在数组中搜索一个值,然后从其他相同长度的数组/ndarrays 打印相应的值

问题描述

我有一个需要改进的低效算法。

基本上我有数组:

ids = [134543,...,234]
a = [123,3546]
b = [[435,549][245,4986]]

长度相同(即ids.shape = (600000,) a.shape = (600000,) b.shape = (600000,2)

以及较小的注释 id 数组(长度约为 100)

ids_important = [345,549]

我想在我的 ids 数组中找到重要 id 的索引,然后输出 a 和 b 中对应的元素。

我目前的算法是:

for i in range(ids_important.shape[0]):
    for j in range(ids.shape[0]):
        if(ids[j] == ids_important[i]):
           print(a[j])
           print(b[j,0])

考虑到数组的大小,这个算法非常慢。有人告诉我我可以使用掩码数组来改进它,但一直无法弄清楚如何实现它。非常感谢您的帮助。

解决方法

这可能不是最节省空间的方式(分配一个 len(important_ids) x len(ids) 数组),但它应该比其他答案和您的原始方法快得多,因为它利用了numpy 向量化(并避免慢循环):

import numpy as np
ids = np.random.randint(0,15,size=20)
important_ids = np.random.randint(5,9,size=10)
a = np.random.randint(10,30,size=50)
b = np.random.randint(10,size=(50,2))

equals = np.argwhere(ids[None,:] == important_ids[:,None])
inds = equals[:,1]
print(a[inds])
print(b[tuple(inds),0])
,

如果您对重要的 id 和数组进行排序,则每次搜索时都不必遍历所有 id。

    ids_important.sort()
    # Sorting an array based on ids of another
    # X = ["a","b","c","d","e","f","g","h","i"]
    # Y = [ 0,1,2,1]

    # Z = [x for _,x in sorted(zip(Y,X))]
    # print(Z)  # ["a","i","g"]
    a_sorted = [elmt for _,elmt in sorted(zip(ids,a))]
    b_sorted = [[elmt for _,b[0]))],[elmt for _,b[1]))]]

然后,您可以简单地迭代 ids 和 important_ones,这将是线性运行时。

i,j = 0,0
l = len(important_ids)
while i < l and j < 600000:
    if ids[j] < important_ids[i]:
        j += 1
        continue
    if ids[j] == important_ids[i]:
        print(a[j])
        print(b[j,0])
    i += 1