尝试使用tf.GradientTape进行自定义训练时,“所有输入的形状必须匹配”错误丢失函数

问题描述

我正在使用Python 3.7.7。以及带有功能性API和热切执行的Tensorflow 2.1.0。

我正在尝试使用从U-Net预训练网络提取的编码器来进行custom training

  1. 我无需编译即可获得U-Net模型。
  2. 我已将权重加载到模型中。
  3. 我从那个模型中提取了编码器和解码器。

然后我要在此摘要中使用编码器:

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None,200,1)]     0         
_________________________________________________________________
conv1_1 (Conv2D)             (None,64)      1664      
_________________________________________________________________
conv1_2 (Conv2D)             (None,64)      102464    
_________________________________________________________________
pool1 (MaxPooling2D)         (None,100,64)      0         
_________________________________________________________________
conv2_1 (Conv2D)             (None,96)      55392     
_________________________________________________________________
conv2_2 (Conv2D)             (None,96)      83040     
_________________________________________________________________
pool2 (MaxPooling2D)         (None,50,96)        0         
_________________________________________________________________
conv3_1 (Conv2D)             (None,128)       110720    
_________________________________________________________________
conv3_2 (Conv2D)             (None,128)       147584    
_________________________________________________________________
pool3 (MaxPooling2D)         (None,25,128)       0         
_________________________________________________________________
conv4_1 (Conv2D)             (None,256)       295168    
_________________________________________________________________
conv4_2 (Conv2D)             (None,256)       1048832   
_________________________________________________________________
pool4 (MaxPooling2D)         (None,12,256)       0         
_________________________________________________________________
conv5_1 (Conv2D)             (None,512)       1180160   
_________________________________________________________________
conv5_2 (Conv2D)             (None,512)       2359808   
=================================================================
Total params: 5,384,832
Trainable params: 5,832
Non-trainable params: 0
_________________________________________________________________

我使用此功能进行自定义训练:

def train_encoder_unet_custom(model,dataset):
  
  optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

  for episode in range(num_episodes):
    selected = np.random.permutation(no_of_samples)[:num_shot + num_query]
    # Create our Support Set.
    support_set = np.array(dataset[selected[:num_shot]])
    
    X_train = support_set[:,:]
    y_train = support_set[:,1,:]

    loss_value,grads = grad(model,X_train,y_train)

    optimizer.apply_gradients(zip(grads,model.trainable_variables))

grad函数是:

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

def loss(model,x,y,training):
  # training=training is needed only if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  y_ = model(x,training=training)

  return loss_object(y_true=y,y_pred=y_)

def grad(model,inputs,targets):
  with tf.GradientTape() as tape:
    loss_value = loss(model,targets,training=False)
  return loss_value,tape.gradient(loss_value,model.trainable_variables)

但是当我尝试运行它时,出现错误

InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [5,512] != values[1].shape = [5,256] [Op:Pack] name: packed

loss函数中,我检查了y_变量的值。 y_是具有以下形状的6个元素的列表:

(5,512)
(5,256)
(5,128)
(5,96)
(5,64)
(5,1)

有什么想法吗?

如果您需要更多详细信息,请问我。

这是完整的调用堆栈:

<ipython-input-133-22827956a9f6> in train_encoder_unet_custom(model,dataset,feat_type,show)
     22     y_valid = query_set[:,:]
     23 
---> 24     loss_value,y_train)
     25 
     26     optimizer.apply_gradients(zip(grads,model.trainable_variables))

<ipython-input-143-58ff4de686d6> in grad(model,targets)
     10 def grad(model,targets):
     11   with tf.GradientTape() as tape:
---> 12     loss_value = loss(model,training=False)
     13   return loss_value,model.trainable_variables)

<ipython-input-143-58ff4de686d6> in loss(model,training)
      6   y_ = model(x,training=training)
      7 
----> 8   return loss_object(y_true=y,y_pred=y_)
      9 
     10 def grad(model,targets):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in __call__(self,y_true,y_pred,sample_weight)
    147     with K.name_scope(self._name_scope),graph_ctx:
    148       ag_call = autograph.tf_convert(self.call,ag_ctx.control_status_ctx())
--> 149       losses = ag_call(y_true,y_pred)
    150       return losses_utils.compute_weighted_loss(
    151           losses,sample_weight,reduction=self._get_reduction())

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args,**kwargs)
    253       try:
    254         with conversion_ctx:
--> 255           return converted_call(f,args,kwargs,options=options)
    256       except Exception as e:  # pylint:disable=broad-except
    257         if hasattr(e,'ag_error_Metadata'):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in converted_call(f,caller_fn_scope,options)
    455   if conversion.is_in_whitelist_cache(f,options):
    456     logging.log(2,'Whitelisted %s: from cache',f)
--> 457     return _call_unconverted(f,options,False)
    458 
    459   if ag_ctx.control_status_ctx().status == ag_ctx.Status.disABLED:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f,update_cache)
    337 
    338   if kwargs is not None:
--> 339     return f(*args,**kwargs)
    340   return f(*args)
    341 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in call(self,y_pred)
    251           y_pred,y_true)
    252     ag_fn = autograph.tf_convert(self.fn,ag_ctx.control_status_ctx())
--> 253     return ag_fn(y_true,**self._fn_kwargs)
    254 
    255   def get_config(self):

/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args,**kwargs)
    199     """Call target,and fall back on dispatchers if there is a TypeError."""
    200     try:
--> 201       return target(*args,**kwargs)
    202     except (TypeError,ValueError):
    203       # Note: convert_to_eager_tensor currently raises a ValueError,not a

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/losses.py in sparse_categorical_crossentropy(y_true,from_logits,axis)
   1562     Sparse categorical crossentropy loss value.
   1563   """
-> 1564   y_pred = ops.convert_to_tensor_v2(y_pred)
   1565   y_true = math_ops.cast(y_true,y_pred.dtype)
   1566   return K.sparse_categorical_crossentropy(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor_v2(value,dtype,dtype_hint,name)
   1380       name=name,1381       preferred_dtype=dtype_hint,-> 1382       as_ref=False)
   1383 
   1384 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value,name,as_ref,preferred_dtype,ctx,accepted_result_types)
   1497 
   1498     if ret is None:
-> 1499       ret = conversion_func(value,dtype=dtype,name=name,as_ref=as_ref)
   1500 
   1501     if ret is NotImplemented:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_conversion_function(v,as_ref)
   1500   elif dtype != inferred_dtype:
   1501     v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype),v)
-> 1502   return _autopacking_helper(v,name or "packed")
   1503 
   1504 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py in _autopacking_helper(list_or_tuple,name)
   1406     # checking.
   1407     if all(isinstance(elem,core.Tensor) for elem in list_or_tuple):
-> 1408       return gen_array_ops.pack(list_or_tuple,name=name)
   1409   must_pack = False
   1410   converted_elems = []

/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_array_ops.py in pack(values,axis,name)
   6457       return _result
   6458     except _core._NotOkStatusException as e:
-> 6459       _ops.raise_from_not_ok_status(e,name)
   6460     except _core._FallbackException:
   6461       pass

/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e,name)
   6841   message = e.message + (" name: " + name if name is not None else "")
   6842   # pylint: disable=protected-access
-> 6843   six.raise_from(core._status_to_exception(e.code,message),None)
   6844   # pylint: enable=protected-access
   6845 

/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value,from_value)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...