如何解析Google Quick,Draw中的.npy文件!卷积网络具有deeplearning4j的数据集?

问题描述

我刚开始使用dl4j。看完一些示例后,我认为我可以实现一个简单的卷积网络,对Google Quick,Draw中的图像子集进行分类!数据集。我从示例中使用的DataIterator是针对特定数据集(MNIST,CIFAR等)定制的。

我将如何创建一个自定义DataIterator,将一堆.npy文件转换为可以加载到网络中的格式。更重要的是,这是正确的方法吗?

解决方法

Deeplearning4j可以解析numpy输出。 您可以在此处使用各种numpy方法: https://github.com/eclipse/deeplearning4j/blob/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java#L5654

示例: INDArray arr = Nd4j.createFromNpyFile(yourFile);

数据集只是2个ndarray,因此您可以通过以下方式创建数据集: DataSet d =新的DataSet(功能,标签)

datasetiterator只是返回数据集。您需要做的就是使用上面的npy文件创建在数据集中读取数据集迭代器接口。