如何正确使用交叉熵损失 vs Softmax 进行分类? 1.如何以最佳方式训练“标准”分类网络?2.如果网络有最后一个线性层,如何推断每个类的概率?3.如果网络有最后的 softmax 层,如何训练网络哪个损失,以及如何训练?

问题描述

我想使用 Pytorch 训练一个多类分类器。

以下 the official Pytorch doc 显示了如何在类型为 nn.CrossEntropyLoss() 的最后一层之后使用 nn.Linear(84,10)

不过,我记得这是 softmax 所做的。

这让我很困惑。


  1. 如何以最佳方式训练“标准”分类网络?
  2. 如果网络有最后一个线性层,如何推断每个类的概率?
  3. 如果网络有最后的 softmax 层,如何训练网络(哪个损失,以及如何训练)?

我在 Pytorch 论坛上找到了 this thread,它可能回答了所有这些问题,但我无法将其编译为可用且可读的 Pytorch 代码


我假设的答案:

  1. 喜欢the doc says
  2. 线性层输出的取幂,实际上是对数(对数概率)。
  3. 我不明白。

解决方法

我认为理解 softmax 和交叉熵很重要,至少从实践的角度来看是这样。一旦您掌握了这两个概念,就应该清楚如何在 ML 的上下文中“正确”使用它们。

交叉熵 H(p,q)

交叉熵是一个比较两个概率分布的函数。从实践的角度来看,可能不值得深入探讨交叉熵的正式动机,但如果您有兴趣,我会推荐 Cover 和 Thomas 的信息理论要素作为介绍性文本。这个概念很早就被引入了(我相信是第 2 章)。这是我在研究生院使用的介绍文字,我认为它做得很好(当然我也有一位很棒的导师)。

需要注意的关键是交叉熵是一个函数,它接受两个概率分布:q 和 p,并返回一个当 q 和 p 相等时最小的值。 q 表示估计分布,p 表示真实分布。

在 ML 分类的上下文中,我们知道训练数据的实际标签,因此真实/目标分布 p 对于真实标签的概率为 1,其他地方为 0,即 p 是单热向量。

另一方面,估计分布(模型的输出)q 通常包含一些不确定性,因此 q 中任何类别的概率将在 0 和 1 之间。通过训练系统以最小化交叉熵,我们是告诉系统我们希望它尝试使估计分布尽可能接近真实分布。因此,你的模型认为最有可能的类就是 q 的最大值对应的类。

Softmax

同样,有一些复杂的统计方法来解释 softmax,我们不会在这里讨论。从实用的角度来看,关键是 softmax 是一个函数,它以无界值列表作为输入,并输出一个有效的概率质量函数保持相对顺序。重要的是要强调关于相对顺序的第二点。这意味着 softmax 输入中的最大元素对应于 softmax 输出中的最大元素。

考虑一个经过训练以最小化交叉熵的 softmax 激活模型。在这种情况下,在 softmax 之前,模型的目标是为正确的标签产生尽可能高的值,为不正确的标签产生尽可能低的值。

PyTorch 中的交叉熵损失

PyTorch 中 CrossEntropyLoss 的定义是 softmax 和交叉熵的结合。具体

CrossEntropyLoss(x,y) := H(one_hot(y),softmax(x))

请注意,one_hot 是一个函数,它采用索引 y,并将其扩展为 one-hot 向量。

等效地,您可以将 CrossEntropyLoss 表示为 LogSoftmax 和负对数似然损失(即 PyTorch 中的 NLLLoss)的组合

LogSoftmax(x) := ln(softmax(x))

CrossEntropyLoss(x,y) := NLLLoss(LogSoftmax(x),y)

由于 softmax 中的求幂,有一些计算“技巧”可以使直接使用 CrossEntropyLoss 比分阶段计算更稳定(更准确,不太可能得到 NaN)。

结论

基于以上讨论,您的问题的答案是

1.如何以最佳方式训练“标准”分类网络?

就像医生说的那样。

2.如果网络有最后一个线性层,如何推断每个类的概率?

将 softmax 应用于网络的输出以推断每个类别的概率。如果目标只是找到相对排序或最高概率类,那么只需将 argsort 或 argmax 直接应用于输出(因为 softmax 保持相对排序)。

3.如果网络有最后的 softmax 层,如何训练网络(哪个损失,以及如何训练)?

通常,出于上述稳定性原因,您不希望训练输出 softmaxed 输出的网络。

也就是说,如果您出于某种原因绝对需要,您可以获取输出日志并将它们提供给 NLLLoss

criterion = nn.NLLLoss()
...
x = model(data)    # assuming the output of the model is softmax activated
loss = criterion(torch.log(x),y)

这在数学上等同于将 CrossEntropyLoss 与使用 softmax 激活的模型一起使用。

criterion = nn.CrossEntropyLoss()
...
x = model(data)    # assuming the output of the model is NOT softmax activated
loss = criterion(x,y)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...