PyTorch 相当于 tf.nn.softmax_cross_entropy_with_logits 和 tf.nn.sigmoid_cross_entropy_with_logits

问题描述

我找到了帖子 here在这里,我们尝试在 PyTorch 中找到 tf.nn.softmax_cross_entropy_with_logits 的等效项。答案仍然让我感到困惑。

这是Tensorflow 2代码

import tensorflow as tf
import numpy as np

# here we assume 2 batch size with 5 classes

preds = np.array([[.4,0.6,0],[.8,0.2,0]])
labels = np.array([[0,1.0,[1.0,0]])


tf_preds = tf.convert_to_tensor(preds,dtype=tf.float32)
tf_labels = tf.convert_to_tensor(labels,dtype=tf.float32)

loss = tf.nn.softmax_cross_entropy_with_logits(logits=tf_preds,labels=tf_labels)

它给了我 loss 作为

<tf.Tensor: shape=(2,),dtype=float32,numpy=array([1.2427604,1.0636061],dtype=float32)>

这是PyTorch代码

import torch
import numpy as np

preds = np.array([[.4,0]])


torch_preds = torch.tensor(preds).float()
torch_labels = torch.tensor(labels).float()

loss = torch.nn.functional.cross_entropy(torch_preds,torch_labels)

然而,它提出了:

运行时错误:需要一维目标张量,不支持多目标

看来问题还是没有解决。如何在 PyTorch 中实现 tf.nn.softmax_cross_entropy_with_logits

tf.nn.sigmoid_cross_entropy_with_logits 怎么样?

解决方法

tf.nn.softmax_cross_entropy_with_logits 在 PyTorch 中

这相当于 torch.nn.CrossEntropyLoss (F.cross_entropy)。不过有两件事需要解决:

  1. 您需要的是将目标类的索引而不是整个目标向量作为 One-Hot-Encoding 传递。为此,您可以使用应用在 dim=1 上的 torch.argmax

    该标准期望在 [0,C-1][0,C−1] 范围内的类索引作为每个值的目标 - torch.nn.CrossEntropyLoss PyTorch documentation

  2. 默认情况下,torch.nn.functional.cross_entropy() 将取所有批次元素损失的平均值,您可以通过传递 reduction='none' 参数来防止这种情况发生。

如果你这样打电话:

loss = F.cross_entropy(torch_preds,torch.argmax(torch_labels,dim=1),reduction='none')

你会得到想要的结果:

tensor([1.2428,1.0636])

tf.nn.sigmoid_cross_entropy_with_logits 在 PyTorch 中

这有点诡异,因为在 PyTorch 中没有直接的等价物。但是,您可以自己实现表达式。给定目标 logits 和标签 p,定义为:

logits = torch.tensor(preds)
p = torch.tensor(labels)

带有 logits 的 sigmoid 交叉熵是:

loss = p*-torch.log(torch.sigmoid(logits)) + (1-p)*-torch.log(1-torch.sigmoid(logits))

给出:

tensor([[0.9130,0.6931,0.4375,0.6931],[0.3711,0.7981,0.6931]])

您可以检查结果是否匹配:

tf.nn.sotfmax_cross_entropy_with_logits(logits=tf_preds,labels=tf_labels)

还有:

tf.nn.sigmoid_cross_entropy_with_logits(logits=tf_preds,labels=tf_labels)

已将 torch.nn.functional 导入为 F