问题描述
假设即使是最简单的模型(取自 here)
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1,32,3,1)
self.conv2 = nn.Conv2d(32,64,1)
self.fc1 = nn.Linear(9216,128)
self.fc2 = nn.Linear(128,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x,2)
x = torch.flatten(x,1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x,dim=1)
return output
向模型提供复杂数据时,
output = model(data.complex())
它给了
ret = torch.addmm(bias,input,weight.t())
RuntimeError: expected scalar type Float but found ComplexDouble
(为了简单起见,我没有复制整个堆栈跟踪,也没有复制整个训练代码)
在模型的 self.complex()
之后做 __init__
,就像我通常会做的那样 self.double()
,不起作用,
torch.nn.modules.module.ModuleAttributeError: 'Net' object has no attribute 'complex'
编辑:
同时,我发现 this paper。还在读。
解决方法
正如您通常所做的self.double()
,我从https://pytorch.org/docs/stable/generated/torch.nn.Module.html
self.type(dst_type)
就我而言,self.type(torch.complex64)
对我有用。