用于图像分割数据集的 Keras 数据增强管道具有相同操作的图像和掩码

问题描述

我正在为我的图像分割数据集构建预处理和数据增强管道 keras 有一个强大的 API 可以做到这一点,但我遇到了在图像上重现相同增强以及分割掩码(第二张图像)的问题。两个图像必须经过完全相同的操作。还不支持吗?

https://www.tensorflow.org/tutorials/images/data_augmentation

示例/伪代码

data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical",seed=SEED_VAL),layers.experimental.preprocessing.Randomrotation(factor=0.4,fill_mode="constant",fill_value=0,layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2),fill_mode='constant',seed=SEED_VAL)])

(train_ds,test_ds),info = tfds.load('somedataset',split=['train[:80%]','train[80%:]'],with_info=True)

这段代码不起作用,但说明了我梦想中的 api 是如何工作的:

train_ds = train_ds.map(lambda datapoint: data_augmentation((datapoint['image'],datapoint['segmentation_mask']),training=True))

替代方案

另一种方法是编写图像分割教程 (https://www.tensorflow.org/tutorials/images/segmentation) 中提出的自定义加载和操作/随机方法

非常感谢有关此类数据集最先进数据增强的任何提示:)

解决方法

您可以尝试使用外部库进行额外的图像增强。这些链接可能有助于图像增强和分割掩码,

相册

https://github.com/albumentations-team/albumentations

https://albumentations.ai/docs/getting_started/mask_augmentation/

enter image description here

图片

https://github.com/aleju/imgaug

https://nbviewer.jupyter.org/github/aleju/imgaug-doc/blob/master/notebooks/B05%20-%20Augment%20Segmentation%20Maps.ipynb

enter image description here

,

这是我自己的实现,以防其他人想在 2020 年 12 月使用 tf 内置函数 (tf.image api) :)

@tf.function
def load_image(datapoint,augment=True):
    
    # resize image and mask
    img_orig = input_image = tf.image.resize(datapoint['image'],(IMG_SIZE,IMG_SIZE))
    mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'],IMG_SIZE))
    
    # rescale the image
    if IMAGE_CHANNELS == 1:
        input_image = tf.image.rgb_to_grayscale(input_image)
    input_image = tf.cast(input_image,tf.float32) / 255.0
    
    # augmentation
    if augment:
        # zoom in a bit
        if tf.random.uniform(()) > 0.5:
            # use original image to preserve high resolution
            input_image = tf.image.central_crop(img_orig,0.75)
            input_mask = tf.image.central_crop(mask_orig,0.75)
            # resize
            input_image = tf.image.resize(input_image,IMG_SIZE))
            input_mask = tf.image.resize(input_mask,IMG_SIZE))
        
        # random brightness adjustment illumination
        input_image = tf.image.random_brightness(input_image,0.3)
        # random contrast adjustment
        input_image = tf.image.random_contrast(input_image,0.2,0.5)
        
        # flipping random horizontal or vertical
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            input_mask = tf.image.flip_left_right(input_mask)
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_up_down(input_image)
            input_mask = tf.image.flip_up_down(input_mask)

        # rotation in 30° steps
        rot_factor = tf.cast(tf.random.uniform(shape=[],maxval=12,dtype=tf.int32),tf.float32)
        angle = np.pi/12*rot_factor
        input_image = tfa.image.rotate(input_image,angle)
        input_mask = tfa.image.rotate(input_mask,angle)

    return input_image,input_mask

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...