输入必须有 3 个维度,在创建 LSTM 分类器时出现 2 个错误 使用重塑功能添加重复矢量层

问题描述

网络的结构必须如下:

(lstm): LSTM(1,64,batch_first=True)

(fc1):线性(in_features=64,out_features=32,bias=True)

(relu): ReLU()

(fc2):线性(in_features=32,out_features=5,bias=True)

我写了这段代码:

class LSTMClassifier(nn.Module):

    def __init__(self):
        super(LSTMClassifier,self).__init__() 
        self.lstm = nn.LSTM(1,batch_first=True)
        self.fc1 = nn.Linear(in_features=64,out_features=32,bias=True)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features=32,out_features=5,bias=True)
         

    def forward(self,x):
       x = torch.tanh(self.lstm(x)[0])
       x = self.fc1(x)
       x = F.relu(x)
       x = self.fc2(x)

这是为了测试:

    (batch_data,batch_label) = next (iter (train_loader))
    model = LSTMClassifier().to(device)
    output = model (batch_data.to(device)).cpu()
    assert output.shape == (batch_size,5)
    print ("passed")

错误是:

----> 3 输出 = 模型 (batch_data.to(device)).cpu()

5 帧 /usr/local/lib/python3.7/dist-packages/torch/nn/modules/rnn.py in check_input(self,input,batch_sizes) 201 引发运行时错误( 202'输入必须有{}维,得到{}'.format( --> 203 expected_input_dim,input.dim())) 204 如果 self.input_size != input.size(-1): 205 引发运行时错误(

运行时错误:输入必须有 3 个维度,得到 2 个

我的问题是什么?

解决方法

LSTM 支持 3 维输入(样本、时间步长、特征)。您需要将输入从 2D 转换为 3D。为此,您可以:

使用重塑功能

首先,您需要使用 batch_data.shape 获得 2D 输入的形状。让我们假设您的 2D 输入的形状是 (15,4)。 现在要将输入从 2D 重塑为 3D,您可以使用重塑函数 np.reshape(data,new_shape)

    (batch_data,batch_label) = next (iter (train_loader))
    batch_data = np.reshape(batch_data,(15,4,1)) # line to add
    model = LSTMClassifier().to(device)
    output = model (batch_data.to(device)).cpu()
    assert output.shape == (batch_size,5)
    print ("passed")

稍后,您还需要将测试数据从 2D 重塑为 3D。

添加重复矢量层

该层是在 Keras 中实现的,我不确定它是否在 PyTorch 中可用,这是您的情况。 该层为您的数据添加了一个额外的维度(重复输入 n 次)。例如,您可以将 2D 输入 (batch size,input size) 转换为 3D 输入 (batch_size,sequence_length,input size)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...