如何通过TensorFlow对象检测API在TF1模型Zoo中与预训练模型一起成功运行mobiledet模型?

问题描述

我想从TensorFlow对象检测api测试TF1模型库中提供的mobiledet模型。 tf1 object detection model zoo

因为预训练的文件同时包含pb文件和ckpt文件the Screenshot of ckpt files。 因此,我尝试了两种方法来加载预训练的模型以进行推断。

首先,我尝试直接加载tflite_graph.pb。遇到以下问题,我尝试更改tf版本,但仍然无法解决

代码如下:

MODEL_DIR = '/tf_ckpts/ssdlite_mobiledet_cpu_320x320_coco_2020_05_19/'
MODEL_CHECK_FILE = os.path.join(MODEL_DIR,'tflite_graph.pb')
graph = tf.Graph()
with graph.as_default():
    graph_def = tf.GraphDef()
    with tf.gfile.Open(MODEL_CHECK_FILE,'rb') as f:
        graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')
Traceback (most recent call last):
  File "/home/zhaoxin/workspace/models-1.12.0/research/inference_demo.py",line 41,in <module>
    tf.import_graph_def(graph_def,name='')
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py",line 507,in new_func
    return func(*args,**kwargs)
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py",line 405,in import_graph_def
    producer_op_list=producer_op_list)
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py",line 505,in _import_graph_def_internal
    raise ValueError(str(e))
ValueError: NodeDef mentions attr 'exponential_avg_factor' not in Op<name=FusedBatchnormV3; signature=x:T,scale:U,offset:U,mean:U,variance:U -> y:T,batch_mean:U,batch_variance:U,reserve_space_1:U,reserve_space_2:U,reserve_space_3:U; attr=T:type,allowed=[DT_HALF,DT_BFLOAT16,DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC","NCHW"]; attr=is_training:bool,default=true>; NodeDef: {{node FeatureExtractor/MobileDetcpu/Conv/Batchnorm/FusedBatchnormV3}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).

然后,我尝试加载ckpt文件以运行模型。

mobiledet = 'tf_ckpts/ssdlite_mobiledet_cpu_320x320_coco_2020_05_19/'
Meta_path = mobiledet+'model.ckpt-400000.Meta'
ckpt_path = mobiledet+'model.ckpt-400000'

with tf.Session() as sess:
    saver=tf.train.import_Meta_graph(Meta_path)
    saver.restore(sess,ckpt_path)
    graph = tf.get_default_graph()

这样的错误

Traceback (most recent call last):
  File "/home/zhaoxin/workspace/models-1.12.0/research/tf_load.py",line 15,in <module>
    saver=tf.train.import_Meta_graph(Meta_path)
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py",line 1453,in import_Meta_graph
    **kwargs)[0]
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py",line 1477,in _import_Meta_graph_with_return_elements
    **kwargs))
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/Meta_graph.py",line 809,in import_scoped_Meta_graph_with_return_elements
    return_elements=return_elements)
  File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py",line 501,in _import_graph_def_internal
    graph._c_graph,serialized,options)  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'LegacyParallelInterleaveDatasetV2' in binary running on localhost.localdomain. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib,accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph,as contrib ops are lazily registered when the module is first accessed.

上面两种方法的加载错误似乎是由tf版本的不一致引起的,但是我尝试了很多tf版本,但未能解决。有没有人在TF1对象检测模型Zoo中成功运行mobiledet模型?

操作系统:linux

TF版本:tf 1.15

解决方法

@Shane Zhao-您打算使用自定义数据集进行训练还是按原样使用预训练图?据我所知,Tensorflow的版本仅在培训期间才有意义。无论如何,请参考Colab中来自Google的此演示-std::tie