如何在Keras中实现对Bahdanau的关注?

问题描述

如何在Keras中实施Bahdanau注意层。喀拉拉邦是否有Bahdanau注意层(例如密实,lstm)? 如果没有,那么请您解释一下如何在keras中实现它。

解决方法

TensorFlow-Keras中有一个Bahdanau Attention的实现,更确切地说是tensorflow-addons

您可以在这里看看:https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/BahdanauAttention

要使用tensorflow-addons,请确保您pip install tensorflow-addons

在本教程中,您可以在这里找到:(https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt),例如,Luong的注意;您可以轻松地对其进行修改并选择Bahdanhau Attetion。

#ENCODER
class EncoderNetwork(tf.keras.Model):
    def __init__(self,input_vocab_size,embedding_dims,rnn_units ):
        super().__init__()
        self.encoder_embedding = tf.keras.layers.Embedding(input_dim=input_vocab_size,output_dim=embedding_dims)
        self.encoder_rnnlayer = tf.keras.layers.LSTM(rnn_units,return_sequences=True,return_state=True )
    
    #DECODER
    class DecoderNetwork(tf.keras.Model):
        def __init__(self,output_vocab_size,rnn_units):
            super().__init__()
            self.decoder_embedding = tf.keras.layers.Embedding(input_dim=output_vocab_size,output_dim=embedding_dims) 
            self.dense_layer = tf.keras.layers.Dense(output_vocab_size)
            self.decoder_rnncell = tf.keras.layers.LSTMCell(rnn_units)
            # Sampler
            self.sampler = tfa.seq2seq.sampler.TrainingSampler()
            # Create attention mechanism with memory = None
            self.attention_mechanism = self.build_attention_mechanism(dense_units,None,BATCH_SIZE*[Tx])
            self.rnn_cell =  self.build_rnn_cell(BATCH_SIZE)
            self.decoder = tfa.seq2seq.BasicDecoder(self.rnn_cell,sampler= self.sampler,output_layer=self.dense_layer)
    
        def build_attention_mechanism(self,units,memory,memory_sequence_length):
             
           # HERE
            return tfa.seq2seq.LuongAttention(units,memory = memory,memory_sequence_length=memory_sequence_length)
            #return tfa.seq2seq.BahdanauAttention(units,memory_sequence_length=memory_sequence_length)
    
        # wrap decodernn cell  
        def build_rnn_cell(self,batch_size ):
            rnn_cell = tfa.seq2seq.AttentionWrapper(self.decoder_rnncell,self.attention_mechanism,attention_layer_size=dense_units)
            return rnn_cell
        
        def build_decoder_initial_state(self,batch_size,encoder_state,Dtype):
            decoder_initial_state = self.rnn_cell.get_initial_state(batch_size = batch_size,dtype = Dtype)
            decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 
            return decoder_initial_state
    
    encoderNetwork = EncoderNetwork(input_vocab_size,rnn_units)
    decoderNetwork = DecoderNetwork(output_vocab_size,rnn_units)
    optimizer = tf.keras.optimizers.Adam()

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...