如何用现有的分片tfrecords替换tfds数据集

问题描述

我正在使用使用tfds数据集的克隆代码,并希望将其适应于已存在的分片tfrecrod集合,并进行尽可能少的修改

具体来说,克隆的代码将执行以下操作:

builder = tfds.builder(dataset,data_dir)
builder.download_and_prepare()
...
estimator.train(
        data_lib.build_input_fn(builder,True),max_steps=train_steps
)

在此代码中,“数据集”是tfds数据集的名称(例如cifar10或others)。代替, 我想训练一个已经以分片tfrecords形式存在的外部数据集,即:

'train_ - .tfrecords'

'val_ - .tfrecords'

并驻留在存储桶中(如果该信息有帮助,则存储在Google云上)。

我一直在研究Adding new datasets in TFDS format,但这似乎需要一个完整的管道来从头开始生成tfrecords,这是不可能的,并且鉴于tfrecords已经存在,这似乎是多余的。我确定我会缺少对现有tfrecords的一些简单修改。.

任何建议将不胜感激。

解决方法

阿罗娜,

您的期望是正确的:有一个特殊功能tf.data.TFRecordDataset用于处理tfrecords中的数据。像这样在您的input_fn中使用它:

def input_fn(features,labels,training=True,batch_size=256):
    
    file_paths = [file0,file1]  # pass tfrecords filenames here
    dataset = tf.data.TFRecordDataset(file_paths)

    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()
    
    return dataset.batch(batch_size)

在TF网站上了解更多信息:1 2