以行方式使用 numpy matmul 和广播

问题描述

我有一个 3D 点 (n,3) 数组,它们将使用以 nx3x3 数组形式存储的 3x3rotation 矩阵绕原点旋转。

目前我只是在使用 matmul 的 for 循环中执行此操作,但我认为这是毫无意义的,因为必须有更快的广播方式来执行此操作。

当前代码

n = 10
points = np.random.random([10,3])
rotation_matrices = np.tile(np.random.random([3,3]),(n,1,1))

result = []

for point in range(len(points)):
    rotated_point = np.matmul(rotation_matrices[point],points[point])

    result.append(rotated_point)

result = np.asarray(result)

注意:在这个例子中,我只是平铺了相同的旋转矩阵,但在我的真实情况下,每个 3x3 旋转矩阵都是不同的。

我想做什么

我猜一定有某种广播方式,因为当点云变得非常大时,for 循环会变得非常缓慢。我想这样做:

np.matmul(rotation_matrices,points)

其中 row 中的每个 points 乘以其相应的旋转矩阵。可能有一种方法可以使用 np.einsum 执行此操作,但我无法弄清楚签名。

解决方法

如果您看到 https://ghostbin.co/paste/s2qdw,则 np.einsum('ij,jk',a,b)matmul 的签名。

所以你可以试试带有签名的np.einsum

np.einsum('kij,kj->ki',rotation_matrices,points)

测试

einsum = np.einsum('kij,points)
manual = np.array([np.matmul(x,y) for x,y in zip (rotation_matrices,points)])
np.allclose(einsum,manual)
# True

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...