问题描述
我正在研究 seq2seq NMT,用于法语到英语的翻译。在推理模型中,我遇到了基数错误。
ValueError:数据基数不明确:
x 尺寸:1、5、5
请提供共享相同第一维的数据。
encoder_inputs = Input(shape=(None,))
embedding_e = Embedding(num_source_vocab,256,mask_zero = True)
encoder_embedding = embedding_e(encoder_inputs)
encoder = LSTM(256,return_state = True)
encoder_outputs,state_h,state_c = encoder(encoder_embedding)
encoder_states = [state_h,state_c]
decoder_inputs = Input(shape=(None,))
embedding_f = Embedding(num_target_vocab,mask_zero = True)
decoder_embedding = embedding_f(decoder_inputs)
decoder = LSTM(256,return_sequences = True,return_state = True)
decoder_outputs,_,_ = decoder(decoder_embedding,initial_state=encoder_states)
decoder_dense = Dense(num_target_vocab,activation= 'softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs,decoder_inputs],[decoder_outputs])
model.compile(optimizer = 'rmsprop',loss = 'categorical_crossentropy',metrics = ['accuracy'])
model.summary()
filepath = 'eng2fre.h5'
checkpoint = ModelCheckpoint(filepath,monitor='val_accuracy',verbose=1,save_best_only=True,mode='max')
history = model.fit([encoder_input_data,decoder_input_data],decoder_target_data,epochs =20,batch_size = 64,validation_split=0.2,callbacks=[checkpoint])
encoder_model = Model(encoder_inputs,encoder_states)
decoder_state_input_h = Input(shape=(256,))
decoder_state_input_c = Input(shape=(256,))
decoder_states_inputs = [decoder_state_input_h,decoder_state_input_c]
decoder_inputs_single = Input(shape=(1,))
decoder_inputs_single_x = embedding_f(decoder_inputs_single)
decoder_outputs2,state_h2,state_c2 = decoder(
decoder_inputs_single_x,initial_state=decoder_states_inputs)
decoder_states2 = [state_h2,state_c2]
decoder_outputs2 = decoder_dense(decoder_outputs2)
decoder_model = Model(
[decoder_inputs_single] + decoder_states_inputs,[decoder_outputs2] + decoder_states2)
x=encoder_input_data[100]
states = encoder_model.predict(x)
input_single = np.zeros((1,1))
input_single[0,0] = target_vocab['sos']
eos_id = target_vocab['eos']
# getting error after the following chunk of code
for i in range(max_target_length):
dec_op,h,c = decoder_model.predict([input_single]+ states)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)