PyTorch 中 BatchNorm2d 的导数

问题描述

在我的网络中,我想在正向传递中计算我的网络的正向传递和反向传递。 为此,我必须手动定义前向传递层的所有后向传递方法
对于激活函数,这很容易。而且对于线性和卷积层,它运行良好。但我真的在为 Batchnorm 苦苦挣扎。由于 Batchnorm 论文只讨论了 1D 情况: 到目前为止,我的实现是这样的:

def backward_batchnorm2d(input,output,grad_output,layer):
    gamma = layer.weight
    beta = layer.bias
    avg = layer.running_mean
    var = layer.running_var
    eps = layer.eps
    B = input.shape[0]

    # avg,var,gamma and beta are of shape [channel_size]
    # while input,grad_output are of shape [batch_size,channel_size,w,h]
    # for my calculations I have to reshape avg,gamma and beta to [batch_size,h] by repeating the channel values over the whole image and batches

    dL_dxi_hat = grad_output * gamma
    dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0,2,3),keepdim=True)
    dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0,keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0,keepdim=True) / B
    dL_dxi = dL_dxi_hat / torch.sqrt(var + eps) + 2.0 * dL_dvar * (input - avg) / B + dL_davg / B # dL_dxi_hat / sqrt()
    dL_dgamma = (grad_output * output).sum((0,keepdim=True)
    dL_dbeta = (grad_output).sum((0,keepdim=True)
    return dL_dxi,dL_dgamma,dL_dbeta

当我使用 torch.autograd.grad() 检查我的渐变时,我注意到 dL_dgammadL_dbeta 是正确的,但 dL_dxi 是不正确的,(很多)。但是我找不到我的错误。我的错误在哪里?

作为参考,这里是 Batchnorm 的定义:

enter image description here

这里是一维情况的导数公式:

enter image description here

解决方法

def backward_batchnorm2d(input,output,grad_output,layer):
    gamma = layer.weight
    gamma = gamma.view(1,-1,1,1) # edit
    # beta = layer.bias
    # avg = layer.running_mean
    # var = layer.running_var
    eps = layer.eps
    B = input.shape[0] * input.shape[2] * input.shape[3] # edit

    # add new
    mean = input.mean(dim = (0,2,3),keepdim = True)
    variance = input.var(dim = (0,unbiased=False,keepdim = True)
    x_hat = (input - mean)/(torch.sqrt(variance + eps))
    
    dL_dxi_hat = grad_output * gamma
    # dL_dvar = (-0.5 * dL_dxi_hat * (input - avg) / ((var + eps) ** 1.5)).sum((0,keepdim=True) 
    # dL_davg = (-1.0 / torch.sqrt(var + eps) * dL_dxi_hat).sum((0,keepdim=True) + dL_dvar * (-2.0 * (input - avg)).sum((0,keepdim=True) / B
    dL_dvar = (-0.5 * dL_dxi_hat * (input - mean)).sum((0,keepdim=True)  * ((variance + eps) ** -1.5) # edit
    dL_davg = (-1.0 / torch.sqrt(variance + eps) * dL_dxi_hat).sum((0,keepdim=True) + (dL_dvar * (-2.0 * (input - mean)).sum((0,keepdim=True) / B) #edit
    
    dL_dxi = (dL_dxi_hat / torch.sqrt(variance + eps)) + (2.0 * dL_dvar * (input - mean) / B) + (dL_davg / B) # dL_dxi_hat / sqrt()
    # dL_dgamma = (grad_output * output).sum((0,keepdim=True) 
    dL_dgamma = (grad_output * x_hat).sum((0,keepdim=True) # edit
    dL_dbeta = (grad_output).sum((0,keepdim=True)
    return dL_dxi,dL_dgamma,dL_dbeta
  1. 因为您没有上传前向代码,所以如果您的 gamma 的形状大小为 1,您需要将其重塑为 [1,gamma.shape[0],1]
  2. 该公式遵循 1D,因此它们通过批量大小求和。但是,在 2D 中,我们对 3 个维度求和,因此 B = input.shape[0] * input.shape[2] * input.shape[3]
  3. running_meanrunning_var 仅用于测试/推理模式,我们不会在训练中使用它们(您可以在 the paper 中找到它)。您需要的均值和方差是根据输入计算的,您可以将均值、方差和 x_hat = (x-mean)/sqrt(variance + eps) 存储到您的对象 layer 中,或者像我在上面的代码 # add new 中所做的那样重新计算。然后用dL_dvar,dL_davg,dL_dxi的公式替换它们。
  4. 您的 dL_dgamma 应该是不正确的,因为您自己乘以 output 的梯度,应该将其修改为 grad_output * x_hat