如何使用每行的索引矩阵索引矩阵的行元素?

问题描述

我有一个索引矩阵,例如

I = np.array([[1,2],[2,1,0]])

第 i 行的索引从第 i 行的另一个矩阵 M 中选择一个元素。

所以有我,例如

M = np.array([[6,7,8],[9,10,11])

M[I] 应该选择:

[[7,6,[11,9]]

我本来可以:

I1 = np.repeat(np.arange(0,I.shape[0]),I.shape[1])
I2 = np.ravel(I)
Result = M[I1,I2].reshape(I.shape)

但这看起来很复杂,我正在寻找更优雅的解决方案。最好不要压平和整形。

在示例中我使用了 numpy,但实际上我使用的是 jax。所以如果jax有更高效的解决方案,欢迎分享

解决方法

这一行代码怎么样?这个想法是枚举矩阵的行和行索引,以便您可以访问索引矩阵中的相应行。

import numpy as np

I = np.array([[1,2],[2,1,0]])
M = np.array([[6,7,8],[9,10,11]])

Result = np.array([row[I[i]] for i,row in enumerate(M)])
print(Result)

输出:

[[ 7  6  8]
 [11 10  9]]
,
In [108]: I = np.array([[1,0]])
     ...: M = np.array([[6,11]])
     ...: 
     ...: I,M

我必须给 M 添加一个“]”。

Out[108]: 
(array([[1,0]]),array([[ 6,[ 9,11]]))

使用 broadcasting 进行高级索引:

In [110]: M[np.arange(2)[:,None],I]
Out[110]: 
array([[ 7,6,[11,9]])

第一个索引具有形状 (2,1),它与 ​​I 的 (2,3) 形状配对以选择 (2,3) 值块。

,

np.take_along_axis 也可以在此处使用 M 上的索引 I 来获取 axis=1 的值:

>>> np.take_along_axis(M,I,axis=1)

array([[ 7,9]])