问题描述
我想从TensorFlow对象检测api测试TF1模型库中提供的mobiledet模型。 tf1 object detection model zoo
因为预训练的文件同时包含pb文件和ckpt文件the Screenshot of ckpt files。 因此,我尝试了两种方法来加载预训练的模型以进行推断。
首先,我尝试直接加载tflite_graph.pb。遇到以下问题,我尝试更改tf版本,但仍然无法解决。
代码如下:
MODEL_DIR = '/tf_ckpts/ssdlite_mobiledet_cpu_320x320_coco_2020_05_19/'
MODEL_CHECK_FILE = os.path.join(MODEL_DIR,'tflite_graph.pb')
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.Open(MODEL_CHECK_FILE,'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def,name='')
Traceback (most recent call last):
File "/home/zhaoxin/workspace/models-1.12.0/research/inference_demo.py",line 41,in <module>
tf.import_graph_def(graph_def,name='')
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py",line 507,in new_func
return func(*args,**kwargs)
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py",line 405,in import_graph_def
producer_op_list=producer_op_list)
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/importer.py",line 505,in _import_graph_def_internal
raise ValueError(str(e))
ValueError: NodeDef mentions attr 'exponential_avg_factor' not in Op<name=FusedBatchnormV3; signature=x:T,scale:U,offset:U,mean:U,variance:U -> y:T,batch_mean:U,batch_variance:U,reserve_space_1:U,reserve_space_2:U,reserve_space_3:U; attr=T:type,allowed=[DT_HALF,DT_BFLOAT16,DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC","NCHW"]; attr=is_training:bool,default=true>; NodeDef: {{node FeatureExtractor/MobileDetcpu/Conv/Batchnorm/FusedBatchnormV3}}. (Check whether your GraphDef-interpreting binary is up to date with your GraphDef-generating binary.).
然后,我尝试加载ckpt文件以运行模型。
mobiledet = 'tf_ckpts/ssdlite_mobiledet_cpu_320x320_coco_2020_05_19/'
Meta_path = mobiledet+'model.ckpt-400000.Meta'
ckpt_path = mobiledet+'model.ckpt-400000'
with tf.Session() as sess:
saver=tf.train.import_Meta_graph(Meta_path)
saver.restore(sess,ckpt_path)
graph = tf.get_default_graph()
这样的错误:
Traceback (most recent call last):
File "/home/zhaoxin/workspace/models-1.12.0/research/tf_load.py",line 15,in <module>
saver=tf.train.import_Meta_graph(Meta_path)
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py",line 1453,in import_Meta_graph
**kwargs)[0]
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/training/saver.py",line 1477,in _import_Meta_graph_with_return_elements
**kwargs))
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/framework/Meta_graph.py",line 809,in import_scoped_Meta_graph_with_return_elements
return_elements=return_elements)
File "/home/zhaoxin/tools/miniconda3/envs/tf115/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py",line 501,in _import_graph_def_internal
graph._c_graph,serialized,options) # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'LegacyParallelInterleaveDatasetV2' in binary running on localhost.localdomain. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib,accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph,as contrib ops are lazily registered when the module is first accessed.
上面两种方法的加载错误似乎是由tf版本的不一致引起的,但是我尝试了很多tf版本,但未能解决。有没有人在TF1对象检测模型Zoo中成功运行mobiledet模型?
操作系统:linux
TF版本:tf 1.15
解决方法
@Shane Zhao-您打算使用自定义数据集进行训练还是按原样使用预训练图?据我所知,Tensorflow的版本仅在培训期间才有意义。无论如何,请参考Colab中来自Google的此演示-std::tie