计算一批数组的数组矩阵乘法的有效方法

问题描述

我想并行化以下问题。给定一个形状为 w 的数组 (dim1,)一个形状为 A 的矩阵 (dim1,dim2),我希望 A 的每一行都与 { 的相应元素相乘{1}}。 这很微不足道。

但是,我想对一堆数组 w 执行此操作,最后对结果求和。因此,为了避免 for 循环,我创建了形状为 w 的矩阵 W,并按以下方式使用了 (n_samples,dim1) 函数

np.einsum

其中 x = np.einsum('ji,ik -> jik',W,A)) r = x.sum(axis=0) 的形状为 x,最终和的形状为 (n_samples,dim1,dim2)

我注意到 (dim1,dim2) 对于大矩阵 np.einsum 来说非常慢。有没有更有效的方法解决这个问题?我也想尝试使用 A,但可能不是这样。

谢谢:-)

解决方法

In [455]: W = np.arange(1,7).reshape(2,3); A = np.arange(1,13).reshape(3,4)

你的计算:

In [463]: x = np.einsum('ji,ik -> jik',W,A)
     ...: r = x.sum(axis=0)
In [464]: r
Out[464]: 
array([[  5,10,15,20],[ 35,42,49,56],[ 81,90,99,108]])

如评论中所述,einsum 可以对 j 进行求和:

In [465]: np.einsum('ji,ik -> ik',A)
Out[465]: 
array([[  5,108]])

由于 j 只出现在 A 中,我们可以先对 A 求和:

In [466]: np.sum(W,axis=0)[:,None]*A
Out[466]: 
array([[  5,108]])

这不涉及乘积和,所以不涉及矩阵乘法。

或者在乘法后求和:

In [475]: (W[:,:,None]*A).sum(axis=0)
Out[475]: 
array([[  5,108]])

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...