问题描述
我正在使用tensorflow 2。
将$LogFolder
方法与Model.fit()
一起使用时,参数'tf.data.Dataset
'被忽略。因此,要按批次训练我的模型,我必须先通过调用batch_size
将样本数据集更改为样本批次数据集。
然后,在阅读了文档之后,我不明白tf.data.Dataset.batch(batch_size)
方法将如何在每个时期重新整理我的数据集。
由于我的数据集是批次的数据集,它会在各个批次之间进行随机排序吗(批次保持不变)?还是会随机整理所有样本,然后将它们重新分组为新批次(这是所需的行为)?
非常感谢您的帮助。
解决方法
使用shuffle
API时,fit
参数对tf.data.Dataset
函数无效。
如果我们读了documentation(强调是我的):
shuffle:布尔值(是否在每个时期之前对训练数据进行随机化)或str(对于“批处理”)。 当x是生成器时,将忽略此参数。“批处理”是处理HDF5数据限制的特殊选项;它以批量大小的块洗牌。当steps_per_epoch不为None时无效。
这不是很清楚,但是我们可以暗示使用tf.data.Dataset
时,shuffle参数将被忽略,因为它的行为类似于生成器。
可以肯定的是,让我们深入研究代码。如果我们查看fit
方法的代码,您将看到数据是由特殊的类DataHandler
处理的。查看此类的代码,我们看到这是一个Adapter类,用于处理不同类型的数据。我们陷入了处理tf.data.Dataset DatasetAdapter
的类中,并且可以看到该类没有考虑shuffle
参数:
def __init__(self,x,y=None,sample_weights=None,steps=None,**kwargs):
super(DatasetAdapter,self).__init__(x,y,**kwargs)
# Note that the dataset instance is immutable,its fine to reuse the user
# provided dataset.
self._dataset = x
# The user-provided steps.
self._user_steps = steps
self._validate_args(y,sample_weights,steps)
如果要改组数据集,请使用tf.data.Dataset
API中的shuffle函数。