用pytorch训练输入梯度的神经网络

问题描述

我目前正在尝试用pytorch训练神经网络,在这里我尝试匹配输入导数上的输入。我要这样做是因为这确保了保守的矢量场。 (完成神经网络训练以进行分子动力学力匹配) 这意味着:

input = torch.rand((n,3),requires_grad=True)
output = torch.rand((n,requires_grad=True)
prediction = model(input) # size of the prediction[1]
input_grad = torch.autograd.grad(outputs=prediction,inputs=input,retain_graph=True,create_graph=True)
loss = loss_fn(output,input_grad)
...

问题是,如果我尝试更新神经网络的参数,则所有参数的梯度均为0。我不知道如何建立模型正在正确训练的图形。 在Jaxmd中,可以像[Jax Glass Training] [1]所示训练这样的模型。 我也尝试过

input_grad = torch.autograd.grad(outputs=prediction,create_graph=True)

但是这会产生相似的结果并且没有任何意义。 [1]:https://colab.research.google.com/github/google/jax-md/blob/master/notebooks/neural_networks.ipynb#scrollTo=WNs8v2745Mc3

编辑:

再现pytorch版本1.6.0的更新代码示例

import torch 

class Feedforward(torch.nn.Module):
        def __init__(self,input_size,hidden_size):
            super(Feedforward,self).__init__()
            self.input_size = input_size
            self.hidden_size  = hidden_size
            self.fc1 = torch.nn.Linear(self.input_size,self.hidden_size)
            self.relu = torch.nn.ReLU()
            self.fc2 = torch.nn.Linear(self.hidden_size,1)
     
        
        def forward(self,x):
            hidden = self.fc1(x)
            relu = self.relu(hidden)
            output = self.fc2(relu)
            output = output.sum()
            output =torch.autograd.grad(outputs=output,inputs=x,create_graph=True)
            return output[0]

test_input = torch.rand((10,requires_grad=True)
test_output = torch.rand((10,3))


model = Feedforward(3,10)
optim = torch.optim.Adam(model.parameters())
optim.zero_grad()
loss_fn = torch.nn.L1Loss()
model.train()
out = model(test_input)
loss = loss_fn(out,test_output)
loss.backward()
optim.step() # if you break here and investigate the gradients
             # of the FFNN,the gradients will be 0 

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...