问题描述
我正在尝试使用批量标准化实现 to_date
模型,但出现错误:
select trunc(time,'HH24'),count(column) as xyz
from table
where time >= date '2021-08-04'
and time <= date '2021-09-04'
and column='xyz'
group by trunc(time,'HH24')
order by trunc(time,'HH24');
我使用的批量大小为 32,数据的特征数为 40。我一直在尝试计算 Conv1d
的来源,但我无法做到。这是我尝试使用的 CNN 架构:
RuntimeError Traceback (most recent call last)
<ipython-input-117-ef6e122ea50c> in <module>()
----> 1 test()
2 for epoch in range(1,n_epochs + 1):
3 train(epoch)
4 test()
7 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in linear(input,weight,bias)
1751 if has_torch_function_variadic(input,weight):
1752 return handle_torch_function(linear,(input,weight),input,bias=bias)
-> 1753 return torch._C._nn.linear(input,bias)
1754
1755
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x140 and 100x10)
解决方法
这个全连接应该从:
nn.Linear(20*5,10)
到:
nn.Linear(20*7,10)
为什么?
如果您的输入数据长度为 40,则(B
是批量大小):
- 第一次转换后的输出(K=25):
B x 25 x 18
- 第二次转换后的输出 (K=20):
B x 20 x 7
-
nn.Flatten()
后的输出:B x 140
,即如果B=32,则32x140