哪些 PyTorch 模块受 model.eval() 和 model.train() 影响?

问题描述

model.eval() 方法修改某些模块(层),这些模块(层)在训练和推理期间需要以不同的方式表现。 the docs 中列出了一些示例:

这仅对某些模块有 [an] 影响。如果它们受到影响,请参阅特定模块的文档以了解其在培训/评估模式下的行为的详细信息,例如DropoutBatchnorm

是否有受影响模块的详尽列表?

解决方法

在 google 上搜索 site:https://pytorch.org/docs/stable/generated/torch.nn. "during evaluation",会发现以下模块受到影响:

基类 模块 标准
_InstanceNorm InstanceNorm1d
InstanceNorm2d
InstanceNorm3d
track_running_stats=True
_BatchNorm BatchNorm1d
BatchNorm2d
BatchNorm3d
SyncBatchNorm
_DropoutNd Dropout
Dropout2d
Dropout3d
AlphaDropout
FeatureAlphaDropout
,

除了 @iacob 提供的信息:

基类 模块 标准
RNNBase RNN
LSTM
GRU
dropout > 0(默认:0
变压器层 变压器
变压器编码器
变压器解码器
dropout > 0Transformer 默认值:0.1
懒惰变体 LazyBatchNorm
目前每晚
merged PR
track_running_stats=True