如何在Tensorflow中使用tf.data加载视频数据以进行端到端的分布式输入模型训练?

问题描述

我正在尝试使用TensorFlow中的视频数据训练模型。当前的数据管道使用数据生成器,该数据生成器使用moviepy模块加载视频,然后使用dlib和其他一些软件包进行一些视频预处理。我正在进行端到端培训(而不是先对视频进行预处理并向模型加载清理后的数据),因为我希望模型能够进行随机学习。换句话说,在每个时期,我希望从视频加载的帧是不同的。但是,培训的速度非常慢,即每个纪元约17小时,因为主要瓶颈出现在数据生成器(加载和预处理视频)部分。

搜索了tf网站,发现了分布式输入功能和数据预取功能。这是我将生成器放入tf数据集的部分,其中valid_list中的args生成器接受并产生已处理数据的文件路径的列表:

valid_gen=tf.data.Dataset.from_generator(generate_data,output_types=((tf.float32,tf.float32,tf.float32),tf.int64),output_shapes=((tf.TensorShape([None,segments,imsize,3]),tf.TensorShape([None,68]),68,2])),n_class])),args=(valid_list,batch_size))
valid_gen.prefetch(buffer_size=batch_size)
dist_valid_data=mirrored_strategy.experimental_distribute_dataset(valid_gen)

但是,出现以下错误

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-48-3af068b6698e> in <module>
      2                       validation_data=dist_valid_data,callbacks=callbacks,3                       validation_steps=len(valid_dict)//batch_size,----> 4                       batch_size=batch_size,epochs=1000)

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self,*args,**kwargs)
    106   def _method_wrapper(self,**kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self,**kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,batch_size,epochs,verbose,callbacks,validation_split,validation_data,shuffle,class_weight,sample_weight,initial_epoch,steps_per_epoch,validation_steps,validation_batch_size,validation_freq,max_queue_size,workers,use_multiprocessing)
   1096                 batch_size=batch_size):
   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)
   1099               if data_handler.should_sync:
   1100                 context.async_wait()

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self,**kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args,**kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self,**kwds)
    838         # Lifting succeeded,so variables are initialized and we can run the
    839         # stateless function.
--> 840         return self._stateless_fn(*args,**kwds)
    841     else:
    842       canon_args,canon_kwds = \

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self,**kwargs)
   2827     with self._lock:
   2828       graph_function,args,kwargs = self._maybe_define_function(args,kwargs)
-> 2829     return graph_function._filtered_call(args,kwargs)  # pylint: disable=protected-access
   2830 
   2831   @property

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self,kwargs,cancellation_manager)
   1846                            resource_variable_ops.BaseResourceVariable))],1847         captured_inputs=self.captured_inputs,-> 1848         cancellation_manager=cancellation_manager)
   1849 
   1850   def _call_flat(self,captured_inputs,cancellation_manager=None):

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self,cancellation_manager)
   1922       # No tape is watching; skip to running the function.
   1923       return self._build_call_outputs(self._inference_function.call(
-> 1924           ctx,cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(
   1926         args,~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self,ctx,cancellation_manager)
    548               inputs=args,549               attrs=attrs,--> 550               ctx=ctx)
    551         else:
    552           outputs = execute.execute_with_cancellation(

~/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,---> 60                                         inputs,num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  TypeError: endswith first arg must be bytes or a tuple of bytes,not str
Traceback (most recent call last):

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/ops/script_ops.py",line 244,in __call__
    ret = func(*args)

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py",line 302,in wrapper
    return func(*args,**kwargs)

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py",line 827,in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "<ipython-input-44-34d69f4e6d29>",line 306,in generate_data
    imsize=imsize,segments=segments)

  File "<ipython-input-44-34d69f4e6d29>",line 221,in pre_processing
    v_array=np.array(list(VideoFileClip(fp).iter_frames(fps=fps)))

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/moviepy/video/io/VideoFileClip.py",line 91,in __init__
    fps_source=fps_source)

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/moviepy/video/io/ffmpeg_reader.py",line 36,in __init__
    fps_source)

  File "/home/dwa382/.conda/envs/video_env_1/lib/python3.7/site-packages/moviepy/video/io/ffmpeg_reader.py",in ffmpeg_parse_infos
    is_GIF = filename.endswith('.gif')

TypeError: endswith first arg must be bytes or a tuple of bytes,not str


     [[{{node PyFunc}}]]
     [[MultideviceIteratorGetNextFromShard]]
     [[RemoteCall]]
     [[IteratorGetNextAsOptional]] [Op:__inference_train_function_19168]

Function call stack:
train_function

谢谢。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)