根据索引的多维数组中的Numpy sum元素

问题描述

我正在处理一个非常大的多维数据,但让我以一个二维数组为例。给定一个每次迭代都在改变的值数组,

arr = np.array([[ 1,2,3,4,5],[5,6,7,8,9]]) # a*b

一个一直固定的索引数组。

idx = np.array([[[0,1,1],[-1,-1,-1]],[[5,3],[1,-1]]]) # n*h*w,where n = a*b,

这里 -1 表示不会应用任何索引。我希望得到一个结果

res = np.array([[1+2+2,0],[5+2+4,2]]) # h*w

在实际实践中,我正在使用一个非常大的 3D 张量(n ~ 万亿),具有非常稀疏的 idx(即很多 -1)。由于 idx 是固定的,我目前的解决方案是通过填充 0 和 1 来预先计算 n*(h*w) 数组 index_tensor,然后执行

tmp = arr.reshape(1,n)
res = (tmp @ index_tensor).reshape([h,w])

它工作正常,但需要大量内存来存储 index_tensor。有什么方法可以利用 idx 的稀疏性和不变性来降低内存成本并在 python 中保持公平的运行速度(使用 numpy 或 pytorch 将是最好的)?提前致谢!

解决方法

暂时忽略 -1 的复杂性,直接索引和求和是:

In [58]: arr = np.array([[ 1,2,3,4,5],[5,6,7,8,9]])
In [59]: idx = np.array([[[0,1,1],[2,6]],...:                 [[5,3],[1,-1,-1]]])
In [60]: arr.flat[idx]
Out[60]: 
array([[[1,2],[3,5,[[5,4],9,9]]])
In [61]: _.sum(axis=-1)
Out[61]: 
array([[ 5,14],[11,20]])

处理 -1 的一种方法(不一定快速或内存高效)是使用掩码数组:

In [62]: mask = idx<0
In [63]: mask
Out[63]: 
array([[[False,False,False],[False,False]],[[False,True,True]]])

In [65]: ma = np.ma.masked_array(Out[60],mask)
In [67]: ma
Out[67]: 
masked_array(
  data=[[[1,--,--]]],mask=[[[False,True]]],fill_value=999999)
In [68]: ma.sum(axis=-1)
Out[68]: 
masked_array(
  data=[[5,2]],mask=[[False,fill_value=999999)

掩码数组通过用中性值替换掩码值来处理求和之类的操作,例如求和的情况下为 0。

(我可能会在早上重温这个)。

求和矩阵乘积

In [72]: np.einsum('ijk,ijk->ij',Out[60],~mask)
Out[72]: 
array([[ 5,2]])

这比掩码数组方法更直接、更快捷。

您尚未详细说明如何构建 index_tensor,因此我不会尝试对其进行比较。

另一种可能是用 0 填充数组,并调整索引:

In [83]: arr1 = np.hstack((0,arr.ravel()))
In [84]: arr1
Out[84]: array([0,9])
In [85]: arr1[idx+1]
Out[85]: 
array([[[1,0]]])
In [86]: arr1[idx+1].sum(axis=-1)
Out[86]: 
array([[ 5,2]])

稀疏

首先尝试使用稀疏矩阵:

idx 重塑为二维:

In [141]: idx1 = np.reshape(idx,(4,3))

从中制作一个稀疏张量。首先,我将采用迭代 lil 方法,尽管通常直接构建 coo(甚至 csr)输入更快:

In [142]: M = sparse.lil_matrix((4,10),dtype=int)
     ...: for i in range(4):
     ...:     for j in range(3):
     ...:         v = idx1[i,j]
     ...:         if v>=0:
     ...:            M[i,v] = 1
     ...: 
In [143]: M
Out[143]: 
<4x10 sparse matrix of type '<class 'numpy.int64'>'
    with 9 stored elements in List of Lists format>
In [144]: M.A
Out[144]: 
array([[1,0],[0,0]])

这可以用于产品总和:

In [145]: M@arr.ravel()
Out[145]: array([ 3,14,11,2])

使用 M.A@arr.ravel() 本质上就是您所做的。虽然 M 是稀疏的,但 arr 不是。对于这种情况,M.A@M@ 快。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...