如何为seq2seq模型添加注意力机制?

问题描述

我正在尝试通过Keras blog上的示例研究seq2seq模型。

我对现在的工作方式有相当不错的了解。我目前在为该模型添加注意力机制时陷入困境。我看到了另一篇文章并获得了一些线索,但是现在它仅引发索引错误。另外,我不确定在推理阶段该怎么做。

这里是我的代码

# Define an input sequence and process it.
encoder_inputs = Input(shape=(None,num_encoder_tokens))
encoder = LSTM(latent_dim,return_state=True)
encoder_outputs,state_h,state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h,state_c]

# Set up the decoder,using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,num_decoder_tokens))
# We set up our decoder to return full output sequences,# and to return internal states as well. We don't use the
# return states in the training model,but we will use them in inference.
decoder_lstm = LSTM(latent_dim,return_sequences=True,return_state=True)
decoder_outputs,_,_ = decoder_lstm(decoder_inputs,initial_state=encoder_states)


attention = dot([decoder_outputs,encoder_outputs],axes=[2,2])
attention = Activation('softmax')(attention)
context = dot([attention,1])
decoder_outputs = concatenate([context,decoder_outputs])
decoder_dense = Dense(num_decoder_tokens,activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)


# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs,decoder_inputs],decoder_outputs)
print(model.summary())
#%%
# Run training
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.fit([encoder_input_data,decoder_input_data],decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.2)
# Save model
model.save('s2s.h5')

# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states

# Define sampling models
encoder_model = Model(encoder_inputs,encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h,decoder_state_input_c]
decoder_outputs,state_c = decoder_lstm(
    decoder_inputs,initial_state=decoder_states_inputs)
decoder_states = [state_h,state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,[decoder_outputs] + decoder_states)

# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
    (i,char) for char,i in input_token_index.items())
reverse_target_char_index = dict(
    (i,i in target_token_index.items())


def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1,1,num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0,target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify,here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens,h,c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0,-1,:])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1,num_decoder_tokens))
        target_seq[0,sampled_token_index] = 1.

        # Update states
        states_value = [h,c]

    return decoded_sentence

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)