pytorch中的交叉熵损失如何工作?

问题描述

我正在尝试一些pytorch代码。通过交叉熵损失,我发现了一些有趣的结果,同时使用了二进制的交叉熵损失和pytorch的交叉熵损失

import torch
import torch.nn as nn

X = torch.tensor([[1,0],[1,[0,1],1]],dtype=torch.float)
softmax = nn.softmax(dim=1)


bce_loss = nn.bceloss()
ce_loss= nn.CrossEntropyLoss()

pred = softmax(X)

bce_loss(X,X) # tensor(0.)
bce_loss(pred,X) # tensor(0.3133)
bce_loss(pred,pred) # tensor(0.5822)

ce_loss(X,torch.argmax(X,dim=1)) # tensor(0.3133)

我期望相同输入和输出交叉熵损失为零。这里X,pred和torch.argmax(X,dim = 1)与某些转换相同/相似。这种推理仅适用于bce_loss(X,X) # tensor(0.),否则所有其他结果都导致损失大于零。我推测bce_loss(pred,X)bce_loss(pred,pred)ce_loss(X,dim=1))输出应为零。

这是什么错误

解决方法

看到这个的原因是因为nn.CrossEntropyLoss接受logits和目标,aka X应该是logits,但是已经在0和1之间。X应该更大,因为在softmax之后将介于0和1之间。

ce_loss(X * 1000,torch.argmax(X,dim=1)) # tensor(0.)

nn.CrossEntropyLoss与logits一起使用,以利用log sum技巧。

激活后您目前正在尝试的方式,您的预测大约为[0.73,0.26]

二进制交叉熵示例有效,因为它接受了已激活的logit。顺便说一句,您可能想使用nn.Sigmoid来激活二进制交叉熵logit。对于2类示例,softmax也可以。

,

我已经发布了交叉熵和 NLLLoss here 的手动实现,作为对相关 pytorch CrossEntropyLoss 问题的回答。它可能并不完美,但请务必检查一下。

编辑:我在之前的帖子中没有包含代码,所以帖子被删除了。按照给定的建议,计算 CrossEntropyLoss 的部分代码(直接从上面的链接复制)如下:

def compute_crossentropyloss_manual(x,y0):
    """
    x is the vector of probabilities with shape (batch_size,C)
    y0 shape is the same (batch_size),whose entries are integers from 0 to C-1
    """
    loss = 0.
    n_batch,n_class = x.shape
    # print(n_class)
    for x1,y1 in zip(x,y0):
        class_index = int(y1.item())
        loss = loss + torch.log(torch.exp(x1[class_index])/(torch.exp(x1).sum()))
    loss = - loss/n_batch
    return loss

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...