如何创建自定义 keras 生成器以适应多个输出并使用工作人员

问题描述

我有一个输入和多个输出,比如多标签分类,但我选择尝试另一种方法,看看我是否有任何改进。

我有这些生成器,我使用 flow_from_dataframe 因为我有一个巨大的数据集 (200k):

self.train_generator = datagen.flow_from_dataframe(
    dataframe=train,directory='dataset',x_col='Filename',y_col=columns,batch_size=BATCH_SIZE,color_mode='rgb',class_mode='raw',shuffle=True,target_size=(HEIGHT,WIDTH))

self.test_generator = datatest.flow_from_dataframe(
    dataframe=test,WIDTH))

我正在使用这个函数来适应:

def generator(self,generator):
    while True:
        X,y = generator.next()
        y = [y[:,x] for x in range(len(columns))]
        yield X,[y]

如果我适合这样:

self.h = self.model.fit_generator(self.generator(self.train_generator),steps_per_epoch=self.STEP_SIZE_TRAIN,validation_data=self.generator(self.test_generator),validation_steps=self.STEP_SIZE_TEST,epochs=50,verbose = 1,workers = 2,)

我明白了:

RuntimeError: Your generator is NOT thread-safe. Keras requires a thread-safe generator when `use_multiprocessing=False,workers > 1`. 

使用 multiprocessing=True:

self.h = self.model.fit_generator(self.generator(self.train_generator),use_multiprocessing=True,)

结果:

  File "C:\ProgramData\Anaconda3\lib\threading.py",line 932,in _bootstrap_inner
    self.run()
  File "C:\ProgramData\Anaconda3\lib\threading.py",line 870,in run
    self._target(*self._args,**self._kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\utils\data_utils.py",line 877,in _run
    with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\keras\utils\data_utils.py",line 867,in pool_fn
    pool = get_pool_class(True)(
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\context.py",line 119,in Pool
    return Pool(processes,initializer,initargs,maxtasksperchild,File "C:\ProgramData\Anaconda3\lib\multiprocessing\pool.py",line 212,in __init__
    self._repopulate_pool()
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\pool.py",line 303,in _repopulate_pool
    return self._repopulate_pool_static(self._ctx,self.Process,line 326,in _repopulate_pool_static
    w.start()
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\process.py",line 121,in start
    self._popen = self._Popen(self)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\context.py",line 327,in _Popen
    return Popen(process_obj)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\popen_spawn_win32.py",line 93,in __init__
    reduction.dump(process_obj,to_child)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\reduction.py",line 60,in dump
    ForkingPickler(file,protocol).dump(obj)
TypeError: cannot pickle 'generator' object


  File "C:\ProgramData\Anaconda3\lib\threading.py",protocol).dump(obj)
TypeError: cannot pickle 'generator' object

Traceback (most recent call last):
  File "<string>",line 1,in <module>
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py",line 116,in spawn_main
    exitcode = _main(fd,parent_sentinel)
  File "C:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py",line 126,in _main
    self = reduction.pickle.load(from_parent)
EOFError: Ran out of input

现在卡住了,怎么解决?

解决方法

根据文档https://keras.io/api/preprocessing/image/

可以将参数 class_mode 设置为“multi_output”,因此您无需创建自定义生成器:

 class_mode: one of "binary","categorical","input","multi_output","raw",sparse" or None. Default: "categorical". Mode for yielding the targets: 

- "binary": 1D numpy array of binary labels,- "categorical": 2D numpy array of one-hot encoded labels. Supports multi-label output. 
- "input": images identical to input images (mainly used to work with autoencoders),- "multi_output": list with the values of the different columns,- "raw": numpy array of values in y_col column(s),- "sparse": 1D numpy array of integer labels,- None,no targets are returned (the generator will only yield batches of image data,which is useful to use in model.predict()).

我现在可以使用 > 1 的工作人员,但我没有性能改进。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...