关于 Pytorch 中特定层的参数的梯度

问题描述

我正在 pytorch 中构建一个具有多个网络的模型。例如,让我们考虑 netAnetB。在损失函数中,我需要使用组合 netA(netB)。在优化的不同部分,我需要仅针对 loss_func(netA(netB)) 的参数计算 netA 的梯度,而在另一种情况下,我需要计算 netB 参数的梯度。应该如何解决这个问题?

我的方法:在计算梯度的情况下,netA 的参数我使用 loss_func(netA(netB.detach()))

如果我写 loss_func(netA(netB).detach()),似乎 netAnetB 的两个参数都是分离的。

我尝试使用 loss_func(netA.detach(netB)) 来仅分离 netA 的参数,但它不起作用。 (我收到 netA 没有属性分离的错误。)

解决方法

梯度是张量的属性,而不是网络
因此,您只能.detach 一个张量。

您可以为每个网络使用不同的优化器。通过这种方式,您可以一直计算所有网络的梯度,但只更新相关网络的权重(调用相关优化器的 step)。