keras.models.load_model() 给出错误“ValueError: Got 0 输入方程“baik,baij->bakj”,期待 2”

问题描述

我的代码将批处理矩阵乘法“tf.einsum('baik,baij->bakj',q,k)/np.sqrt(dv)”作为其中的一部分。在训练模型后,我使用“model.save('./model')”保存它,现在我想加载该保存的模型。我这样试过 "model = keras.models.load_model('./model',compile=False,custom_objects={'f1': f1})" 。但它给出了下面的错误。为什么会发生这种情况。

Traceback (most recent call last):
  File "test_pro.py",line 37,in <module>
    model = keras.models.load_model('./model',custom_objects={'f1': f1})
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py",line 212,in load_model
    return saved_model_load.load(filepath,compile,options)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 147,in load
    keras_loader.finalize_objects()
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 612,in finalize_objects
    self._reconstruct_all_models()
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 631,in _reconstruct_all_models
    self._reconstruct_model(model_id,model,layers)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/load.py",line 677,in _reconstruct_model
    created_layers) = functional_lib.reconstruct_from_config(
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 1285,in reconstruct_from_config
    process_node(layer,node_data)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py",line 1233,in process_node
    output_tensors = layer(input_tensors,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py",line 1012,in __call__
    outputs = call_fn(inputs,*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py",line 1327,in _call_wrapper
    return self._call_wrapper(*args,line 1359,in _call_wrapper
    result = self.function(*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py",line 201,in wrapper
    return target(*args,**kwargs)
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py",line 751,in einsum
    return _einsum_v2(equation,*inputs,line 1174,in _einsum_v2
    _einsum_v2_parse_and_resolve_equation(equation,input_shapes))
  File "/home/dcs2016csc007/.local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py",line 1254,in _einsum_v2_parse_and_resolve_equation
    raise ValueError('Got {} inputs for equation "{}",expecting {}'.format(
ValueError: Got 0 inputs for equation "baik,baij->bakj",expecting 2

这是我创建模型的方式:

def MultiHeadsAttModel(l=7*7,d=1024,dv=64,dout=1024,nv = 16 ):

v1 = Input(shape = (l,d))
q1 = Input(shape = (l,d))
k1 = Input(shape = (l,d))

v2 = Dense(dv*nv,activation = "relu")(v1)
q2 = Dense(dv*nv,activation = "relu")(q1)
k2 = Dense(dv*nv,activation = "relu")(k1)

v = Reshape([l,nv,dv])(v2)
q = Reshape([l,dv])(q2)
k = Reshape([l,dv])(k2)
att = tf.einsum('baik,k)/np.sqrt(dv) #batch matrix multiplication
att = Lambda(lambda x:  K.softmax(x),output_shape=(l,nv))(att)
out = tf.einsum('bajk,baik->baji',att,v)
out = Reshape([l,d])(out)
out = Add()([out,q1])

out = Dense(dout,activation = "relu")(out)

return  Model(inputs=[q1,k1,v1],outputs=out) 



   
def create_model(input_shape,output_classes):
     mobile = tf.keras.applications.mobilenet.MobileNet(weights='imagenet')
     x = mobile.layers[-6].input
    
     if True:
        x = Reshape([7*7,1024])(x)
        att = MultiHeadsAttModel(l=7*7,nv = 16 )
        x = att([x,x,x])
        x = Reshape([7,7,1024])(x)   
        x = Batchnormalization()(x)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...