自编码器 BiLSTM 模型中的 Pytorch 大小不匹配错误

问题描述

我正在尝试开发基于 BiLSTM 的自动编码器。当我试图在解码器的第二个 LSTM 层中重塑时,它给出的问题是大小。但是会报错:

运行时错误:预期隐藏 [0] 大小 (1,64,600),得到 (2,600)。

我也想让第二个 LSTM 层成为双向的。

这里是参数

seed= 0,epochs= 4000,batch_size= 64,lr= 5e-04,dropout= 0.1,embedding_dims=40,e_hidden_dims= 100,seq_length= 131,bottleneck_dims=100,interval=10,d_hidden_dims=600
class Autoencoder(nn.Module):
  def __init__(self,nuc_pair_size,embedding_dims,e_hidden_dims,bottleneck_dims,d_hidden_dims,seq_length,dropout_size = 0.1):
    super().__init__()
    nuc_pair_size+=1
    self.seq_length= seq_length
    # define the vars over here (layers,objects)
    self.embedding= nn.Embedding( nuc_pair_size,embedding_dims)
    self.rnn1= nn.LSTM(input_size= embedding_dims,hidden_size= e_hidden_dims,bidirectional=True)
    self.fc0= nn.Linear(in_features = e_hidden_dims,out_features= bottleneck_dims*2)
    self.fc1= nn.Linear(in_features = bottleneck_dims*2,out_features= bottleneck_dims)
    self.a1= nn.ReLU(True)
    #self.a1= nn.Sigmoid(True)
    self.dropout= nn.Dropout(dropout_size)
        
    self.fc02= nn.Linear(in_features = bottleneck_dims,out_features= d_hidden_dims*2)
    self.fc2= nn.Linear(in_features = d_hidden_dims*2,out_features= d_hidden_dims)
    self.rnn2= nn.LSTM(input_size= d_hidden_dims,hidden_size= d_hidden_dims)#,bidirectional=True)
    self.fc3= nn.Linear(in_features= d_hidden_dims,out_features= nuc_pair_size)
 
  def encoder(self,x):
    x= self.embedding(x).permute(1,2)
    _,(hidden_states,_)= self.rnn1(x)
    lv= self.fc0(hidden_states)
    lv= self.fc1(lv) # latent vector
    lv= self.dropout(lv)
    return lv

  def decoder(self,lv):
    lv=self.fc02(lv)
    lv= self.fc2(lv)
    output,_= self.rnn2(lv.repeat(self.seq_length,1,1),(lv,lv))
    output= output.permute(1,2)
    logits= self.fc3(output)
    return logits.transpose(1,2)
      
  def forward(self,x):
    lv= self.encoder(x)
    logits= self.decoder(lv)
    return (lv.squeeze(),logits)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)