问题描述
基本上我有数组:
ids = [134543,...,234]
a = [123,3546]
b = [[435,549][245,4986]]
长度相同(即ids.shape = (600000,)
a.shape = (600000,)
b.shape = (600000,2)
)
以及较小的注释 id 数组(长度约为 100)
ids_important = [345,549]
我想在我的 ids 数组中找到重要 id 的索引,然后输出 a 和 b 中对应的元素。
我目前的算法是:
for i in range(ids_important.shape[0]):
for j in range(ids.shape[0]):
if(ids[j] == ids_important[i]):
print(a[j])
print(b[j,0])
考虑到数组的大小,这个算法非常慢。有人告诉我我可以使用掩码数组来改进它,但一直无法弄清楚如何实现它。非常感谢您的帮助。
解决方法
这可能不是最节省空间的方式(分配一个 len(important_ids) x len(ids)
数组),但它应该比其他答案和您的原始方法快得多,因为它利用了numpy 向量化(并避免慢循环):
import numpy as np
ids = np.random.randint(0,15,size=20)
important_ids = np.random.randint(5,9,size=10)
a = np.random.randint(10,30,size=50)
b = np.random.randint(10,size=(50,2))
equals = np.argwhere(ids[None,:] == important_ids[:,None])
inds = equals[:,1]
print(a[inds])
print(b[tuple(inds),0])
,
如果您对重要的 id 和数组进行排序,则每次搜索时都不必遍历所有 id。
ids_important.sort()
# Sorting an array based on ids of another
# X = ["a","b","c","d","e","f","g","h","i"]
# Y = [ 0,1,2,1]
# Z = [x for _,x in sorted(zip(Y,X))]
# print(Z) # ["a","i","g"]
a_sorted = [elmt for _,elmt in sorted(zip(ids,a))]
b_sorted = [[elmt for _,b[0]))],[elmt for _,b[1]))]]
然后,您可以简单地迭代 ids 和 important_ones,这将是线性运行时。
i,j = 0,0
l = len(important_ids)
while i < l and j < 600000:
if ids[j] < important_ids[i]:
j += 1
continue
if ids[j] == important_ids[i]:
print(a[j])
print(b[j,0])
i += 1