问题描述
我正在使用 Deeplearning4j 和 datavec,我有一个 DataSetIterator 对象代表我的所有数据,这是一个时间序列。如何将其拆分为训练和测试迭代器?我检查了一下,不推荐使用 DataSetIterator 类的方法。谢谢。
解决方法
遍历您的 DataSetIterator
,并为每个 DataSet
条目创建两个新的 DataSets
,分别用于训练和测试。
关键是使用 splitTestAndTrain 方法,它接受一个 double fractionTrain
来指定要训练的数据量(其余的要测试)。该方法有不同的重载,因此您可以选择最适合您需要的方法。如果您希望将所有训练和测试数据集添加到一个公共迭代器中,您可以将它们存储在两个不同的列表中,并在以后获取它们对应的迭代器。类似的东西:
List<DataSet> trainList = new ArrayList<>();
List<DataSet> testList= new ArrayList<>();
while (yourDataSetIterator.hasNext())
{
DataSet ds = yourDataSetIterator.next();
SplitTestAndTrain splData = ds.splitTestAndTrain(0.5); //half for each
DataSet trainDs = splData.getTrain();
trainList.add(trainDs);
DataSet testDs = splData.getTest();
testList.add(testDs);
(...)
}
Iterator<DataSet> trainIterator = trainList.iterator();
Iterator<DataSet> testIterator = testList.iterator();
由于我真的不知道这个库的具体细节,所以这个例子只是创建了“基本的”iterators
。这可能是自定义的,因此您可以创建 DataSetIterators
。
请注意,您可能还需要在拆分 DataSet 之前对其进行混洗 (ds.shuffle()
)。你可以找到一些例子here
如果你想以特定的方式分割它,你可以标记不同的条目并找到测试数据集的最大索引;然后,调用 splitTestAndTrain(int max)
方法,该方法专门针对 max 参数拆分数据集。 sortByLabel
方法在这里也很有用。
Adam Gibson
对关于其他机制的评论提出了很好的建议,以拆分 DataSetIterator
,这似乎也是一种“更自然”的方式来做到这一点,DataSetIteratorSplitter
.
它提供了 getTrainIterator()
和 getTestIterator()
方法,它们返回库的特定迭代器 DataSetIterator
。