在Tensorflow中为输入和输出保留相同的数据集扩充

问题描述

我有一个批处理数据集,其中包含图像作为输入和输出代码是这样的:

os.chdir(r'E:/trainTest')

def process_img(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_png(img,channels=3)
    img = tf.image.convert_image_dtype(img,tf.float32)
    img = tf.image.resize(img,size=(img_height,img_width))
    return img

x_files = glob('input/*.png')
y_files = glob('output/*.png')

files_ds = tf.data.Dataset.from_tensor_slices((x_files,y_files))


#Dataset which gives me input-output 
files_ds = files_ds.map(lambda x,y: (process_img(x),process_img(y))).batch(batch_size)

#model init etc
#----

model.fit(files_ds,epochs=25)

问题是我的模型没有足够的图像。所以我的问题是,如何从files_ds创建增强的图像(如翻转,旋转,缩放等)?因为必须以相同的方式增强输出图像,所以必须增强输入图像。

这个问题实际上来自以下问题,我想在自己的部分中提问:
Tensorflow image_dataset_from_directory for input dataset and output dataset

解决方法

tf.image有很多可以使用的随机转换。例如:

这是一个关于猫的完全随机图像的例子。

from skimage import data
import matplotlib.pyplot as plt
import tensorflow as tf

cat = data.chelsea()

plt.imshow(cat)
plt.show()

enter image description here

具有变换的图像:

from skimage import data
import matplotlib.pyplot as plt
import tensorflow as tf

cat = data.chelsea()

plt.imshow(tf.image.random_hue(cat,.2,.5))
plt.show()

enter image description here

您可以像这样在您的tf.data.Dataset中实现它:

def process_img(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_png(img,channels=3)
    img = tf.image.convert_image_dtype(img,tf.float32)
    img = tf.image.resize(img,size=(img_height,img_width))
    img = tf.image.random_hue(img,0.,.5) # Here
    return img

我找到了一种在图形模式下保持相同变换的方法。基本上是在同一调用中将两个图像传递给转换器。

import os
import tensorflow as tf
os.chdir(r'c:/users/user/Pictures')
from glob2 import glob
import matplotlib.pyplot as plt

x_files = glob('inputs/*.jpg')
y_files = glob('targets/*.jpg')

files_ds = tf.data.Dataset.from_tensor_slices((x_files,y_files))

def load(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img,size=(28,28))
    return img

def process_img(file_path1,file_path2):
    img = tf.stack([load(file_path1),load(file_path2)])
    img = tf.image.random_hue(img,max_delta=.5)
    return img[0],img[1]

files_ds = files_ds.map(lambda x,y: process_img(x,y)).batch(1)

a,b = next(iter(files_ds))

plt.imshow(a[0,...])
plt.imshow(b[0,...])
,

您可以使用 tf.keras.preprocessing.image.ImageDataGenerator 进行预处理。 This是完整选项的文档页面。我分享了一个小示例,说明如何通过 flow_from_directory 使用它。您没有事先阅读图像并用完RAM。图像从目录中加载,经过预处理并根据需要输入到模型中。

# we create two instances with the same arguments
data_gen_args = dict(rescale=1./255,shear_range=0.2,horizontal_flip=True,rotation_range=90,width_shift_range=0.1,height_shift_range=0.1,zoom_range=0.2)

image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1

image_generator = image_datagen.flow_from_directory(
    'data/images',class_mode=None,seed=seed)

mask_generator = mask_datagen.flow_from_directory(
    'data/masks',seed=seed)

# combine generators into one which yields image and masks
train_generator = zip(image_generator,mask_generator)

model.fit(
    train_generator,steps_per_epoch=2000,epochs=50)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...