ndarray的条件过滤

问题描述

假设我有以下数组数组:

Input = np.array([[[[17.63,0.,-0.71,29.03],[17.63,-0.09,0.71,56.12],[ 0.17,1.24,-2.04,18.49],[ 1.41,-0.8,0.51,11.85],[ 0.61,-0.29,0.15,36.75]]],[[[ 0.32,-0.14,0.39,24.52],[ 0.18,0.25,-0.38,18.08],[ 0.,0.  ],[ 0.43,0.3,0.  ]]],[[[ 0.75,0.65,19.51],[ 0.37,0.27,0.52,24.27],0.  ]]]])

Input.shape
(3,1,5,4)

与此Input数组一起的是所有输入的相应Label数组,因此:

Label = np.array([0,2])

Label.shape
(3,)

我需要某种方法来检查Input的所有嵌套数组,以仅选择具有足够数据点的数组。

通过这个,我的意思是我想要一种消除(或者应该说删除)所有后三行条目均为零的数组的方法。同时还要消除该数组的相应Label

预期输出:

Input_filtered
array([[[[17.63,0.  ]]]])

Label_filtered
array([0,1])

我需要什么技巧?

解决方法

您应该只能使用矢量化numpy命令执行此操作。

filter = np.any(Input[:,:,-3:],axis=(1,2,3))
Label_filtered = Label[filter]
Input_filtered = Input[[filter]]

对于示例集,您提供的每个循环(每个100000个循环)产生4.95 µs±9.69 ns,而anon01的解决方案是每个循环17.1 µs±111 ns(每个100000个循环)。在更大的阵列上,改进应该更加明显。

如果数据的维数不同,则可以更改axis参数。 对于任意数量的轴,它可能如下所示:

filter = np.any(Input[:,axis=tuple(range(1,Input.ndim)))
,

执行此操作的最佳方法取决于数据规模。如果子数组很少(数千个或更少),则可以生成一个应用于Label和Input数组的过滤器列表:

filter = []
for j in range(len(Input)):
    arr = Input[j,-3:]
    filter.append(np.any(arr))
Label_filtered = Label[filter]
Input_filtered = Input[[filter]]

需要注意的几件事:向量化/ numpy位(Input[j,-3]np.any(arr)非常快速,而本机python迭代和列表用法({{1 }},for j in range非常慢。

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...