问题描述
我定义了一个三层卷积层(self.convs),输入张量的形状为([100,10,24])
x_convs = self.convs(Variable(torch.from_numpy(X).type(torch.FloatTensor)))
>>Variable(torch.from_numpy(X).type(torch.FloatTensor)).shape
torch.Size([100,24])
>>self.convs
ModuleList(
(0): ConvBlock(
(conv): Conv1d(24,8,kernel_size=(5,),stride=(1,padding=(2,))
(relu): ReLU()
(maxpool): AdaptiveMaxPool1d(output_size=10)
(zp): ConstantPad1d(padding=(1,0),value=0)
)
(1): ConvBlock(
(conv): Conv1d(8,value=0)
)
(2): ConvBlock(
(conv): Conv1d(8,value=0)
)
)
当我开除x_convs = self.convs(Variable(torch.from_numpy(X).type(torch.FloatTensor)))
时,它给了我错误
`94 registered hooks while the latter silently ignores them.
95 """
---> 96 raise NotImplementedError`
ConvBlock定义如下
class ConvBlock(nn.Module):
def __init__(self,T,in_channels,out_channels,filter_size):
super(ConvBlock,self).__init__()
padding = self._calc_padding(T,filter_size)
self.conv=nn.Conv1d(in_channels,filter_size,padding=padding)
self.relu=nn.ReLU()
self.maxpool=nn.AdaptiveMaxPool1d(T)
self.zp=nn.ConstantPad1d((1,0)
def _calc_padding(self,Lin,kernel_size,stride=1,dilation=1):
p = int(((Lin-1)*stride + 1 + dilation*(kernel_size - 1) - Lin)/2)
return p
def forward(self,x):
x = x.permute(0,2,1)
x = self.conv(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.permute(0,1)
return x
“前进”功能具有正确的缩进,因此我无法弄清楚发生了什么。
解决方法
如果要按顺序执行这3层,则应使用nn.Sequential
而不是nn.ModuleList
。 nn.ModuleList
没有实现forward()
方法,但是nn.Sequential
实现了。
如果您希望在forward()
方法中有某些特殊行为,则可以继承nn.ModuleList
的子类并覆盖其forward()
。
您正在尝试调用ModuleList
(即list
(即Python中的列表对象),对其进行了稍微的修改以用于PyTorch。
一种快速解决方案是将self.convs
称为:
x_convs = self.convs[0](Variable(torch.from_numpy(X).type(torch.FloatTensor)))
if len(self.convs) > 1:
for conv in self.convs[1:]:
x_convs = self.convs[0](x_convs)
也就是说,尽管self.convs
是list
,但每个成员都是Module
。您可以使用其索引直接调用self.convs
的每个成员,例如``self.convsan_index`。
或者,您可以借助functools
模块来做到这一点:
from functools import reduce
def apply_layer(layer_input,layer):
return layer(layer_input)
output_of_self_convs = reduce(apply_layer,self.convs,Variable(torch.from_numpy(X).type(torch.FloatTensor)))
P.S。不过,Variable
关键字已不再使用。