使用 ImageDataGenerator 对 Keras 中的视频4D 张量进行数据增强

问题描述

我在 Keras 中有一个 ImageDataGenerator,我想在训练期间将其应用于短视频剪辑中的每一帧,这些视频剪辑表示为具有形状 (num_frames,width,height,3) 的 4D numpy 数组。

对于由图像组成的标准数据集,每个图像都具有形状(宽度、高度、3),我通常会执行以下操作:

aug = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=15,zoom_range=0.15)

model.fit_generator(
        aug.flow(X_train,y_train),epochs=100)

如何将这些相同的数据增强应用于表示图像序列的 4D numpy 数组数据集?

解决方法

我想通了。我创建了一个继承自 tensorflow.keras.utils.Sequence 的自定义类,该类使用 scipy 为每个图像执行增强。

       class CustomDataset(tf.keras.utils.Sequence):
            def __init__(self,batch_size,*args,**kwargs):
                self.batch_size = batch_size
                self.X_train = args[0]
                self.Y_train = args[1]

            def __len__(self):
                # returns the number of batches
                return int(self.X_train.shape[0] / self.batch_size)

            def __getitem__(self,index):
                # returns one batch
                X = []
                y = []
                for i in range(self.batch_size):
                    r = random.randint(0,self.X_train.shape[0] - 1)
                    next_x = self.X_train[r]
                    next_y = self.Y_train[r]
                    
                    augmented_next_x = []
                    
                    ###
                    ### Augmentation parameters for this clip.
                    ###
                    rotation_amt = random.randint(-45,45)
                    
                    for j in range(self.X_train.shape[1]):
                        transformed_img = ndimage.rotate(next_x[j],rotation_amt,reshape=False)
                        transformed_img[transformed_img == 0] = 255
                        augmented_next_x.append(transformed_img)
                
                    X.append(augmented_next_x)
                    y.append(next_y)
                    
                X = np.array(X).astype('uint8')
                y = np.array(y)

                encoder = LabelBinarizer()
                y = encoder.fit_transform(y)
                
                return X,y

            def on_epoch_end(self):
                # option method to run some logic at the end of each epoch: e.g. reshuffling
                pass

然后我将其传递给 fit_generator 方法:

training_data_augmentation = CustomDataset(BS,X_train_L,y_train_L)
model.fit_generator(
    training_data_augmentation,epochs=300)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...