一种遍历多维数组的有效方法?

问题描述

我试图找到一种方法来对多个2D数组中的每个元素执行操作,而不必遍历它们。或者至少不需要两个for循环。我的代码计算一系列图像(阵列)上每个像素的标准偏差。现在,图像的数量已不成问题,这是数组的大小,这使得代码花费的速度非常慢。以下是我所拥有的工作示例。

import numpy as np

# reshape(# of image (arrays),# of rows,# of cols) 
a = np.arange(32).reshape(2,4,4)

stddev_arr = np.array([])
for i in range(4):
    for j in range(4): 
        pixel = a[0:,i,j]
        stddev = np.std(pixel) 
        stddev_arr = np.append(stddev_arr,stddev)

我的实际数据是2000x2000,使此代码循环4000000次。有一个更好的方法吗? 任何建议都非常感谢。

解决方法

您已经在使用numpy。 numpy的std()函数接受一个axis参数,该参数告诉它您希望对其进行操作的轴(在本例中为第零轴)。因为这会将计算工作转移到numpy的C后端(并且您的处理器可能使用SIMD optimizationsvectorize a lot of operations来处理),所以它的速度比迭代快 。代码中的另一个耗时的操作是附加到stddev_arr时。追加到numpy数组是 slow ,因为 entire数组在添加新元素之前已复制到新的内存中。现在您已经知道该数组需要多大,因此您不妨对其进行预分配。

a = np.arange(32).reshape(2,4,4)
stdev = np.std(a,axis=0)

这给出了一个4x4数组

array([[8.,8.,8.],[8.,8.]])

要将其展平为一维数组,请执行flat_stdev = stdev.flatten()

比较执行时间:

# Using only numpy
def fun1(arr):
    return np.std(arr,axis=0).flatten()

# Your function
def fun2(arr):
    stddev_arr = np.array([])
    for i in range(arr.shape[1]):
        for j in range(arr.shape[2]): 
            pixel = arr[0:,i,j]
            stddev = np.std(pixel) 
            stddev_arr = np.append(stddev_arr,stddev)
    return stddev_arr


# Your function,but pre-allocating stddev_arr
def fun3(arr):
    stddev_arr = np.zeros((arr.shape[1] * arr.shape[2],))
    x = 0
    for i in range(arr.shape[1]):
        for j in range(arr.shape[2]): 
            pixel = arr[0:,j]
            stddev = np.std(pixel) 
            stddev_arr[x] = stddev
            x += 1
    return stddev_arr

首先,让我们确保所有这些功能都等效:

a = np.random.random((3,10,10))
assert np.all(fun1(a) == fun2(a))
assert np.all(fun1(a) == fun3(a))

是的,所有结果都相同。现在,让我们尝试使用更大的数组。

a = np.random.random((3,100,100))

x = timeit.timeit('fun1(a)',setup='from __main__ import fun1,a',number=10)
# x: 0.003302899989648722

y = timeit.timeit('fun2(a)',setup='from __main__ import fun2,number=10)
# y: 5.495519500007504

z = timeit.timeit('fun3(a)',setup='from __main__ import fun3,number=10)
# z: 3.6250679999939166

哇!仅通过预分配,我们就能获得约1.5倍的加速。 更令人惊叹:将numpy的std()axis参数一起使用可以使速度提高1000倍以上,而这仅适用于100x100数组!使用更大的数组,您可以期望看到更大的加速。

,

因此,根据您提供的内容,您可以用另一种方式来重塑数组,以向量化它来替换两个循环。然后,您只需在所需的轴上使用一次np.std

a = np.arange(32).reshape(2,4)

a = a.reshape(2,-1).transpose()

stddev_arr = np.std(a,axis=1)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...