3D输入的Pytorch交叉熵损失

问题描述

我有一个网络,该网络输出大小为(batch_size,max_len,num_classes)的3D张量。我的真实真理是(batch_size,max_len)的形式。如果我确实对标签执行一次热编码,则其形状将为(batch_size,num_classes),即max_len中的值是[0,num_classes]范围内的整数。由于原始代码太长,我编写了一个更简单的版本来再现原始错误

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size,num_classes])
label = torch.randint(0,num_classes,[batch_size,max_len])
pred = nn.softmax(dim = 2)(pred)
criterion(pred,label)

pred和label的形状分别为torch.Size([32,350,1000])torch.Size([32,350])

遇到的错误

ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])

如果我对标签进行一次热编码以计算损失

x = nn.functional.one_hot(label)
criterion(pred,x)

它将引发以下错误

ValueError:预期的目标大小(32,1000),得到了torch.Size([32,350,1000])

解决方法

Pytorch documentation中,CrossEntropyLoss期望其输入的形状为(N,C,...),因此第二维始终是类的数量。如果将preds的大小调整为(batch_size,num_classes,max_len),则代码应该可以正常工作。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...