问题描述
我正在 pytorch 中构建一个具有多个网络的模型。例如,让我们考虑 netA
和 netB
。在损失函数中,我需要使用组合 netA(netB)
。在优化的不同部分,我需要仅针对 loss_func(netA(netB))
的参数计算 netA
的梯度,而在另一种情况下,我需要计算 netB
参数的梯度。应该如何解决这个问题?
我的方法:在计算梯度的情况下,netA
的参数我使用 loss_func(netA(netB.detach()))
。
如果我写 loss_func(netA(netB).detach())
,似乎 netA
和 netB
的两个参数都是分离的。
我尝试使用 loss_func(netA.detach(netB))
来仅分离 netA
的参数,但它不起作用。 (我收到 netA
没有属性分离的错误。)
解决方法
梯度是张量的属性,而不是网络。
因此,您只能.detach
一个张量。
您可以为每个网络使用不同的优化器。通过这种方式,您可以一直计算所有网络的梯度,但只更新相关网络的权重(调用相关优化器的 step
)。