问题描述
在训练时,我将 epochs 设置为迭代数据的次数。我想知道当我已经可以用 tf.data.Datasets.repeat(EPOCHS)
做同样的事情时 model.fit(train_dataset,epochs=EPOCHS)
有什么用?
解决方法
它的工作方式略有不同。
让我们选择 2 个不同的例子。
- dataset.repeat(20) 和 model.fit(epochs=10)
- dataset.repeat(10) 和 model.fit(epochs=20)
我们还假设您有一个包含 100 条记录的数据集。
如果您选择选项 1,则每个 epoch 将有 2,000 条记录。在通过模型传递 2,000 条记录后,您将“检查”模型的改进情况,并且将执行 10 次。
如果选择选项 2,每个 epoch 将有 1,000 条记录。您将评估您的模型在推送 1000 条记录后的改进情况,并且您将执行 20 次。
在这两个选项中,您将用于训练的记录总数相同,但您评估、记录等模型行为的“时间”不同。
,tf.data.Datasets.repeat()
在图像数据的情况下可用于 tf.data.Datasets
上的数据增强。
假设您想增加训练数据集中的图像数量,使用随机变换然后重复训练数据集 count
次并应用随机变换,如下所示
train_dataset = (
train_dataset
.map(resize,num_parallel_calls=AUTOTUNE)
.map(rescale,num_parallel_calls=AUTOTUNE)
.map(onehot,num_parallel_calls=AUTOTUNE)
.shuffle(BUFFER_SIZE,reshuffle_each_iteration=True)
.batch(BATCH_SIZE)
.repeat(count=5)
.map(random_flip,num_parallel_calls=AUTOTUNE)
.map(random_rotate,num_parallel_calls=AUTOTUNE)
.prefetch(buffer_size=AUTOTUNE)
)
如果没有 repeat() 方法,您必须创建数据集的副本,单独应用转换,然后连接数据集。但是使用repeat() 可以简化这一点,还利用了方法链的优势,并且代码看起来很整洁。
关于数据增强的更多信息:https://www.tensorflow.org/tutorials/images/data_augmentation#apply_augmentation_to_a_dataset