保存改组后的数据集状态

问题描述

我正在寻找一种保存tf.data.Dataset.shuffle使用的随机状态的机制。对于上下文,我希望能够在重新启动后重现训练结果。

我有一个解决方案(如下所示),但是它不是特别优雅,我非常有信心batch / unbatch会导致性能问题。是否有使用Dataset.shuffle的等效方法

import tensorflow as tf
import numpy as np


class Shuffler(tf.Module):
    def __init__(self,buffer_size: int,seed: int = 0):
        self._buffer_size = buffer_size
        self._seed = seed
        self._rng = tf.random.Generator.from_seed(seed)

    def __call__(self,dataset: tf.data.Dataset):
        def map_fn(*args):
            vals = self._rng.uniform((self._buffer_size,))
            i = tf.argsort(vals)
            if len(args) == 1:
                (args,) = args
            return tf.nest.map_structure(lambda x: tf.gather(x,i),args)

        return dataset.batch(self._buffer_size).map(map_fn).unbatch()


def as_list(ds: tf.data.Dataset):
    return [x.numpy() for x in ds]


shuffler = Shuffler(5)
chkpt = tf.train.Checkpoint(shuffler=shuffler)
p0 = chkpt.save("/tmp/chkpt-0")
ds = tf.data.Dataset.range(5).apply(shuffler)
expected0 = as_list(ds)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(ds)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)

chkpt.restore(p0)
np.testing.assert_equal(as_list(ds),expected0)

np.testing.assert_equal(as_list(ds),expected1)
# mangle state by iterating over it again
as_list(ds)

# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(ds),expected1)
print("Passed!")

解决方法

证明该状态已经在Iterator中进行了管理。

import tensorflow as tf
import numpy as np


def as_list(it: tf.data.Iterator,length: int = 5):
    return [it.next().numpy() for _ in range(length)]


ds = tf.data.Dataset.range(5).shuffle(5,seed=0).repeat()
it = iter(ds)
chkpt = tf.train.Checkpoint(it=it)
p0 = chkpt.save("/tmp/chkpt-0")
expected0 = as_list(it)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(it)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)

chkpt.restore(p0)
np.testing.assert_equal(as_list(it),expected0)

np.testing.assert_equal(as_list(it),expected1)
# mangle state by iterating over it again
as_list(it)

# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(it),expected1)
print("Passed!")