Tensorflow对象检测eager_few_shot_od_training_tf2分类头错误

问题描述

我正在尝试像official tutorial in tf object detection API repo那样在具有三个类别的自定义数据集上训练对象检测模型,但是由于它们具有一个具有一个类别的数据集,因此它们不恢复模型的分类头。他们建议取消注释一行,如您在此代码片段中所见。但是,当我取消注释时,会出现错误。感谢您的帮助。

# Set up object-based checkpoint restore --- RetinaNet has two prediction
# `heads` --- one for classification,the other for Box regression.  We will
# restore the Box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_Box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._Box_predictor._base_tower_layers_for_heads,_prediction_heads=detection_model._Box_predictor._prediction_heads,#I uncommented this line
    #    (i.e.,the classification head that we *will not* restore)
    _Box_prediction_head=detection_model._Box_predictor._Box_prediction_head,)
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,_Box_predictor=fake_Box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

# Run model through a dummy image so that variables are created
image,shapes = detection_model.preprocess(tf.zeros([1,640,3]))
prediction_dict = detection_model.predict(image,shapes)
_ = detection_model.postprocess(prediction_dict,shapes)
print('Weights restored!')

这是在运行笔记本电脑并取消注释该行后出现的错误

ValueError                                Traceback (most recent call last)
    <ipython-input-7-96e77f9f8468> in <module>
         24 
         25 image,3]))
    ---> 26 prediction_dict = detection_model.predict(image,shapes)
         27 _ = detection_model.postprocess(prediction_dict,shapes)
         28 print('Weights restored!')


C:\Python\lib\site-packages\object_detection\Meta_architectures\ssd_Meta_arch.py in predict(self,preprocessed_inputs,true_image_shapes)
    589     self._anchors = Box_list_ops.concatenate(Boxlist_list)
    590     if self._Box_predictor.is_keras_model:
--> 591       predictor_results_dict = self._Box_predictor(feature_maps)
    592     else:
    593       with slim.arg_scope([slim.batch_norm],C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self,*args,**kwargs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):
--> 985           outputs = call_fn(inputs,**kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\core\Box_predictor.py in call(self,image_features,**kwargs)
    200           feature map in the input `image_features` list.
    201     """
--> 202     return self._predict(image_features,**kwargs)
    203 
    204   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\convolutional_keras_Box_predictor.py in _predict(self,**kwargs)
    482               self._base_tower_layers_for_heads[head_name][index],483               image_feature)
--> 484         prediction = head_obj(head_tower_feature)
    485         predictions[head_name].append(prediction)
    486     return predictions

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self,**kwargs)
    986 
    987         if self._activity_regularizer:

C:\Python\lib\site-packages\object_detection\predictors\heads\head.py in call(self,features)
     67   def call(self,features):
     68     """The Keras model call will delegate to the `_predict` method."""
---> 69     return self._predict(features)
     70 
     71   @abstractmethod

C:\Python\lib\site-packages\object_detection\predictors\heads\keras_class_head.py in _predict(self,features)
    339     for layer in self._class_predictor_layers:
    340       class_predictions_with_background = layer(
--> 341           class_predictions_with_background)
    342     batch_size = features.get_shape().as_list()[0]
    343     if batch_size is None:

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in __call__(self,**kwargs)
    980       with ops.name_scope_v2(name_scope):
    981         if not self.built:
--> 982           self._maybe_build(inputs)
    983 
    984         with ops.enable_auto_cast_variables(self._compute_dtype_object):

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in _maybe_build(self,inputs)
   2641         # operations.
   2642         with tf_utils.maybe_init_scope(self):
-> 2643           self.build(input_shapes)  # pylint:disable=not-callable
   2644       # We must set also ensure that the layer is marked as built,and the build
   2645       # shape is stored since user defined build functions may not be calling

C:\Python\lib\site-packages\tensorflow\python\keras\layers\convolutional.py in build(self,input_shape)
    202         constraint=self.kernel_constraint,203         trainable=True,--> 204         dtype=self.dtype)
    205     if self.use_bias:
    206       self.bias = self.add_weight(

C:\Python\lib\site-packages\tensorflow\python\keras\engine\base_layer.py in add_weight(self,name,shape,dtype,initializer,regularizer,trainable,constraint,partitioner,use_resource,synchronization,aggregation,**kwargs)
    612         synchronization=synchronization,613         aggregation=aggregation,--> 614         caching_device=caching_device)
    615     if regularizer is not None:
    616       # Todo(fchollet): in the future,this should be handled at the

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _add_variable_with_custom_getter(self,getter,overwrite,**kwargs_for_getter)
    729         # there is nothing to restore.
    730         checkpoint_initializer = self._preload_simple_restoration(
--> 731             name=name,shape=shape)
    732       else:
    733         checkpoint_initializer = None

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in _preload_simple_restoration(self,shape)
    796         key=lambda restore: restore.checkpoint.restore_uid)
    797     return CheckpointinitialValue(
--> 798         checkpoint_position=checkpoint_position,shape=shape)
    799 
    800   def _track_trackable(self,trackable,overwrite=False):

C:\Python\lib\site-packages\tensorflow\python\training\tracking\base.py in __init__(self,checkpoint_position,shape)
     73       # We need to set the static shape information on the initializer if
     74       # possible so we don't get a variable with an unkNown shape.
---> 75       self.wrapped_value.set_shape(shape)
     76     self._checkpoint_position = checkpoint_position
     77 

C:\Python\lib\site-packages\tensorflow\python\framework\ops.py in set_shape(self,shape)
   1207       raise ValueError(
   1208           "Tensor's shape %s is not compatible with supplied shape %s" %
-> 1209           (self.shape,shape))
   1210 
   1211   # Methods not supported / implemented for Eager Tensors.

ValueError: Tensor's shape (3,3,256,546) is not compatible with supplied shape (3,24)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)