如何对两个不等长的 tf.datasets 进行成对迭代?

问题描述

我处理两个长度不等的数据集。

我的目标是为 datasetA 中的每个元素取一个来自另一个 datasetB 的元素。我尝试了 .take(1)(如图 here 所示)从 datasetB 中获取单个元素,但重复调用 .take(1) 不会提高数据集的内部计数,即它始终返回相同的元素;但我想每次都得到一个新元素。

我可以使用 for element in datasetA: 迭代一个数据集,然后使用其中的第二个数据集作为 elementB = iterB.get_next()。这会在消耗 iterB 时引发错误

这是我正在使用的完整玩具代码

datasetA = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6])
datasetB = tf.data.Dataset.from_tensor_slices([11,22,33,44])

iterB = iter(datasetB)
epochs = 5

for epoch in range(epochs):
  print(f"Epoch {epoch}")
  for element in datasetA:
    print(element)
    elementB = iterB.get_next()
    print(elementB)

然后我继续:

for epoch in range(epochs):
  print(f"Epoch {epoch}")
  for element in datasetA:
    print(element)
    elementB = iterB.get_next_as_optional()
    if not elementB.has_value():
      iterB = iter(datasetB)
      elementB = iterB.get_next_as_optional()

    print(elementB.get_value())

这行得通,但重新初始化 datasetB 的迭代器很麻烦。

我进一步发现的是这个 for old TensorFlow,它使用 TF 操作重新初始化迭代器,该迭代器不再可用。 this question 中也提到了这一点,这很有帮助,但没有引导我找到 TF2.+ 解决方案。

我正在寻找的是一种从 datasetAdatasetB获取成对元素的优雅方式,其中 datasetB 在使用时(自动)重复。

我不需要迭代组合数据集,除非较短的数据集通过重复被“填充”到较长的数据集,然后我可以从数据集A和B中对(A,B)与A进行采样数据集B。

TL;博士: 想要在两个长度不等的数据集上进行成对迭代,在消耗时重新启动较短的数据集。

解决方法

我不懂这种编码语言,但这是您应该做的。

datasetA = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6]);
datasetB = tf.data.Dataset.from_tensor_slices([11,22,33,44]);
set i=0,j=0;
get lengths of both alength and blength;
for(i=0;i<alength;i++){
 print(datasetA[i]);
 print(datasetB[j]);
 if(j<blength-1)
   j++;
 else
  j=0;  
 }
,

要从两个数据集中获取所有可能的样本对,可以使用以下 generator

    # assuming that dataset_A and dataset_B are defined globally
    def generator():
        for sample_A in dataset_A:
            for sample_B in dataset_B:
                yield (sample_A,sample_B)

为了只获取数据集中相同位置的样本对(相同索引),有一个标准的 zip 方法:

    dataset = tf.data.Dataset.zip(dataset_A,dataset_B)

当其中一个数据集耗尽时,这种生成器就会停止。

如果目标是从 dataset_A 中获取所有样本的成对样本,但 dataset_B 较小,则可以无限重复第二个数据集,

     dataset_B = dataset_B.repeat()

然后 zip 两个数据集。