numpy 索引 ndarray[(4, 2), (5, 3)] 的说明

问题描述

问题

请帮助理解 Numpy 将元组 (i,j) 索引到 ndarray 中的设计决策或合理性。

背景

当索引为单元组 (4,2) 时,则 (i=row,j=column)。

shape = (6,7)
X = np.zeros(shape,dtype=int)
X[(4,2)] = 1
X[(5,3)] = 1
print("X is :\n{}\n".format(X))
---
X is :
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 1 0 0 0 0]    <--- (4,2)
 [0 0 0 1 0 0 0]]   <--- (5,3)

然而,当索引是多个元组 (4,2),(5,3) 时,则 (i=row,j=row) for (4,2) 和 (i=column,j=column) for (5,3).

shape = (6,7)
Y = np.zeros(shape,dtype=int)
Y[(4,3)] = 1
print("Y is :\n{}\n".format(Y))
---
Y is :
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 1 0 0 0]    <--- (2,3)
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 1 0]    <--- (4,5)
 [0 0 0 0 0 0 0]]

这意味着您正在构建一个二维数组 R,例如 R=A[B,C]。 这意味着对于 rij=abijcij.

所以这意味着位于 R[0,0] 的项目是 A 中的项目 作为行索引 B[0,0] 和列索引 C[0,0]。项目 R[0,1]A 中具有行索引 B[0,1] 并作为列索引的项目 C[0,1]

multi_index:整数数组元组每个维度一个数组

为什么不总是 (i=row,j=column)?如果总是 (i=row,j=column) 会怎样?


更新

根据 Akshay 和 @DaniMesejo 的回答,理解:

X[
  (4),# dimension 2 indices with only 1 element
  (2)     # dimension 1 indices with only 1 element
] = 1

Y[
  (4,2,...),# dimension 2 indices 
  (5,3,...)  # dimension 1 indices (dimension 0 is e.g. np.array(3) whose shape is (),in my understanding)
] = 1

解决方法

很容易理解它是如何工作的(以及这个设计决策背后的动机)。

Numpy 将其 ndarray 存储为连续的内存块。每个元素在前一个之后每 n 个字节以顺序方式存储。

(从此 excellent SO post 引用的图像)

所以如果你的 3D 数组看起来像这样 -

enter image description here

然后在内存中存储为 -

enter image description here

当检索一个元素(或一个元素块)时,NumPy 会计算它需要遍历多少个 strides(字节)才能得到下一个元素 in that direction/axis。因此,对于上面的示例,对于 axis=2 它必须遍历 8 个字节(取决于 datatype)但对于 axis=1 它必须遍历 8*4 个字节,并且 {{ 1}} 它需要 axis=0 个字节。

考虑到这一点,让我们看看您正在尝试做什么。

8*8
print(X)
print(X.strides)

对于您的数组,要获取 [[0 0 0 0 0 0 0] [0 0 0 0 0 0 0] [0 0 0 0 0 0 0] [0 0 0 0 0 0 0] [0 0 1 0 0 0 0] [0 0 0 1 0 0 0]] #Strides (bytes) required to traverse in each axis. (56,8) 中的下一个元素,我们需要遍历 axis=0,而对于 56 bytes 中的下一个元素,我们需要 axis=1

当您索引 8 bytes 时,NumPy 将 (4,2) 中的 56*4 字节和 axis=0 中的 8*2 字节用于访问。同样,如果您想访问 axis=1(4,2),则必须访问 (5,3) 中的 56*(4,5)axis=0 中的 8*(2,3)。>

这就是设计如此的原因,因为它与 NumPy 使用 axis=1 实际索引元素的方式一致。

strides
X[(axis0_indices),(axis1_indices),..]

X[(4,5),(2,3)] #(row indices),(column indices)

通过这种设计,也可以轻松扩展到更高维度的张量(例如 8 维数组)! 如果你分别提到每个索引元组,它将需要元素 * 计算的维度数来获取这些。虽然使用这种设计,它可以将步幅值广播到每个轴的元组,并更快地获取这些值!