问题描述
我有一个返回预测 (n,n)
的 GAN。为了指导这个网络,我有一个损失函数,它是二元交叉熵损失 (torch.tensor
) 和 Wasserstein 距离的总和。但是,为了计算 Wasserstein 距离,我使用了 bceloss
库中的 scipy.stats.wasserstein_distance
函数。您可能知道,此函数需要两个 SciPy
数组作为输入。所以,为了使用这个函数,我将我的预测张量和地面实况张量转换为 NumPy
数组,如下所示
NumPy
然后,将pred_np = pred_tensor.detach().cpu().clone().numpy().ravel()
target_np = target_tensor.detach().cpu().clone().numpy().ravel()
W_loss = wasserstein_distance(pred_np,target_np)
与W_loss
相加得到总损失。我现在展示这部分是因为它有点不必要并且与我的问题无关。
我担心的是我正在分离梯度,所以我想在优化和更新模型参数时它不会考虑 bceloss
。我是个新手,所以我希望我的问题很清楚,并感谢您提前回答。
解决方法
将一个不是张量的对象添加到您的损失中本质上是添加一个常量。常数的导数为零,所以这个增加的项对您的网络的权重没有任何影响。
tl;博士: 您需要在 pytorch 中重写损失计算(或者只是找到一个现有的实现,互联网上有很多)。