如何解决以下 CNN 架构的错误“mat1 和 mat2 形状无法相乘”?

问题描述

我正在尝试使用批量标准化实现 to_date 模型,但出现错误

select trunc(time,'HH24'),count(column) as xyz 
  from table
 where time >= date '2021-08-04'
   and time <= date '2021-09-04'
   and column='xyz'
 group by trunc(time,'HH24')
 order by trunc(time,'HH24');

我使用的批量大小为 32,数据的特征数为 40。我一直在尝试计算 Conv1d 的来源,但我无法做到。这是我尝试使用的 CNN 架构:

RuntimeError                              Traceback (most recent call last)
<ipython-input-117-ef6e122ea50c> in <module>()
----> 1 test()
      2 for epoch in range(1,n_epochs + 1):
      3   train(epoch)
      4   test()

7 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input,weight,bias)
   1751     if has_torch_function_variadic(input,weight):
   1752         return handle_torch_function(linear,(input,weight),input,bias=bias)
-> 1753     return torch._C._nn.linear(input,bias)
   1754 
   1755 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x140 and 100x10)

解决方法

这个全连接应该从:

nn.Linear(20*5,10)

到:

nn.Linear(20*7,10)

为什么?

如果您的输入数据长度为 40,则(B 是批量大小):

  • 第一次转换后的输出(K=25):B x 25 x 18
  • 第二次转换后的输出 (K=20):B x 20 x 7
  • nn.Flatten()后的输出:B x 140,即如果B=32,则32x140