带有 where 参数的意外 numpy sum 行为

问题描述

举个例子,看看这些 numpy 数组:

>>> a
array([[1,2,3],[4,5,6]])
>>> b
array([[ True,False,True],[False,[ True,True,False]])

假设我想要 a 的每一行的总和,包括 b 的每一行中指定的元素。这里有两条指令可以做到这一点:

>>> np.sum(a[:,None] * b[None],2)
array([[ 4,3,[10,6,9]])
>>> np.sum(np.where(b[None],a[:,None],0),9]])

我通常使用第一个选项,但最近发现 np.sum一个 where 参数,并且希望它可以工作:

>>> np.sum(a[:,where=b[None])
array([[10],[25]])

但结果不同。我可以看到每一行实际上对应于正确结果中的行的总和。

我还发现当维度已经匹配而没有广播时,使用两种方法的结果是一样的:

>>> a
array([[1,True]])
>>> np.sum(a * b,1)
array([4,6])
>>> np.sum(a,1,where=b)
array([4,6])

这种行为的解释是什么?有什么办法可以防止,还是应该坚持我以前的方法

解决方法

所以你一直在做的是创建一个 (2,3,3) 数组,并在最后一个轴上求和:

In [216]: np.where(b,a[:,None],0)
Out[216]: 
array([[[1,3],[0,[1,2,0]],[[4,6],[4,5,0]]])
In [217]: np.sum(_,axis=2)
Out[217]: 
array([[ 4,[10,6,9]])

如果我们复制 a 来创建一个 (2,3) 数组:

In [218]: A=a[:,None,:].repeat(3,1)
In [219]: A
Out[219]: 
array([[[1,3]],6]]])

我们可以总结b为真(b广播到(1,3)到(2,3):

In [221]: np.sum(A,where=b,axis=2)
Out[221]: 
array([[ 4,9]])

where 的这种用法是相对较新的,而且要弄清楚如何去做需要大量的反复试验。不知道有没有速度优势。

where 最简单的方法是使用带有 1 或 2 个参数的 ufunc,例如除法或求逆,我们不希望它在 0 处进行计算。然后我们指定一个具有默认值的 out 数组。 np.sumnp.add.reduce。它具有默认的 0 开始,因此不需要(或允许)out

where : array_like of bool,optional
    A boolean array which is broadcasted to match the dimensions
    of `array`,and selects elements to include in the reduction. Note
    that for ufuncs like ``minimum`` that do not have an identity
    defined,one has to pass in also ``initial``.

虽然 b 广播匹配 A,但 A 不可广播,因此必须复制。

In [231]: np.sum(a[:,axis=2)
Traceback (most recent call last):
  File "<ipython-input-231-b6e9e9179fac>",line 1,in <module>
    np.sum(a[:,axis=2)
  File "<__array_function__ internals>",line 5,in sum
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/fromnumeric.py",line 2247,in sum
    return _wrapreduction(a,np.add,'sum',axis,dtype,out,keepdims=keepdims,File "/usr/local/lib/python3.8/dist-packages/numpy/core/fromnumeric.py",line 87,in _wrapreduction
    return ufunc.reduce(obj,**passkwargs)
ValueError: non-broadcastable operand with shape (2,1,3) doesn't match the broadcast shape (2,3)