np.isin - 考虑顺序测试 Numpy 数组是否包含给定的行

问题描述

我正在使用以下行来查找 b 的行是否在 a

 a[np.all(np.isin(a[:,0:3],b[:,0:3]),axis=1),3]

数组沿 axis=1 有更多条目,我只比较前 3 个条目并返回 a 的第四个条目 (idx=3)。

我意识到的可能错误是,没有考虑条目的顺序。因此,ab 的以下示例:

a = np.array([[...],[1,2,3,1000],[2,1,2000],[...]])

b = np.array([[1,3]])

将返回 [1000,2000] 而不是只有 [1000]

如何同时考虑行的顺序?

解决方法

对于小 b(少于 100 行),试试这个:

a[(a[:,:3] == b[:,None]).all(axis=-1).any(axis=0)]

示例:

a = np.array([[1,5,0],[1,2,3,1000],[2,1,2000],[0,1]])

b = np.array([[1,3],1]])

>>> a[(a[:,None]).all(axis=-1).any(axis=0),3]
array([1000,1])

说明:

关键是将 a 的所有行(前 3 列)的相等性测试“分发”到 b 的所有行:

# on the example above

>>> a[:,None]
array([[[ True,False,False],[ True,True,True],# <-- a[1,:3] matches b[0]
        [False,[False,False]],[[False,True]]])  # <-- a[3,:3] matches b[1]

请注意,这可能很大:形状为 (len(b),len(a),3)

然后第一个 .all(axis=-1) 表示我们希望所有整行都匹配:

>>> (a[:,None]).all(axis=-1)
array([[False,True]])

最后一位 .any(axis=0) 表示:“匹配 b 中的任何行”:

>>> (a[:,None]).all(axis=-1).any(axis=0)
array([False,True])

即:“a[2,:3] 匹配 someb 以及 a[3,:3]”。

最后,将其用作 a 中的掩码并取第 3 列。

性能注意事项

上述技术将 a 行与 b 行的乘积相等。如果 ab 都有很多行,这会很慢并且会使用大量内存。

作为替代方案,您可以在纯 Python 中使用 set 成员资格(无需对列进行子集设置 -- 这可以由调用者完成):

def py_rows_in(a,b):
    z = set(map(tuple,b))
    return [row in z for row in map(tuple,a)]

b 有超过 50~100 行时,这可能比上面的 np 版本更快,这里写成一个函数:

def np_rows_in(a,b):
    return (a == b[:,None]).all(axis=-1).any(axis=0)
import perfplot

fig,axes = plt.subplots(ncols=2,figsize=(16,5))
plt.subplots_adjust(wspace=.5)
for ax,alen in zip(axes,[100,10_000]):
    a = np.random.randint(0,20,(alen,4))
    plt.sca(ax)
    ax.set_title(f'a: {a.shape[0]:_} rows')
    perfplot.show(
        setup=lambda n: np.random.randint(0,(n,3)),kernels=[
            lambda b: np_rows_in(a[:,:3],b),lambda b: py_rows_in(a[:,],labels=['np_rows_in','py_rows_in'],n_range=[2 ** k for k in range(10)],xlabel='len(b)',)
plt.show()

comparative performance