keras数据扩充

问题描述

我知道ImageDataGenerator为每个输入图像生成一个随机扩充的图像。现在,我想为每个输入图像生成两个增强图像:

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')
train_ds = datagen.flow_from_directory('/home/train/')

为进一步说明,我想在同一张图像上应用2个不同的增强函数,即,如果我们对5张图像进行采样,最终将在批处理中得到2×5 = 10个增强观察

那么我该如何进行?

解决方法

我建议创建一个从tf.keras.utils.Sequence继承的自定义数据生成器。有很多方法可以解决此问题,但这应该符合您的寻找思路:

class double_aug_generator(tf.keras.utils.Sequence):
    def __init__(self,x,y,batch_size,aug_params1,aug_params2):
        self.x,self.y = x,y
        self.batch_size = batch_size
        self.datagen = tf.keras.preprocessing.image.ImageDataGenerator(**aug_params1)
        
        // dictionary of parameters for the second augmentation
        self.aug_params2 = aug_params2

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)
    
    def load(self,file_names):
        // load and return raw images however you like

    def __getitem__(self,idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        
        // load images
        batch_x = self.load(batch_x)
        
        // apply first augmentation
        batch_x = self.datagen.flow(batch_x)
        
        // apply second
        batch_x = self.datagen.apply_transform(batch_x,self.aug_params2)
        
        return batch_x,np.array(batch_y)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...