问题描述
假设我有一个数组 a
和一个布尔数组 b
,我想从 a
的每一行中的有效元素中提取固定数量的元素。有效元素是由 b
指示的元素。
这是一个例子:
a = np.arange(24).reshape(4,6)
b = np.array([[0,1,0],[0,1],1]]).astype(bool)
x = []
for i in range(a.shape[0]):
c = a[i,b[i]]
d = np.random.choice(c,2)
x.append(d)
这里我使用了一个 for 循环,如果这些数组很大并且是高维的,它会很慢。有没有更有效的方法来做到这一点?谢谢。
解决方法
- 生成形状为
a
的随机均匀 [0,1] 矩阵。 - 将此矩阵乘以掩码
b
以将无效元素设置为零。 - 从每行中选择
k
个最大索引(仅从该行中的有效元素模拟无偏随机k
样本)。 - (可选)使用这些索引来获取元素。
a = np.arange(24).reshape(4,6)
b = np.array([[0,1,0],[0,1],1]])
k = 2
r = np.random.uniform(size=a.shape)
indices = np.argpartition(-r * b,k)[:,:k]
从索引中获取元素:
>>> indices
array([[3,2],[5,[3,[4,5]])
>>> a[np.arange(a.shape[0])[:,None],indices]
array([[ 3,[11,7],[15,14],[22,23]])