问题描述
我有三个数组:m
、grad1
和 grad2
。 m
的形状为 (x,)
,而 grad1
和 grad2
的形状为 (x,y,z)
。我试图找出基于 grad1
的值创建具有 grad2
或 m
值条目的新数组的最有效方法。我尝试使用以下代码执行此操作:
param0_grad = np.where(m[:] > 0,grad1,grad2)
根据我对 np.where()
的理解,我认为这应该根据 param0_grad
中的每个值使用 grad1
或 grad2
填充 m
。但是,我收到以下广播错误(当 x=3、y=4、z=2 时):
ValueError: operands Could not be broadcast together with shapes (3,) (3,4,2) (3,2)
代码适用于 x=2,但没有 x>2 的值。
解决方法
试试这个:
param0_grad = np.where(m[:,None,None] > 0,grad1,grad2)
基本上,您需要添加一个空维度来广播尾随轴。