变压器中目标的Pytorch NLP序列长度

问题描述

我正在尝试理解Transformer(https://github.com/SamLynnEvans/Transformer)的代码。

如果在“火车”脚本中看到train_model函数,我想知道为什么需要使用与trg不同的trg_input序列长度:

trg_input = trg[:,:-1]

在这种情况下,trg_input的序列长度为“ seq_len(trg)-1”。 这意味着trg就像:

<sos> tok1 tok2 tokn <eos>

和trg_input就像:

<sos> tok1 tok2 tokn    (no eos token)

请让我知道原因。

谢谢。

相关代码如下:

    for i,batch in enumerate(opt.train):
        src = batch.src.transpose(0,1).to('cuda')
        trg = batch.trg.transpose(0,1).to('cuda')

        trg_input = trg[:,:-1]
        src_mask,trg_mask = create_masks(src,trg_input,opt)
        preds = model(src,src_mask,trg_mask)
        ys = trg[:,1:].contiguous().view(-1)
        opt.optimizer.zero_grad()
        loss = F.cross_entropy(preds.view(-1,preds.size(-1)),ys,ignore_index=opt.trg_pad)
        loss.backward()
        opt.optimizer.step()


def create_masks(src,trg,opt):
    
    src_mask = (src != opt.src_pad).unsqueeze(-2)

    if trg is not None:
        trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
        size = trg.size(1) # get seq_len for matrix
        np_mask = nopeak_mask(size,opt)
        if trg.is_cuda:
            np_mask.cuda()
        trg_mask = trg_mask & np_mask
        
    else:
        trg_mask = None
    return src_mask,trg_mask

解决方法

这是因为整个目标是根据到目前为止所看到的标记来生成下一个标记。当我们得到预测时,看看模型的输入。直到我们当前的步骤,我们不仅要馈送源序列,而且还要馈送目标序列。 Models.py中的模型如下:

class Transformer(nn.Module):
    def __init__(self,src_vocab,trg_vocab,d_model,N,heads,dropout):
        super().__init__()
        self.encoder = Encoder(src_vocab,dropout)
        self.decoder = Decoder(trg_vocab,dropout)
        self.out = nn.Linear(d_model,trg_vocab)
    def forward(self,src,trg,src_mask,trg_mask):
        e_outputs = self.encoder(src,src_mask)
        #print("DECODER")
        d_output = self.decoder(trg,e_outputs,trg_mask)
        output = self.out(d_output)
        return output

因此,您可以看到forward方法接收到srctrg,它们分别被馈入编码器和解码器。如果您查看the original paper中的模型架构,这会更容易理解:

enter image description here

“输出(右移)”对应于代码中的trg[:,:-1]

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...