keras.models.load_model() 给出 ValueError

问题描述

我已将训练好的模型和权重保存如下。

model,history,score = fit_model(model,train_batches,val_batches,callbacks=[callback])
model.save('./model')
model.save_weights('./weights')

然后我尝试通过以下方式获取保存的模型

if __name__ == '__main__':
  model = keras.models.load_model('./model',compile= False,custom_objects={"F1score": tfa.metrics.F1score})
  test_batches,nb_samples = test_gen(dataset_test_path,32,img_width,img_height)
  predict,loss,acc = predict_model(model,test_batches,nb_samples)
  print(predict)
  print(acc)
  print(loss)

但它给了我一个错误。我应该怎么做才能克服这个问题?

Traceback (most recent call last):
  File "test_pro.py",line 34,in <module>
    model = keras.models.load_model('./model',custom_objects={"F1score": tfa.metrics.F1score})
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py",line 212,in load_model
    return saved_model_load.load(filepath,compile,options)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 138,in load
    keras_loader.load_layers()
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 379,in load_layers
    self.loaded_nodes[node_Metadata.node_id] = self._load_layer(
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 407,in _load_layer
    obj,setter = revive_custom_object(identifier,Metadata)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 921,in revive_custom_object
    raise ValueError('Unable to restore custom object of type {} currently. '
ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements `get_config`and `from_config` when saving. In addition,please use the `custom_objects` arg when calling `load_model()`.

解决方法

查看Keras的源码,报错when trying to load a model with a custom object

def revive_custom_object(identifier,metadata):
  """Revives object from SavedModel."""
  if ops.executing_eagerly_outside_functions():
    model_class = training_lib.Model
  else:
    model_class = training_lib_v1.Model

  revived_classes = {
      constants.INPUT_LAYER_IDENTIFIER: (
          RevivedInputLayer,input_layer.InputLayer),constants.LAYER_IDENTIFIER: (RevivedLayer,base_layer.Layer),constants.MODEL_IDENTIFIER: (RevivedNetwork,model_class),constants.NETWORK_IDENTIFIER: (RevivedNetwork,functional_lib.Functional),constants.SEQUENTIAL_IDENTIFIER: (RevivedNetwork,models_lib.Sequential),}
  parent_classes = revived_classes.get(identifier,None)

  if parent_classes is not None:
    parent_classes = revived_classes[identifier]
    revived_cls = type(
        compat.as_str(metadata['class_name']),parent_classes,{})
    return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
  else:
    raise ValueError('Unable to restore custom object of type {} currently. '
                     'Please make sure that the layer implements `get_config`'
                     'and `from_config` when saving. In addition,please use '
                     'the `custom_objects` arg when calling `load_model()`.'
                     .format(identifier))

该方法仅适用于 revived_classes 中定义的类型的自定义对象。如您所见,它目前仅适用于输入层、层、模型、网络和顺序自定义对象。

在您的代码中,您在 tfa.metrics.F1Score 参数中传递了一个 custom_objects 类,该类的类型为 METRIC_IDENTIFIER,因此不受支持(可能是因为它没有实现 {{ 1}} 和 get_config 函数如错误输出所述):

from_config

我上次使用 Keras 已经有一段时间了,但也许您可以尝试遵循 this other related answer 中的建议,并将对 keras.models.load_model('./model',compile=False,custom_objects={"F1Score": tfa.metrics.F1Score}) 的调用包装在一个方法中。像这样(根据您的需要进行调整):

tfa.metrics.F1Score

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...