如何在Tensorflow 2.3中的自定义损失函数中处理批量大小

问题描述

我正在尝试实现与三重损失有关的自定义损失功能。三元组损失提供了一个自定义距离度量,可返回嵌入之间的成对距离。我定义了一个自定义函数,该函数可以在正向传播中正常工作。但是在反向传播时会引发一些错误。以下是错误

InvalidArgumentError:  slice index 16 of dimension 1 out of bounds.
 [[{{node TripletSemiHardLoss/PartitionedCall/while_1/body/_226/while_1/strided_slice}}]] [Op:__inference_train_function_31232]

函数调用堆栈: train_function

16是我输入的批次大小。我没有在自定义代码中使用任何while循环。但是,有一个for循环。

我尝试了以下方法

  1. 我使用tf.size(input)检索批次大小。在前进道具上工作正常。
  2. 我尝试了while循环和for循环。在向前传播时,两者都工作正常。两者都产生相同的结果。但是在反向传播中,两者都抛出相同的错误

这是总错误堆栈:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-22-70c4ddc79f73> in <module>
     11                            epochs=25,12                            callbacks=[checkpoint],---> 13                            verbose=1)

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args,**kwargs)
    322               'in a future version' if date is None else ('after %s' % date),323               instructions)
--> 324       return func(*args,**kwargs)
    325     return tf_decorator.make_decorator(
    326         func,new_func,'deprecated',~/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit_generator(self,generator,steps_per_epoch,epochs,verbose,callbacks,validation_data,validation_steps,validation_freq,class_weight,max_queue_size,workers,use_multiprocessing,shuffle,initial_epoch)
   1827         use_multiprocessing=use_multiprocessing,1828         shuffle=shuffle,-> 1829         initial_epoch=initial_epoch)
   1830 
   1831   @deprecation.deprecated(

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in _method_wrapper(self,*args,**kwargs)
    106   def _method_wrapper(self,**kwargs):
    107     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
--> 108       return method(self,**kwargs)
    109 
    110     # Running inside `run_distribute_coordinator` already.

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in fit(self,x,y,batch_size,validation_split,sample_weight,initial_epoch,validation_batch_size,use_multiprocessing)
   1096                 batch_size=batch_size):
   1097               callbacks.on_train_batch_begin(step)
-> 1098               tmp_logs = train_function(iterator)
   1099               if data_handler.should_sync:
   1100                 context.async_wait()

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in __call__(self,**kwds)
    778       else:
    779         compiler = "nonXla"
--> 780         result = self._call(*args,**kwds)
    781 
    782       new_tracing_count = self._get_tracing_count()

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _call(self,**kwds)
    838         # Lifting succeeded,so variables are initialized and we can run the
    839         # stateless function.
--> 840         return self._stateless_fn(*args,**kwds)
    841     else:
    842       canon_args,canon_kwds = \

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in __call__(self,**kwargs)
   2827     with self._lock:
   2828       graph_function,args,kwargs = self._maybe_define_function(args,kwargs)
-> 2829     return graph_function._filtered_call(args,kwargs)  # pylint: disable=protected-access
   2830 
   2831   @property

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _filtered_call(self,kwargs,cancellation_manager)
   1846                            resource_variable_ops.BaseResourceVariable))],1847         captured_inputs=self.captured_inputs,-> 1848         cancellation_manager=cancellation_manager)
   1849 
   1850   def _call_flat(self,captured_inputs,cancellation_manager=None):

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _call_flat(self,cancellation_manager)
   1922       # No tape is watching; skip to running the function.
   1923       return self._build_call_outputs(self._inference_function.call(
-> 1924           ctx,cancellation_manager=cancellation_manager))
   1925     forward_backward = self._select_forward_and_backward_functions(
   1926         args,~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py in call(self,ctx,cancellation_manager)
    548               inputs=args,549               attrs=attrs,--> 550               ctx=ctx)
    551         else:
    552           outputs = execute.execute_with_cancellation(

~/anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name,num_outputs,inputs,attrs,name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle,device_name,op_name,---> 60                                         inputs,num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  slice index 16 of dimension 1 out of bounds.
     [[{{node TripletSemiHardLoss/PartitionedCall/while_1/body/_226/while_1/strided_slice}}]] [Op:__inference_train_function_31232]

Function call stack:
train_function

解决方法

实际上是因为numpy样式数组切片。使用tf.slice解决了该问题。