问题描述
假设我创建了以下 numpy 数组:
a = np.array([[3,3],[1,3,1,3]])
b = np.array([0,2,3])
如果我做 a == 0
我得到:
array([[False,True,False],[False,False,False]])
如果我做 a == 1
我得到:
array([[False,[ True,False]])
等等。但是,如果我想获得一个包含与所有条件 a == n
相关的所有掩码的数组,其中 n
属于 b
,我应该如何进行?
np.array([a == n for n in b])
做了我想要的但看起来不太numpythonic。我也试过 a == b
,它只返回 False
。
解决方法
只需a == b[:,None,None]
,其余的由广播处理:
>>> a == b[:,None]
array([[[False,True,False],[False,False,False]],[[False,[ True,[[ True,True],True]]])