问题描述
在具有Tensorflow-2.0.0的jupyter笔记本上,以这种方式执行了80-10-10的火车验证测试拆分:
import tensorflow_datasets as tfds
from os import getcwd
splits = tfds.Split.ALL.subsplit(weighted=(80,10,10))
filePath = f"{getcwd()}/../tmp2/"
splits,info = tfds.load('fashion_mnist',with_info=True,as_supervised=True,split=splits,data_dir=filePath)
AttributeError: type object 'Split' has no attribute 'ALL'
我已经看到我可以用这种方式创建两个集合:
splits,split=['train[:80]','test[80:90]'],data_dir=filePath)
但是我不知道如何添加第三组。
解决方法
tfds.Split.ALL.subsplit
或tfds.Split.TRAIN.subsplit
显然已被弃用,不再受支持。
一些数据集已经在训练和测试之间分配。在这种情况下,我找到了以下解决方案(例如使用时尚的MNIST数据集):
splits,info = tfds.load('fashion_mnist',with_info=True,as_supervised=True,split=['train+test[:80]','train+test[80:90]','train+test[90:]'],data_dir=filePath)
(train_examples,validation_examples,test_examples) = splits
,
在 tfds 上的 Rock_paper_scissor 数据集的情况下,它对我有用:
splits = ['train+test[:80]','train+test[90:]']
splits,info = tfds.load( 'rock_paper_scissors',split=splits,with_info=True)
(train_examples,test_examples) = splits
num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes