Fairseq变压器解码器接收编码器多路输出问题

问题描述

Fairseq使用解码器从编码器接收多个输出,这些输出保持编码器的多个中间状态 两个全局数组 encoder_inner_states = [] final_encoder_inner_states = [] 编码器:

def forward(self,src_tokens,src_lengths):
    """
    Args:
        src_tokens (LongTensor): tokens in the source language of shape
            `(batch,src_len)`
        src_lengths (torch.LongTensor): lengths of each source sentence of
            shape `(batch)`

    Returns:
        dict:
            - **encoder_out** (Tensor): the last encoder layer's output of
              shape `(src_len,batch,embed_dim)`
            - **encoder_padding_mask** (ByteTensor): the positions of
              padding elements of shape `(batch,src_len)`
    """
    encoder_inner_states.clear()
    if self.history is not None:
        self.history.clean()
    # embed tokens and positions
    x = self.embed_scale * self.embed_tokens(src_tokens)
    if self.embed_positions is not None:
        x += self.embed_positions(src_tokens)
    x = F.dropout(x,p=self.dropout,training=self.training)

    # B x T x C -> T x B x C
    x = x.transpose(0,1)

    # add emb into history
    if self.history is not None:
        self.history.add(x)

    # compute padding mask
    encoder_padding_mask = src_tokens.eq(self.padding_idx)
    if not encoder_padding_mask.any():
        encoder_padding_mask = None
    #print("emb:{}".format(x.size()))
    
    #intra_sim
    attn_weight_list = []
    inner_states=[]
    # encoder layers
    for layer_id,layer in enumerate(self.layers):
        #if layer_id == 2 :
            #continue
        x,attn_weight = layer(x,encoder_padding_mask)

        if self.history is not None:
            self.history.add(x)
        if self.history is not None:
            y = self.history.pop()

        encoder_inner_states.append(x)
        attn_weight_list.append(attn_weight)



    if self.history is not None:
        x = self.history.pop()

    final_encoder_inner_states.clear()
    for fc_layer_id,fc_layer in enumerate(self.fc_layers):
        y = fc_layer(encoder_inner_states[fc_layer_id])
        y = F.dropout(y,training=self.training)
        if self.normalize:
            y = self.layer_norm(y)
        final_encoder_inner_states.append(y)

    #self.print_attn_weight(attn_weight_list)
    #print("encoder-out:{}".format(x))
    return {
        'encoder_out': x,# T x B x C
        'encoder_padding_mask': encoder_padding_mask,# B x T
    }

解码器:使用全局数组中的数据,而不是使用常规的encoder_out

def forward(self,prev_output_tokens,encoder_out=None,incremental_state=None):
    """
    Args:
        prev_output_tokens (LongTensor): prevIoUs decoder outputs of shape
            `(batch,tgt_len)`,for input Feeding/teacher forcing
        encoder_out (Tensor,optional): output from the encoder,used for
            encoder-side attention
        incremental_state (dict): dictionary used for storing state during
            :param inner_state:
            :ref:`Incremental decoding`

    Returns:
        tuple:
            - the last decoder layer's output of shape `(batch,tgt_len,vocab)`
            - the last decoder layer's attention weights of shape `(batch,src_len)`
    """
    # embed positions
    positions = self.embed_positions(
        prev_output_tokens,incremental_state=incremental_state,) if self.embed_positions is not None else None

    if incremental_state is not None:
        prev_output_tokens = prev_output_tokens[:,-1:]
        if positions is not None:
            positions = positions[:,-1:]

    # embed tokens and positions
    x = self.embed_scale * self.embed_tokens(prev_output_tokens)

    if self.project_in_dim is not None:
        x = self.project_in_dim(x)

    if positions is not None:
        x += positions
    x = F.dropout(x,1)
    attn = None

    inner_states = [x]
    enc_dec_attn_weight_list = []



    # decoder layers
    for layer_id,layer in enumerate(self.layers):

        x,attn = layer(
            x,#encoder_out['inner_states'][layer_id],#encoder_out['encoder_out'] if encoder_out is not None else None,final_encoder_inner_states[layer_id],encoder_out['encoder_padding_mask'] if encoder_out is not None else None,incremental_state,self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,)
        inner_states.append(x)
        enc_dec_attn_weight_list.append(attn)
    #for i in range(len(enc_dec_attn_weight_list)):
        #print('layer{}'.format(i))
        #print(enc_dec_attn_weight_list[i].cpu().numpy())

    if self.normalize:
        x = self.layer_norm(x)
    # print("decoder-out:{}".format(x))
    # T x B x C -> B x T x C
    x = x.transpose(0,1)

    if self.project_out_dim is not None:
        x = self.project_out_dim(x)

    if self.adaptive_softmax is None:
        # project back to size of vocabulary
        if self.share_input_output_embed:
            x = F.linear(x,self.embed_tokens.weight)
        else:
            x = F.linear(x,self.embed_out)

    return x,{'attn': attn,'inner_states': inner_states}

但是我有以下问题: enter image description here

enter image description here

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)