问题描述
假设我有以下数组数组:
Input = np.array([[[[17.63,0.,-0.71,29.03],[17.63,-0.09,0.71,56.12],[ 0.17,1.24,-2.04,18.49],[ 1.41,-0.8,0.51,11.85],[ 0.61,-0.29,0.15,36.75]]],[[[ 0.32,-0.14,0.39,24.52],[ 0.18,0.25,-0.38,18.08],[ 0.,0. ],[ 0.43,0.3,0. ]]],[[[ 0.75,0.65,19.51],[ 0.37,0.27,0.52,24.27],0. ]]]])
Input.shape
(3,1,5,4)
与此Input
数组一起的是所有输入的相应Label
数组,因此:
Label = np.array([0,2])
Label.shape
(3,)
我需要某种方法来检查Input
的所有嵌套数组,以仅选择具有足够数据点的数组。
通过这个,我的意思是我想要一种消除(或者应该说删除)所有后三行条目均为零的数组的方法。同时还要消除该数组的相应Label
。
预期输出:
Input_filtered
array([[[[17.63,0. ]]]])
Label_filtered
array([0,1])
我需要什么技巧?
解决方法
您应该只能使用矢量化numpy命令执行此操作。
filter = np.any(Input[:,:,-3:],axis=(1,2,3))
Label_filtered = Label[filter]
Input_filtered = Input[[filter]]
对于示例集,您提供的每个循环(每个100000个循环)产生4.95 µs±9.69 ns,而anon01的解决方案是每个循环17.1 µs±111 ns(每个100000个循环)。在更大的阵列上,改进应该更加明显。
如果数据的维数不同,则可以更改axis参数。 对于任意数量的轴,它可能如下所示:
filter = np.any(Input[:,axis=tuple(range(1,Input.ndim)))
,
执行此操作的最佳方法取决于数据规模。如果子数组很少(数千个或更少),则可以生成一个应用于Label和Input数组的过滤器列表:
filter = []
for j in range(len(Input)):
arr = Input[j,-3:]
filter.append(np.any(arr))
Label_filtered = Label[filter]
Input_filtered = Input[[filter]]
需要注意的几件事:向量化/ numpy位(Input[j,-3]
,np.any(arr)
)非常快速,而本机python迭代和列表用法({{1 }},for j in range
)非常慢。