过滤Numpy的数组数组

问题描述

使用numpy的ndarray将数据预处理到神经网络。它基本上包含用于传感器数据的几个固定长度的数组。例如:

>>> type(arr)
<class 'numpy.ndarray'>

>>> arr.shape
(400,1,5,4)

>>> arr
 [
  [[ 9.4 -3.7 -5.2  3.8]
   [ 2.8  1.4 -1.7  3.4]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]]
  ..
  [[ 0.0 -1.0  2.1  0.0]
   [ 3.0  2.8 -3.0  8.2]
   [ 7.5  1.7 -3.8  2.6]
   [ 0.0  0.0  0.0  0.0]
   [ 0.0  0.0  0.0  0.0]]
 ]

每个嵌套数组的形状为(1,4)。目标是遍历此arr并仅将至少具有前三行的那些数组选择为非零(尽管单个条目可以为零,但不能整行)。

因此,在上面给出的示例中,应该删除第一个嵌套数组,因为只有2个第一行非零,而我们需要3个及以上。

解决方法

这是您可以使用的技巧:

mask = arr[:,:,:3].any(axis=3).all(axis=2)
arr_filtered = arr[mask]

快速说明:要保留一个嵌套数组,它应至少有3个第一行(因此我们只需要查看arr[:,:3]),以使所有它们(因此.all(axis=2)结尾)都具有至少一个非零条目(因此.any(axis=3))。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...