问题描述
我实现了多层 LSTM,但是如果 init_state 不是 None,结果与 nn.LSTM 不同。我将 LSTM 模型中的权重加载到我的自定义模型和 pytorch 的 nn.LSTM 模型中。我怀疑我可能在前向功能上做错了什么。任何帮助将不胜感激。非常感谢!
class StackedLSTMs(nn.Module):
def __init__(self,input_sz:int,hidden_sz: int,num_layers: int):
super().__init__()
self.num_layers = num_layers
self.hidden_sz = hidden_sz
self.LSTMs = nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.LSTMs.append(nn.LSTMCell(input_sz,hidden_sz))
#self.LSTMs.append(NaiveCustomLSTMCell(input_sz,hidden_sz))
else:
self.LSTMs.append(nn.LSTMCell(hidden_sz,hidden_sz))
#self.LSTMs.append(NaiveCustomLSTMCell(hidden_sz,hidden_sz))
def forward(self,x,h: Optional[Tuple[torch.Tensor,torch.Tensor]] = None):
print('hidden',h)
seq_size,bs,_ = x.size()
outputs = []
if h is None:
hn = torch.zeros(self.num_layers,self.hidden_sz)
cn = torch.zeros(self.num_layers,self.hidden_sz)
else:
(hn,cn) = h
for t in range(seq_size):
for layer,lstm in enumerate(self.LSTMs):
if layer == 0:
hn[layer,:,:],cn[layer,:] = lstm(x[t,(hn[layer,:]))
else:
hn[layer,:] = lstm(hn[layer-1,:]))
temp = hn[self.num_layers - 1,:].detach().clone()
outputs.append(temp)
outputs = torch.stack(outputs,dim=0)
h = (hn,cn)
#outputs = outputs.transpose(0,1).contiguous()
return outputs,h
torch.manual_seed(999)
lstms = nn.LSTM(320,320,2)
stackedlstms = StackedLSTMs(320,2)
stackedlstms.LSTMs[0].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
stackedlstms.LSTMs[0].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
stackedlstms.LSTMs[0].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
stackedlstms.LSTMs[0].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_hh_l0
stackedlstms.LSTMs[1].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
stackedlstms.LSTMs[1].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
stackedlstms.LSTMs[1].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
stackedlstms.LSTMs[1].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.weight_ih_l0 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
lstms.weight_hh_l0 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
lstms.bias_ih_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.bias_hh_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.weight_ih_l1 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
lstms.weight_hh_l1 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
lstms.bias_ih_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.bias_hh_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
hidden = torch.load('hidden.pt')
newembedt = torch.load('newembed_t.pt')
lstms_res = lstms(newembedt,hidden)
stackedlstms_res = stackedlstms(newembedt,hidden)
print(torch.sum(abs(lstms_res[0]-stackedlstms_res[0])))
print(torch.sum(abs(lstms_res[1][0]-stackedlstms_res[1][0])))
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)