问题描述
我有两个3维数组, A , B ,其中
- A 的尺寸为(500 x 500 x 80),并且
- B 的尺寸为(500 x 80 x 2000)。
在这两个数组中,大小为80的维度可以称为“时间”(例如80个时间点i
)。大小为2000的维度可以称为“方案”(我们有2000 scenario
s)。
我需要做的是取500 x 500矩阵A[:,:,i]
并乘以每个B[:,i,scenario]
和时间{{ 1}}。
我最终得到了下面的代码
scenario
对于相同的问题,一种幼稚的方法是使用嵌套的for循环
i
我希望from scipy.stats import norm
import numpy as np
A = norm.rvs(size = (500,500,80),random_state = 0)
B = norm.rvs(size = (500,80,2000),random_state = 0)
result = np.einsum('ijk,jkl->ikl',A,B,optimize=True)
很快,因为问题“仅”涉及对大型阵列的简单操作,但实际上运行了几分钟。
我将上述for scenario in range(2000):
for i in range(80):
out[:,scenario] = A[:,i] @ B[:,scenario]
的速度与假设 A 中每个矩阵相同的情况进行了比较,我们可以将 A 保留为(500 x 500)矩阵(而不是3d数组),则整个问题可以写成
einsum
这是快速的,仅运行几秒钟。比上面的“稍微”更普遍的情况要快得多。
我的问题是-我是否以计算有效的形式编写慢einsum
的一般情况?
解决方法
与现有的两个嵌套循环相比,您可以做的更好-
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out1 = np.empty((m,r,n),dtype=np.result_type(A.dtype,B.dtype))
for i in range(r):
out1[:,i,:] = A[:,:,i] @ B[:,:]
或者,用np.matmul/@ operator
-
out = (A.transpose(2,1) @ B.transpose(1,2)).swapaxes(0,1)
这两个版本的扩展性似乎比einsum
版本好。
时间
案例1:缩放到第1/4大小
In [44]: m = 500
...: n = 2000
...: r = 80
...: m,n,r = m//4,n//4,r//4
...:
...: A = norm.rvs(size = (m,m,r),random_state = 0)
...: B = norm.rvs(size = (m,random_state = 0)
In [45]: %%timeit
...: out1 = np.empty((m,B.dtype))
...: for i in range(r):
...: out1[:,:]
175 ms ± 6.54 ms per loop (mean ± std. dev. of 7 runs,10 loops each)
In [46]: %timeit (A.transpose(2,1)
165 ms ± 1.11 ms per loop (mean ± std. dev. of 7 runs,10 loops each)
In [47]: %timeit np.einsum('ijk,jkl->ikl',A,B,optimize=True)
483 ms ± 13.5 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
随着我们的扩展,内存拥塞将开始支持单循环版本。
案例2:缩放为1/2大小
In [48]: m = 500
...: n = 2000
...: r = 80
...: m,r = m//2,n//2,r//2
...:
...: A = norm.rvs(size = (m,random_state = 0)
In [49]: %%timeit
...: out1 = np.empty((m,:]
2.9 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
In [50]: %timeit (A.transpose(2,1)
3.02 s ± 94.8 ms per loop (mean ± std. dev. of 7 runs,1 loop each)
案例3:缩放67%大小
In [59]: m = 500
...: n = 2000
...: r = 80
...: m,r = int(m/1.5),int(n/1.5),int(r/1.5)
In [60]: A = norm.rvs(size = (m,random_state = 0)
In [61]: %%timeit
...: out1 = np.empty((m,:]
25.8 s ± 4.9 s per loop (mean ± std. dev. of 7 runs,1 loop each)
In [62]: %timeit (A.transpose(2,1)
29.2 s ± 2.41 s per loop (mean ± std. dev. of 7 runs,1 loop each)
Numba分拆
from numba import njit,prange
@njit(parallel=True)
def func1(A,B):
m = A.shape[0]
n = B.shape[2]
r = A.shape[2]
out = np.empty((m,n))
for i in prange(r):
out[:,:]
return out
案例3的计时-
In [80]: m = 500
...: n = 2000
...: r = 80
...: m,int(r/1.5)
In [81]: A = norm.rvs(size = (m,random_state = 0)
In [82]: %timeit func1(A,B)
653 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs,1 loop each)