numpy 4d数组到tf.data.dataset

问题描述

我正在遵循本教程https://www.tensorflow.org/tutorials/generative/pix2pix,但我正在尝试建立自己的输入管道。我有一个4d numpy数组(数字样本,高度,宽度,通道),并且使用ds = tf.data.Dataset.from_tensor_slices()创建我的数据集。但是,当我调用ds.take(1)时,它没有批量大小的尺寸。我可以通过在必要的地方插入tf.expand_dims()解决此问题,但我认为应该有一种方法可以在数据集中执行此操作。

解决方法

您可以尝试:

for image in ds.batch(1).take(1):
    assert image.shape[0] == 1
    # do something with the image