问题描述
我正在尝试将图像单应性代码从 TF1 版本转换为 TF2,只是 TF 脚本转换在这里不起作用。我坚持对数据集进行批处理,因为图像、image_patch 和 image_Indices 具有不同的形状。虽然 TF1 在摄取和批处理数据集包方面没有问题,但 TF2 有问题。
imgs= np.random.rand(11,240,320,3)
pts = np.random.randint(100,size =(11,8))
patch = np.random.rand(11,128,1)
imgs = tf.convert_to_tensor(imgs)
pts = tf.convert_to_tensor(pts)
patch = tf.convert_to_tensor(patch)
pts= tf.cast(pts,dtype=tf.float64)
张量流2:
img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch]).shuffle(buffer_size=batch_size*4)
这里 11 是图像数量,240 和 320 是图像尺寸,3 是通道数。
错误 -
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [11,3] != values[2].shape = [11,1] [Op:Pack] name: component_0
张量流1:
tf.compat.v1.train.batch([imgs,patch],batch_size=5)
输出 -
[<tf.Tensor 'batch_2:0' shape=(5,11,3) dtype=float64>,<tf.Tensor 'batch_2:1' shape=(5,8) dtype=float64>,<tf.Tensor 'batch_2:2' shape=(5,1) dtype=float64>]
如何在 tensorflow2 中批量处理不同维度的数据集? 同样运行,“tf.compat.v1.train.batch()”在 TF2(tensoflow 2.3 版)中不起作用,因为它会产生急切执行错误。
在 TF2 中批处理此类数据集的正确方法是什么?
解决方法
这里的问题不是批处理,而是 tf.data.Dataset
本身的生成。错误是由 img_batch,pts_batch,patch_batch = tf.data.Dataset.from_tensor_slices([imgs,pts,patch])
引起的,不是由 .shuffle(batch_size=...)
引起的。
我认为这里的 .from_tensor_slices
级别太高,请查看 tf.data.Dataset.from_generator
。