自定义堆叠 LSTM 的输出与 nn.LSTM 不同

问题描述

我实现了多层 LSTM,但是如果 init_state 不是 None,结果与 nn.LSTM 不同。我将 LSTM 模型中的权重加载到我的自定义模型和 pytorch 的 nn.LSTM 模型中。我怀疑我可能在前向功能上做错了什么。任何帮助将不胜感激。非常感谢!


class StackedLSTMs(nn.Module):
    def __init__(self,input_sz:int,hidden_sz: int,num_layers: int):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_sz = hidden_sz
        self.LSTMs = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.LSTMs.append(nn.LSTMCell(input_sz,hidden_sz))
                #self.LSTMs.append(NaiveCustomLSTMCell(input_sz,hidden_sz))
            else:
                self.LSTMs.append(nn.LSTMCell(hidden_sz,hidden_sz))
                #self.LSTMs.append(NaiveCustomLSTMCell(hidden_sz,hidden_sz))


    def forward(self,x,h: Optional[Tuple[torch.Tensor,torch.Tensor]] = None):
        print('hidden',h)
        seq_size,bs,_ = x.size()
        outputs = []
        if h is None:
            hn = torch.zeros(self.num_layers,self.hidden_sz)
            cn = torch.zeros(self.num_layers,self.hidden_sz)
        else:
            (hn,cn) = h

        for t in range(seq_size):
            for layer,lstm in enumerate(self.LSTMs):
                if layer == 0:
                    hn[layer,:,:],cn[layer,:] = lstm(x[t,(hn[layer,:]))
                else:
                    hn[layer,:] = lstm(hn[layer-1,:]))
            temp = hn[self.num_layers - 1,:].detach().clone()
            outputs.append(temp)
        outputs = torch.stack(outputs,dim=0)
        h = (hn,cn)
        #outputs = outputs.transpose(0,1).contiguous()
        return outputs,h


torch.manual_seed(999)
lstms = nn.LSTM(320,320,2)
stackedlstms = StackedLSTMs(320,2)

stackedlstms.LSTMs[0].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
stackedlstms.LSTMs[0].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
stackedlstms.LSTMs[0].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
stackedlstms.LSTMs[0].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_hh_l0

stackedlstms.LSTMs[1].weight_ih = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
stackedlstms.LSTMs[1].weight_hh = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
stackedlstms.LSTMs[1].bias_ih = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
stackedlstms.LSTMs[1].bias_hh = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1

lstms.weight_ih_l0 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l0
lstms.weight_hh_l0 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l0
lstms.bias_ih_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0
lstms.bias_hh_l0 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l0

lstms.weight_ih_l1 = oldmodel.prediction.dec_rnn.lstm.weight_ih_l1
lstms.weight_hh_l1 = oldmodel.prediction.dec_rnn.lstm.weight_hh_l1
lstms.bias_ih_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1
lstms.bias_hh_l1 = oldmodel.prediction.dec_rnn.lstm.bias_ih_l1

hidden = torch.load('hidden.pt')
newembedt = torch.load('newembed_t.pt')

lstms_res = lstms(newembedt,hidden)
stackedlstms_res = stackedlstms(newembedt,hidden)

print(torch.sum(abs(lstms_res[0]-stackedlstms_res[0])))
print(torch.sum(abs(lstms_res[1][0]-stackedlstms_res[1][0])))

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...