如何在 tensorflow 中实现动态增强?

问题描述

我想在具有 tensorflow 模型的 3D 数据集中实现增强。

增强函数是这样的:

    def augmentation(img,label):
            
            p = .5
            print('augmentation')
            
            if random.random() > p:
                img = tf.numpy_function(augment_noise,[img],tf.double)
                
            if random.random() > p:
                img = tf.numpy_function(flip_x,tf.double)
    
            if random.random() > p:
                img = tf.numpy_function(augment_scale,tf.double)
    
            if random.random() > p:
                img = tf.numpy_function(distort_elastic_cv2,tf.double)    
            
           
            img = tf.image.convert_image_dtype(img,tf.float32)
            
            return img,label

tensorflow 中没有实现增强函数

使用该函数的 Tensorflow 代码如下:

ds_train = tf.data.Dataset.from_tensor_slices((image_train,label_train))
ds_valid = tf.data.Dataset.from_tensor_slices((image_val,label_val))


batch_size = 16
repeat_count = int((1000 * batch_size)/len(image_train))
# AUTOTUNE =  tf.data.experimental.AUTOTUNE # tf.data.AUTOTUNE
AUTOTUNE = 16

# Augment the on the fly during training.
ds_train = (
    ds_train.shuffle(len(ds_train)).repeat(repeat_count)
    .map(augmentation,num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(buffer_size=AUTOTUNE)
)


ds_valid = (
    ds_valid.batch(batch_size)
    .prefetch(buffer_size=AUTOTUNE)
)

initial_epoch = 0
epochs = 1000
H = model.fit(ds_train,validation_data=ds_valid,initial_epoch=initial_epoch,epochs = epochs,callbacks = chkpts,use_multiprocessing=False,workers=1,verbose=2)

我想在每个 epoch 中从数据集中随机选择大约 1000 个批次,然后在它们上进行扩充。我计算 repeat_count 以创建大小为 batch_size 的 1000 个批次。

问题是我不知道每个时期的模型调用增强函数并将其隐含到批次的每个图像中(我的意思是每个时期中有 161000 张图像),所以我添加了 {{1}在 print 函数中,它只打印一次,而不是在每个时期或每个图像中。 增强函数是否在每个 epoch 中调用 161000 次?

此外,每次运行代码时,cpu 和 gpu 的利用率也不同。有时cpu的利用率约为25%,gpu为30,但几乎在运行中为100%和5。

如何解决这两个问题?

解决方法

你的字符串被打印一次,因为它调用一次来制作一个 Tensorflow 图。如果您使用 tf.print 打印,它将成为图表的一部分,因此每次都会打印。

复制/粘贴:

import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.datasets import load_sample_image
import numpy as np
import random

imgs = np.stack([load_sample_image('flower.jpg') for i in range(4*4)],axis=0)

def augmentation(img):      
        p = .5
        tf.print('augmentation successful!')  
        img = tf.image.convert_image_dtype(img,tf.float32)
        return img 
    

ds_train = tf.data.Dataset.from_tensor_slices(imgs)


batch_size = 16
repeat_count = 10
AUTOTUNE = 16

ds_train = (
    ds_train.shuffle(len(ds_train)).repeat(repeat_count)
    .map(augmentation,num_parallel_calls=AUTOTUNE)
    .batch(batch_size)
    .prefetch(buffer_size=AUTOTUNE)
)

for i in ds_train:
    pass
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!
augmentation successful!

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...