PyTorch在TransformerEncoder中添加src_key_padding_mask会导致inf丢失

问题描述

我使用this代码作为基础,使用此处提供的不同输入来构建自己的变压器模型。 class TransformerModel(nn.Module):或以下我自己的实现中的部分显示了一些问题:

def make_len_mask(self,inp):
    return (inp == 0).transpose(0,1)


class TransformerModel(nn.Module):
    def __init__(self):
        encoder_layer = TransformerEncoderLayer(ninp,nhead,nhid,dropout)
        encoder_norm = LayerNorm(ninp)
        self.encoder = TransformerEncoder(encoder_layer,nlayers,encoder_norm)

    def forward(self,src,trg):
        src.shape # (x,y,z)
        trg.shape # (x,y)
        # eliminate last dimension of source tensor,which is (batch_size,samples,features) to compute mask
        # resulting in a [true,false]-Vector indicating which elements are padding elements
        padding_tensor = src.mean(2) # padding_tensor.shape: (x,y)
        src_pad_mask = self.make_len_mask(padding_tensor)
        # self.src.mask = None
        output = self.encoder(src,mask=self.src_mask,src_key_padding_mask=src_pad_mask)

使用src_pad_mask会产生ValueError: The loss returned in training_step is nan or inf.。如果未在EncoderLayer中使用该蒙版,则会有结果。

我的输入是

model(source,target)

源是连续的,目标是如下结构的单词:

target = [1] + [4,2,3,8] + [99]  # 0 and 99 are start and end of sentence tokens

我试图用target_tensor[target_tensor == 1] = 0target_tensor[target_tensor == 99] = 0删除它们,很不幸地导致了RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: ...,所以我为序列提供了sos和eos令牌,感觉不对,也许问题出现在那里?但是该序列已在此处填充,因此无法按照示例中的建议删除最后一个或第一个索引。

如果使用nn.Transformer()而不是单个EncoderLayer (),则结果会严重过度拟合而没有遮罩,或者使用该遮罩会出现相同的错误。仅使用target_mask不会产生正确的输入。

是否有可能找出此错误的出处或我的面罩计算错误? github上的更多discussion。是否有必要在训练或推论期间提供口罩?如果是这样,我不这样做,也许有人可以提供帮助或指向消息来源吗?

# Values != 0 => False
# Values == 0 => True
src_pad_mask:
tensor([[False],[False],...
        [ True],[ True]],device='cuda:0')

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...