优雅的 Numpy Tensor 产品

问题描述

我需要在 numpy(或 pytorch)中将乘积超过两个张量:

我有

A = np.arange(1024).reshape(8,1,128)
B = np.arange(9216).reshape(8,128,9)

并且想要获得 C,点积在 A 的最后一个点 (axis=2) 和 B 的中间点 (axis=1 )。这应该有尺寸 8x9。目前,我正在做:

C = np.zeros([8,9])
for i in range(8):
    C[i,:] = np.matmul(A[i,:,:],B[i,:])

如何优雅地做到这一点?

我试过了:

np.tensordot(weights,features,axes=(2,1)).

但它返回 8x1x8x9

解决方法

一种方法是使用 numpy.einsum

C = np.einsum('ijk,ikl->il',A,B)

或者你可以使用 broadcasted 矩阵乘法。

C = (A @ B).squeeze(axis=1)
# equivalent: C = np.matmul(A,B).squeeze(axis=1)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...