处理自定义数据增强层中的批量大小 - tensorflow

问题描述

我已经实现了这个简单的数据增强层,基本上它将图像旋转特定的角度(我知道它可以通过 ImageDataGenerator 完成,但这只是为了解释问题)。

    class Randomrotation(tf.keras.layers.Layer):

        def __init__(self,rotation_range=None,**kwargs):
            super(Randomrotation,self).__init__(**kwargs)    
 

        def call(self,images,training=None,**kwargs):

            batch_size = tf.shape(images)[0]

            if training is None:
               training = K.learning_phase()

            if not training:
               return images

            angles = np.random.uniform(-0.5,0.5,batch_size)
            images = tfa.image.rotate(images,angles)

我的模型,包括这一层,然后通过 fit 方法使用 ImageDataGenerator 进行训练,以自动获取生成器。我收到此错误是因为批大小的值为无。

TypeError: 预期序列对象 len >= 0 或单个整数

解决方法

tf.config.run_functions_eagerly(True)

因为脚本的第一行解决了这个问题。