问题描述
我试图在PyTorch中将两个复杂的矩阵相乘,看来the torch.matmul functions is not added yet to PyTorch library for complex numbers.
您有任何建议吗?还是有另一种方法可以在PyTorch中乘以复杂矩阵?
解决方法
torch.matmul
等复杂张量目前不支持ComplexFloatTensor
,但是您可以执行以下代码一样紧凑的操作:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag,t1.real @ t2.imag + t1.imag @ t2.real),dim=2))
在可能的情况下,避免使用for循环,因为这会导致实现速度大大降低。 通过使用我随附的代码中演示的内置方法来实现矢量化。 例如,对于2个尺寸为1000 X 1000的随机复杂矩阵,您的代码在CPU上花费大约6.1s,而矢量化版本仅花费101ms(快60倍)。
,我使用torch.mv为pytorch.matmul实现了此函数,以处理复数,并且在时间上运行良好:
def matmul_complex(t1,t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n),dtype=torch.cfloat)
t_total = torch.empty((m,dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total,torch.mv(t1,i])),0)
t_final = torch.reshape(t_total,(m,n))
return t_final
我是PyTorch的新手,所以如果我错了,请纠正我。