如何获得torch.nn.Transformer的稳定输出

问题描述

像pytorch的Transformer层那样的输出无法再现。对于cpu和gpu都是这样。我知道有时是由于在gpu上进行并行计算而发生的。

emb = nn.Embedding(10,12).to(device)
inp1 = torch.LongTensor([1,2,3,4]).to(device)
inp1 = emb(inp1).reshape(inp1.shape[0],1,12) #S N E

encoder_layer = nn.TransformerEncoderLayer(d_model=12,nhead=4)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=4)

out1 = transformer_encoder(inp1)
out2 = transformer_encoder(inp1)

out1和out2不同。它可以在cpu上进行多处理,但是结果看起来太不稳定了。该如何解决?

解决方法

nn.TransformerEncoderLayer的默认辍学率为0.1。当模型处于训练模式时,将在每次迭代中随机删除要删除的索引。

如果要使用辍学训练模型,只需在训练中忽略此行为,然后在测试中致电model.eval()

如果您想在训练中禁用这种随机行为,请像这样设置dropout=0

nn.TransformerEncoderLayer(d_model=12,nhead=4,dropout=0)

完整的测试脚本:

import torch
import torch.nn as nn

device = 'cpu'

emb = nn.Embedding(10,12).to(device)
inp1 = torch.LongTensor([1,2,3,4]).to(device)
inp1 = emb(inp1).reshape(inp1.shape[0],1,12) #S N E

encoder_layer = nn.TransformerEncoderLayer(d_model=12,dropout=0).to(device)
transformer_encoder = nn.TransformerEncoder(encoder_layer,num_layers=4).to(device)

out1 = transformer_encoder(inp1)
out2 = transformer_encoder(inp1)

print((out1-out2).norm())

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...