问题描述
我正在尝试使用六月发布的新TensorFlow对象检测API。但是我在使用他们提供的数据扩充工具时遇到了一些困难。这是因为它们从TensorFlow导入contrib.image
,而TensorFlow仅存在于TF 1.x中。因此,我的问题是:“任何人都知道如何在TF 2.x中使用此数据增强工具吗?”。
最诚挚的问候。
解决方法
您可以在Tensorflow网站https://www.tensorflow.org/tutorials/images/data_augmentation上找到使用TF 2.x的数据增强教程。
您还可以使用ImageDataGenerator库在Tf 2.x中执行数据增强。
tf.keras.preprocessing.image.ImageDataGenerator
imagedatagenerator的示例代码段
import tensorflow as tf
image = tf.keras.preprocessing.image.load_img('flower.jpeg')
image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')
#convert image to array
im_array = tf.keras.preprocessing.image.img_to_array(image)
img = im_array.reshape((1,) + im_array.shape)
#Generate the images
count = 0
for batch in image_datagen.flow(img,batch_size=1,save_to_dir ='image_gen',save_prefix='flower',save_format='jpeg'):
count +=1
if count==5:
break
#Input image
import matplotlib.pylab as plt
image = plt.imread('flower.jpeg')
plt.imshow(image)
plt.show()
#After augmentation
import matplotlib.pylab as plt
image = plt.imread('image_gen/flower_0_1167.jpeg')
plt.imshow(image)
plt.show()