我如何向量化此操作?

问题描述

我必须执行以下操作数千次,这严重降低了我的代码的速度:

T = 50
D = 10
K = 20

x = np.random.randn(T,D)
y = np.random.randn(T,K)

result = np.zeros((K,D))

for k in range(K):
    for t in range(T):
        result[k] += y[t,k] * x[t]  # Multiply scalar element in y with row in x

基本上,我试图将矩阵k的列y的列x中的每个元素与np.einsum()中的对应行相加并求和。我尝试使用result = np.einsum("ij,ik->jk",y,x) 解决此问题:

result.shape == (K,D)

至少给了我np.einsum(),但结果不匹配!如何有效执行此操作? UNIQUE NOT NULL甚至有可能吗?

解决方法

这些操作是相同的。您已经找到了(可能是最快的)矢量化操作。

T = 50
D = 10
K = 20

x = np.random.randn(T,D)
y = np.random.randn(T,K)

result = np.zeros((K,D))

for k in range(K):
    for t in range(T):
        result[k] += y[t,k] * x[t]
           
result2 = np.einsum("ij,ik->jk",y,x)

np.allclose(result,result2)
Out[]: True

问题很可能是浮点错误,无论您使用哪种方法来确定它们是否“相同”。 np.allclose()是解决方案。它消除了使用float s在不同计算方法之间发生的很小的误差。

尽管@QuangHoang在评论中指出,y.T @ x更具可读性

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...