PyTorch中复数的矩阵乘法

问题描述

我试图在PyTorch中将两个复杂的矩阵相乘,看来the torch.matmul functions is not added yet to PyTorch library for complex numbers.

您有任何建议吗?还是有另一种方法可以在PyTorch中乘以复杂矩阵?

解决方法

torch.matmul等复杂张量目前不支持ComplexFloatTensor,但是您可以执行以下代码一样紧凑的操作:

def matmul_complex(t1,t2):
    return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag,t1.real @ t2.imag + t1.imag @ t2.real),dim=2))

在可能的情况下,避免使用for循环,因为这会导致实现速度大大降低。 通过使用我随附的代码中演示的内置方法来实现矢量化。 例如,对于2个尺寸为1000 X 1000的随机复杂矩阵,您的代码在CPU上花费大约6.1s,而矢量化版本仅花费101ms(快60倍)。

,

我使用torch.mv为pytorch.matmul实现了此函数,以处理复数,并且在时间上运行良好:

def matmul_complex(t1,t2):
  m = list(t1.size())[0]
  n = list(t2.size())[1]
  t = torch.empty((1,n),dtype=torch.cfloat)
  t_total = torch.empty((m,dtype=torch.cfloat)
  for i in range(0,n):
    if i == 0:
      t_total = torch.mv(t1,t2[:,i])
    else:
      t_total = torch.cat((t_total,torch.mv(t1,i])),0)
  t_final = torch.reshape(t_total,(m,n))
  return t_final

我是PyTorch的新手,所以如果我错了,请纠正我。