问题描述
我正在使用使用tfds数据集的克隆代码,并希望将其适应于已存在的分片tfrecrod集合,并进行尽可能少的修改。
具体来说,克隆的代码将执行以下操作:
builder = tfds.builder(dataset,data_dir)
builder.download_and_prepare()
...
estimator.train(
data_lib.build_input_fn(builder,True),max_steps=train_steps
)
在此代码中,“数据集”是tfds数据集的名称(例如cifar10或others)。代替, 我想训练一个已经以分片tfrecords形式存在的外部数据集,即:
'train_
'val_
并驻留在存储桶中(如果该信息有帮助,则存储在Google云上)。
我一直在研究Adding new datasets in TFDS format,但这似乎需要一个完整的管道来从头开始生成tfrecords,这是不可能的,并且鉴于tfrecords已经存在,这似乎是多余的。我确定我会缺少对现有tfrecords的一些简单修改。.
任何建议将不胜感激。
解决方法
阿罗娜,
您的期望是正确的:有一个特殊功能tf.data.TFRecordDataset
用于处理tfrecords中的数据。像这样在您的input_fn中使用它:
def input_fn(features,labels,training=True,batch_size=256):
file_paths = [file0,file1] # pass tfrecords filenames here
dataset = tf.data.TFRecordDataset(file_paths)
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)