使用 np.where() 选择数组的有效方法?

问题描述

我有三个数组:mgrad1grad2m 的形状为 (x,),而 grad1grad2 的形状为 (x,y,z)。我试图找出基于 grad1 的值创建具有 grad2m 值条目的新数组的最有效方法。我尝试使用以下代码执行此操作:

param0_grad = np.where(m[:] > 0,grad1,grad2)

根据我对 np.where() 的理解,我认为这应该根据 param0_grad 中的每个值使用 grad1grad2 填充 m。但是,我收到以下广播错误(当 x=3、y=4、z=2 时):

ValueError: operands Could not be broadcast together with shapes (3,) (3,4,2) (3,2)

代码适用于 x=2,但没有 x>2 的值。

解决方法

试试这个:

param0_grad = np.where(m[:,None,None] > 0,grad1,grad2)

基本上,您需要添加一个空维度来广播尾随轴。