在模型中使用自定义层时,Keras load_model导致“ TypeError:关键字参数无法理解:”

问题描述

我正在建立一个具有Tensorflow's nmt tutorial中实现的自定义关注层的模型。为了解决我的问题,我使用了相同的层代码并进行了一些更改作为建议。

问题是当我拥有此自定义层时,在保存模型后无法从文件加载模型。这是图层类:

class BahdanauAttention(layers.Layer):
    def __init__(self,output_dim=30,**kwargs):
        super(BahdanauAttention,self).__init__(**kwargs)
        self.W1 = tf.keras.layers.Dense(output_dim)
        self.W2 = tf.keras.layers.Dense(output_dim)
        self.V = tf.keras.layers.Dense(1)

    def call(self,inputs,**kwargs):
        query = inputs[0]
        values = inputs[1]
        query_with_time_axis = tf.expand_dims(query,1)

        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis) + self.W2(values)))

        attention_weights = tf.nn.softmax(score,axis=1)

        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector,axis=1)

        return context_vector,attention_weights

    def get_config(self):
        config = super(BahdanauAttention,self).get_config()
        config.update({
            'W1': self.W1,'W2': self.W2,'V': self.V,})
        return config

我正在使用keras的ModelCheckpoint回调保存模型

path = os.path.join(self.dir,'model_{}'.format(self.timestamp))
callbacks.append(ModelCheckpoint(path,save_best_only=True,monitor='val_loss',mode='min'))

稍后,我正在像这样加载模型:

 self.model = load_model(path,custom_objects={'BahdanauAttention': BahdanauAttention,'custom_loss': self.custom_loss})

这是我收到的错误消息:

raise TypeError(error_message,kwarg)
    TypeError: ('Keyword argument not understood:','W1')

和完整的追溯:

Traceback (most recent call last):
  File "models/lstm.py",line 49,in load_model
    'dollar_mape_loss': self.dollar_mape_loss})
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py",line 187,in load_model
    return saved_model_load.load(filepath,compile,options)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 121,in load
    path,options=options,loader_cls=KerasObjectLoader)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py",line 633,in load_internal
    ckpt_options)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 194,in __init__
    super(KerasObjectLoader,self).__init__(*args,**kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/saved_model/load.py",line 130,in __init__
    self._load_all()
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 215,in _load_all
    self._layer_nodes = self._load_layers()
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 315,in _load_layers
    layers[node_id] = self._load_layer(proto.user_object,node_id)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 341,in _load_layer
    obj,setter = self._revive_from_config(proto.identifier,Metadata,line 359,in _revive_from_config
    self._revive_layer_from_config(Metadata,node_id))
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 417,in _revive_layer_from_config
    generic_utils.serialize_keras_class_and_config(class_name,config))
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/layers/serialization.py",line 175,in deserialize
    printable_module_name='layer')
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py",line 360,in deserialize_keras_object
    return cls.from_config(cls_config)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py",line 697,in from_config
    return cls(**config)
  File "models/lstm.py",line 310,in __init__
    super(BahdanauAttention,self).__init__(**kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py",line 457,in _method_wrapper
    result = method(self,*args,**kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py",line 318,in __init__
    generic_utils.validate_kwargs(kwargs,allowed_kwargs)
  File "venv/m/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py",line 778,in validate_kwargs
    raise TypeError(error_message,kwarg)
TypeError: ('Keyword argument not understood:','W1')

类似的问题表明代码使用的是Keras和TensorFlow的不同版本。我仅使用TensorFlow的Keras。这些是进口

from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping,CSVLogger,ModelCheckpoint
from tensorflow.keras import layers

解决方法

keras' documentation on custom layers之后,他们建议不要在__init__()中初始化任何权重,而应在build()中初始化。这样就不需要将权重添加到配置中,并且可以解决错误。

这是更新的自定义图层类:

class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self,units=30,**kwargs):
        super(BahdanauAttention,self).__init__(**kwargs)
        self.units = units
      

    def build(self,input_shape):
        self.W1 = tf.keras.layers.Dense(self.units)
        self.W2 = tf.keras.layers.Dense(self.units)
        self.V = tf.keras.layers.Dense(1)

    def call(self,inputs,**kwargs):
        query = inputs[0]
        values = inputs[1]
        query_with_time_axis = tf.expand_dims(query,1)

       
        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis) + self.W2(values)))

        attention_weights = tf.nn.softmax(score,axis=1)

        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector,axis=1)

        return context_vector,attention_weights

    def get_config(self):
        config = super(BahdanauAttention,self).get_config()
        config.update({
            'units': self.units,})
        return config
,

我也有这个问题。 我尝试了很多方法,发现可以使用这种方法。 首先,建立模型

model = TextAttBiRNN(maxlen,max_features,embedding_dims).get_model()
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])

第二,负载权重: 我用这个解决了这个问题:

model_file = "/content/drive/My Drive/dga/output_data/model_lstm_att_test_v6.h5"
model.load_weights(model_file)

然后,我们会发现模型可以使用。

这样,我就避开了前面的问题。

,

@ user7331538尝试将path=f os.path.join(self.dir,'model_{}'.format(self.timestamp))替换为path='anymodel_name.h5'