具有多处理功能的Tensorflow2.x自定义数据生成器

问题描述

我刚刚升级到tensorflow 2.3。 我想制作自己的数据生成器进行培训。 使用tensorflow 1.x,我做到了:

def get_data_generator(test_flag):
  item_list = load_item_list(test_flag)
  print('data loaded')
  while True:
    X = []
    Y = []
    for _ in range(BATCH_SIZE):
      x,y = get_random_augmented_sample(item_list)
      X.append(x)
      Y.append(y)
    yield np.asarray(X),np.asarray(Y)

data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train,validation_data=data_generator_test,epochs=10000,verbose=2,use_multiprocessing=True,workers=8,validation_steps=100,steps_per_epoch=500,)

代码在tensorflow 1.x上运行良好。系统中创建了8个流程。处理器和视频卡已完美加载。 “已加载数据”已打印8次。

使用tensorflow 2.3我得到警告:

警告:张量流:多处理会与TensorFlow严重交互,从而导致不确定的死锁。对于高性能数据管道,建议使用tf.data。

“已加载数据”打印了一次(应该打印8次)。 GPU没有得到充分利用。每个时期也都有内存泄漏,因此转换将在几个时期后停止。 use_multiprocessing标志没有帮助。

如何在tensorflow(keras)2.x中制作一个可以在多个cpu进程之间轻松并行化的生成器/迭代器?死锁和数据顺序并不重要。

解决方法

通过tf.data管道,有几个地方可以并行化。根据数据的存储和读取方式,可以并行读取。您还可以并行化扩充,并且可以在训练时预取数据,因此GPU(或其他硬件)从不渴望数据。

在下面的代码中,我演示了如何并行化扩充并添加预取。

import numpy as np
import tensorflow as tf

x_shape = (32,32,3)
y_shape = ()  # A single item (not array).
classes = 10

def generator_fn(n_samples):
    """Return a function that takes no arguments and returns a generator."""
    def generator():
        for i in range(n_samples):
            # Synthesize an image and a class label.
            x = np.random.random_sample(x_shape).astype(np.float32)
            y = np.random.randint(0,classes,size=y_shape,dtype=np.int32)
            yield x,y
    return generator

def augment(x,y):
    return x * tf.random.normal(shape=x_shape),y

samples = 10
batch_size = 5
epochs = 2

# Create dataset.
gen = generator_fn(n_samples=samples)
dataset = tf.data.Dataset.from_generator(
    generator=gen,output_types=(np.float32,np.int32),output_shapes=(x_shape,y_shape)
)
# Parallelize the augmentation.
dataset = dataset.map(
    augment,num_parallel_calls=tf.data.experimental.AUTOTUNE,# Order does not matter.
    deterministic=False
)
dataset = dataset.batch(batch_size,drop_remainder=True)
# Prefetch some batches.
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

# Prepare model.
model = tf.keras.applications.VGG16(weights=None,input_shape=x_shape,classes=classes)
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy")

# Train. Do not specify batch size because the dataset takes care of that.
model.fit(dataset,epochs=epochs)