问题描述
我想在具有 tensorflow 模型的 3D 数据集中实现增强。
增强函数是这样的:
def augmentation(img,label):
p = .5
print('augmentation')
if random.random() > p:
img = tf.numpy_function(augment_noise,[img],tf.double)
if random.random() > p:
img = tf.numpy_function(flip_x,tf.double)
if random.random() > p:
img = tf.numpy_function(augment_scale,tf.double)
if random.random() > p:
img = tf.numpy_function(distort_elastic_cv2,tf.double)
img = tf.image.convert_image_dtype(img,tf.float32)
return img,label
tensorflow 中没有实现增强函数。
ds_train = tf.data.Dataset.from_tensor_slices((image_train,label_train))
ds_valid = tf.data.Dataset.from_tensor_slices((image_val,label_val))
batch_size = 16
repeat_count = int((1000 * batch_size)/len(image_train))
# AUTOTUNE = tf.data.experimental.AUTOTUNE # tf.data.AUTOTUNE
AUTOTUNE = 16
# Augment the on the fly during training.
ds_train = (
ds_train.shuffle(len(ds_train)).repeat(repeat_count)
.map(augmentation,num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(buffer_size=AUTOTUNE)
)
ds_valid = (
ds_valid.batch(batch_size)
.prefetch(buffer_size=AUTOTUNE)
)
initial_epoch = 0
epochs = 1000
H = model.fit(ds_train,validation_data=ds_valid,initial_epoch=initial_epoch,epochs = epochs,callbacks = chkpts,use_multiprocessing=False,workers=1,verbose=2)
我想在每个 epoch 中从数据集中随机选择大约 1000 个批次,然后在它们上进行扩充。我计算 repeat_count
以创建大小为 batch_size
的 1000 个批次。
问题是我不知道每个时期的模型调用增强函数并将其隐含到批次的每个图像中(我的意思是每个时期中有 161000 张图像),所以我添加了 {{1}在 print
函数中,它只打印一次,而不是在每个时期或每个图像中。
增强函数是否在每个 epoch 中调用 161000 次?
此外,每次运行代码时,cpu 和 gpu 的利用率也不同。有时cpu的利用率约为25%,gpu为30,但几乎在运行中为100%和5。
如何解决这两个问题?
解决方法
你的字符串被打印一次,因为它调用一次来制作一个 Tensorflow 图。如果您使用 tf.print
打印,它将成为图表的一部分,因此每次都会打印。
复制/粘贴:
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.datasets import load_sample_image
import numpy as np
import random
imgs = np.stack([load_sample_image('flower.jpg') for i in range(4*4)],axis=0)
def augmentation(img):
p = .5
tf.print('augmentation successful!')
img = tf.image.convert_image_dtype(img,tf.float32)
return img
ds_train = tf.data.Dataset.from_tensor_slices(imgs)
batch_size = 16
repeat_count = 10
AUTOTUNE = 16
ds_train = (
ds_train.shuffle(len(ds_train)).repeat(repeat_count)
.map(augmentation,num_parallel_calls=AUTOTUNE)
.batch(batch_size)
.prefetch(buffer_size=AUTOTUNE)
)
for i in ds_train:
pass
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!