问题描述
我需要在 numpy(或 pytorch)中将乘积超过两个张量:
A = np.arange(1024).reshape(8,1,128)
B = np.arange(9216).reshape(8,128,9)
并且想要获得 C
,点积在 A
的最后一个点 (axis=2
) 和 B
的中间点 (axis=1
)。这应该有尺寸 8x9
。目前,我正在做:
C = np.zeros([8,9])
for i in range(8):
C[i,:] = np.matmul(A[i,:,:],B[i,:])
如何优雅地做到这一点?
我试过了:
np.tensordot(weights,features,axes=(2,1)).
但它返回 8x1x8x9
。
解决方法
一种方法是使用 numpy.einsum
。
C = np.einsum('ijk,ikl->il',A,B)
或者你可以使用 broadcasted 矩阵乘法。
C = (A @ B).squeeze(axis=1)
# equivalent: C = np.matmul(A,B).squeeze(axis=1)