问题描述
我正在遵循本教程https://www.tensorflow.org/tutorials/generative/pix2pix,但我正在尝试建立自己的输入管道。我有一个4d numpy数组(数字样本,高度,宽度,通道),并且使用ds = tf.data.Dataset.from_tensor_slices()
创建我的数据集。但是,当我调用ds.take(1)
时,它没有批量大小的尺寸。我可以通过在必要的地方插入tf.expand_dims()
来解决此问题,但我认为应该有一种方法可以在数据集中执行此操作。
解决方法
您可以尝试:
for image in ds.batch(1).take(1):
assert image.shape[0] == 1
# do something with the image