问题描述
Fairseq使用解码器从编码器接收多个输出,这些输出保持编码器的多个中间状态 两个全局数组 encoder_inner_states = [] final_encoder_inner_states = [] 编码器:
def forward(self,src_tokens,src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch,src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len,batch,embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch,src_len)`
"""
encoder_inner_states.clear()
if self.history is not None:
self.history.clean()
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x,p=self.dropout,training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0,1)
# add emb into history
if self.history is not None:
self.history.add(x)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
#print("emb:{}".format(x.size()))
#intra_sim
attn_weight_list = []
inner_states=[]
# encoder layers
for layer_id,layer in enumerate(self.layers):
#if layer_id == 2 :
#continue
x,attn_weight = layer(x,encoder_padding_mask)
if self.history is not None:
self.history.add(x)
if self.history is not None:
y = self.history.pop()
encoder_inner_states.append(x)
attn_weight_list.append(attn_weight)
if self.history is not None:
x = self.history.pop()
final_encoder_inner_states.clear()
for fc_layer_id,fc_layer in enumerate(self.fc_layers):
y = fc_layer(encoder_inner_states[fc_layer_id])
y = F.dropout(y,training=self.training)
if self.normalize:
y = self.layer_norm(y)
final_encoder_inner_states.append(y)
#self.print_attn_weight(attn_weight_list)
#print("encoder-out:{}".format(x))
return {
'encoder_out': x,# T x B x C
'encoder_padding_mask': encoder_padding_mask,# B x T
}
解码器:使用全局数组中的数据,而不是使用常规的encoder_out
def forward(self,prev_output_tokens,encoder_out=None,incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): prevIoUs decoder outputs of shape
`(batch,tgt_len)`,for input Feeding/teacher forcing
encoder_out (Tensor,optional): output from the encoder,used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:param inner_state:
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch,tgt_len,vocab)`
- the last decoder layer's attention weights of shape `(batch,src_len)`
"""
# embed positions
positions = self.embed_positions(
prev_output_tokens,incremental_state=incremental_state,) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:,-1:]
if positions is not None:
positions = positions[:,-1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
x = F.dropout(x,1)
attn = None
inner_states = [x]
enc_dec_attn_weight_list = []
# decoder layers
for layer_id,layer in enumerate(self.layers):
x,attn = layer(
x,#encoder_out['inner_states'][layer_id],#encoder_out['encoder_out'] if encoder_out is not None else None,final_encoder_inner_states[layer_id],encoder_out['encoder_padding_mask'] if encoder_out is not None else None,incremental_state,self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,)
inner_states.append(x)
enc_dec_attn_weight_list.append(attn)
#for i in range(len(enc_dec_attn_weight_list)):
#print('layer{}'.format(i))
#print(enc_dec_attn_weight_list[i].cpu().numpy())
if self.normalize:
x = self.layer_norm(x)
# print("decoder-out:{}".format(x))
# T x B x C -> B x T x C
x = x.transpose(0,1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x,self.embed_tokens.weight)
else:
x = F.linear(x,self.embed_out)
return x,{'attn': attn,'inner_states': inner_states}
但是我有以下问题: enter image description here
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)