如何允许复杂的输入和复杂的权重到 Pytorch 模型?

问题描述

假设即使是最简单的模型(取自 here

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,32,3,1)
        self.conv2 = nn.Conv2d(32,64,1)
        self.fc1 = nn.Linear(9216,128)
        self.fc2 = nn.Linear(128,10)

    def forward(self,x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x,2)
        x = torch.flatten(x,1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)

        output = F.log_softmax(x,dim=1)
        return output

向模型提供复杂数据时,

output = model(data.complex())

它给了

ret = torch.addmm(bias,input,weight.t())

RuntimeError: expected scalar type Float but found ComplexDouble

(为了简单起见,我没有复制整个堆栈跟踪,也没有复制整个训练代码


在模型的 self.complex() 之后做 __init__,就像我通常会做的那样 self.double(),不起作用,

torch.nn.modules.module.ModuleAttributeError: 'Net' object has no attribute 'complex'

  1. 如何让模型的权重变得复杂?
  2. 如何允许对模型进行复杂输入?
  3. 哪些内置激活函数支持这一点?
  4. 是否还支持一维操作?

编辑:

同时,我发现 this paper。还在读。

解决方法

正如您通常所做的self.double(),我从https://pytorch.org/docs/stable/generated/torch.nn.Module.html

找到了self.type(dst_type)

就我而言,self.type(torch.complex64) 对我有用。