使用tf.data.Dataset时,Model.fit方法的改组如何处理批次?

问题描述

我正在使用tensorflow 2。

$LogFolder方法与Model.fit()一起使用时,参数'tf.data.Dataset'被忽略。因此,要按批次训练我的模型,我必须先通过调用batch_size将样本数据集更改为样本批次数据集。

然后,在阅读了文档之后,我不明白tf.data.Dataset.batch(batch_size)方法将如何在每个时期重新整理我的数据集。

由于我的数据集是批次的数据集,它会在各个批次之间进行随机排序吗(批次保持不变)?还是会随机整理所有样本,然后将它们重新分组为新批次(这是所需的行为)

非常感谢您的帮助。

解决方法

使用shuffle API时,fit参数对tf.data.Dataset函数无效。

如果我们读了documentation(强调是我的):

shuffle:布尔值(是否在每个时期之前对训练数据进行随机化)或str(对于“批处理”)。 当x是生成器时,将忽略此参数。“批处理”是处理HDF5数据限制的特殊选项;它以批量大小的块洗牌。当steps_per_epoch不为None时无效。

这不是很清楚,但是我们可以暗示使用tf.data.Dataset时,shuffle参数将被忽略,因为它的行为类似于生成器。

可以肯定的是,让我们深入研究代码。如果我们查看fit方法的代码,您将看到数据是由特殊的类DataHandler处理的。查看此类的代码,我们看到这是一个Adapter类,用于处理不同类型的数据。我们陷入了处理tf.data.Dataset DatasetAdapter的类中,并且可以看到该类没有考虑shuffle参数:

  def __init__(self,x,y=None,sample_weights=None,steps=None,**kwargs):
    super(DatasetAdapter,self).__init__(x,y,**kwargs)
    # Note that the dataset instance is immutable,its fine to reuse the user
    # provided dataset.
    self._dataset = x

    # The user-provided steps.
    self._user_steps = steps

    self._validate_args(y,sample_weights,steps)

如果要改组数据集,请使用tf.data.Dataset API中的shuffle函数。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...