Tensorflow SavedModel 在加载时忽略资产文件

问题描述

我对来自 Tensorflow hub 的 BERT 模型进行了微调,以构建一个简单的情感分析器。该模型训练和运行良好。在导出时,我只是使用了:

tf.saved_model.save(model,export_dir='models')

这工作得很好..直到我重新启动。

重新启动后,模型不再加载。我曾尝试使用 Keras 加载器和 Tensorflow 服务器,但遇到了同样的错误

我收到以下错误消息:

未找到:/tmp/tfhub_modules/09bd4e665682e6f03bc72fbcff7a68bf879910e/assets/vocab.txt;没有那个文件或目录

该模型正在尝试从 tfhub 模块缓存加载资产,该缓存因重启而被擦除。我知道我可以保留缓存,但我不想这样做,因为我希望能够生成模型,然后将它们复制到单独的应用程序中,而不必担心缓存。

关键是我认为根本没有必要在缓存中查找资产。该模型与生成 vocab.txt 的资产文件夹一起保存,因此为了找到资产,它只需要查看自己的资产文件夹(我认为)。然而,它似乎并没有这样做。

有什么办法可以改变这种行为吗?


添加了用于构建和导出模型的代码(这不是一个聪明的模型,只是对我的工作流程进行原型设计):

bert_model_name = "bert_en_uncased_L-12_H-768_A-12"

BATCH_SIZE = 64
EPOCHS = 1 # Initial

def build_bert_model(bert_model_name):
    input_layer = tf.keras.layers.Input(shape=(),dtype=tf.string,name="inputs")
    preprocessing_layer = hub.KerasLayer(
        map_model_to_preprocess[bert_model_name],name="preprocessing"
    )

    encoder_inputs = preprocessing_layer(input_layer)
    bert_model = hub.KerasLayer(
        map_name_to_handle[bert_model_name],name="BERT_encoder"
    )
    outputs = bert_model(encoder_inputs)

    net = outputs["pooled_output"]
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(1,activation=None,name="classifier")(net)
    return tf.keras.Model(input_layer,net)

def main():
    train_ds,val_ds = load_sentiment140(batch_size=BATCH_SIZE,epochs=EPOCHS)
    steps_per_epoch = tf.data.experimental.cardinality(train_ds).numpy()
    init_lr = 3e-5

    optimizer = tf.keras.optimizers.Adam(learning_rate=init_lr)
    model = build_bert_model(bert_model_name)

    model.compile(optimizer=optimizer,loss='mse',metrics='mse')
    model.fit(train_ds,validation_data=val_ds,steps_per_epoch=steps_per_epoch)

    tf.saved_model.save(model,export_dir='models')

解决方法

此问题来自由 bug 的版本 /1 和 /2 触发的 TensorFlow https://tfhub.dev/tensorflow/bert_en_uncased_preprocess。更新的模型 tensorflow/bert_*_preprocess/3(上周五发布)避免了这个错误。请更新到最新版本。

Classify Text with BERT 教程已相应更新。

感谢您提出这个问题!

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...