TensorFlow学习记录:saved_model模块的用法

TensorFlow中的saved_model模块用于生成冻结图文件,并且saved_model模块封装了平常用的Saver类。与Saver类不同的是,saved_model模块生成的模型文件集成了打标签的操作,可以更方便地部署在生产环境中。

关于为什么要用saved_model模块,这篇文章讲得挺好的。请点击这里

一个saved_model对象可以存储一个或多个MetaGraphDef。那什么时候需要多个MetaGraphDef呢?也许你想同时保存模型cpu版本和GPU版本,或者你想同时保存模型的开发版本和生产版本。这个时候你就可以用tag(标签)来区分它们了。在加载模型的时候能根据tag标签来加载不同的MetaGraphDef。

TensorFlow中的saved_model模块可以给MetaGraphDef添加多个签名(signature)。每个签名的的结构都由输入节点、输出节点、名字3部分组成。并且,输入节点,输出节点的名字可以任意指定。

1.导出带有签名的模型文件

假设之前训练了一个模型,让模型在一组混乱的数据中找到y≈2x的规律。其中
(1)用saved_model模块的builder.SavedModelBuilder类实例一个builder对象。
(2)构建签名的输入节点inputs。该输入节点的名字为“input_x”。该名字是模型文件中输入节点的名字(可以任意取名)。
(3)构建标签输出节点outputs。该输出节点的名字为“output”。
(4)调用build_signature_def函数,并将输入节点、输出节点和名字(sig_name)传入,生成一个签名对象。
(5)用builder对象的add_Meta_graph_and_variables方法将签名加入到模型中。
(6)调用builder对象的save方法导出带有签名的模型文件

代码如下:

from tensorflow.python.saved_model import tag_constants
	#saveddir+"tfservingmodel"为模型的保存路径
    builder = tf.saved_model.builder.SavedModelBuilder(savedir+'tfservingmodel')
    
    #定义输入签名,X为输入tensor
    inputs = {'input_x': tf.saved_model.utils.build_tensor_info(X)}
    #定义输出签名, z为最终需要的输出结果tensor 
    outputs = {'output' : tf.saved_model.utils.build_tensor_info(z)}
    #调用build_signature_def()函数,并将输入节点、输出节点和名字(sig_name)传入,生成具体的签名对象
    signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'sig_name')
    
    #将节点的定义和值加到builder中,同时还加入了tag标签(tag_constants.SERVING), 还可以使用TRAINING、GPU或自定义
    builder.add_Meta_graph_and_variables(sess, [tag_constants.SERVING], {'my_signature':signature})
    builder.save()     

运行后,会生成如下图所示文件


其中variables文件里的内容如下所示

在这里插入图片描述


从第一张图可以看出,tfservingmodel文件夹包含了一个文件一个文件夹,文件save_model.pb是模型的定义文件文件夹variables中放置了具体的模型文件
从第二张图可以看出,variables文件夹包含了两个模型文件,variables.data-00000-of-00001文件保存了模型中参数的值,variables.index文件保存了模型中节点符号的定义。

我们可以看下saved_model.pb文件中保存的张量名字和属性

import tensorflow  as tf
from tensorflow import saved_model as sm

model_path = "log/tfservingmodel"

with tf.Session()as sess:
	Meta_graph_def = tf.saved_model.loader.load(sess,[sm.tag_constants.SERVING],model_path)
	
	op_list = sess.graph.get_operations()   #load完后可以直接从sess.graph中获取所有节点
	with open("operations.txt",'a+')as f:
		for index,op in enumerate(op_list):
			f.write(str(op.name)+"\n")
			f.write(str(op.values())+"\n")

运行结果截图(部分):

在这里插入图片描述

2.根据tag标签导入模型文件,并根据签名找到网络节点

导入刚刚保存的模型
(1)用saved_model模块中的loader.load方法根据tag标签导入对应的模型文件
(2)用signature_def方法从导入的模型中提取签名。
(3)以字典取值的方式取出输入、输出节点。
(4)向模型注入数据,并输出结果。

代码如下:

from tensorflow.python.saved_model import tag_constants
with tf.Session() as sess:
	#根据tag_constants.SERVING标签找到对应的计算图
    Meta_graph_def = tf.saved_model.loader.load(sess, [tag_constants.SERVING], savedir+'tfservingmodel')
    # 从Meta_graph_def中取出SignatureDef对象
    signature = Meta_graph_def.signature_def
    
    # 从signature中找出具体输入输出的tensor name 
    x = signature['my_signature'].inputs['input_x'].name
    result = signature['my_signature'].outputs['output'].name

    y = sess.run(result, Feed_dict={x: 5})#传入5,进行预测
    print(y)       

运行结果:

在这里插入图片描述

3.用saved_model_cli工具查看及使用saved_model模型

在命令行中,用saved_model_cli工具查看和使用生成的saved_model模型。具体内容如下 :
(1)找出tag标签对应的MetaGraphDef。
(2)找出MetaGraphDef中的signature、输入、输出节点等相关信息。
(3)以命令行的方式向模型输入数据,使其运行并输出结果。

saved_model_cli工具工具共有两个主要的参数:

  • show参数:侧重于查看模型中的信息。
  • run参数:侧重于运行模型。

3.1查看模型文件中的tag(标签

saved_model_cli show --dir log/tfservingmodel

运行结果:

在这里插入图片描述


我们可以看到输出结果为serve,表明SavedModel对象里面只有一个MetaGraphDef,这个serve对应于tag_constants.SERVING。

3.2查看serve对应的MetaGraphDef中的签名

saved_model_cli show --dir log/tfservingmodel --tag_set serve

运行结果:

在这里插入图片描述


我们可以看到输出结果为SignatureDef Key:“my_signature”,表明serve对应的MetaGraphDef中有一个签名为"my_signature",与上面1中生成带有签名时的一致。

3.3查看signature中定义的输入、输出节点的名称

saved_model_cli show --dir log/tfservingmodel --tag_set serve --signature_def my_signature

运行结果:

在这里插入图片描述


我们可以看到,模型的输入节点的张量为input_x,输出节点的张量为output。

上面的内容可以用saved_model_cli工具中的“–all”参数查看模型文件中的全部信息。

saved_model_cli show --dir log/tfservingmodel --all

运行结果:

在这里插入图片描述

4.用run参数运行模型

用saved_model_cli 工具的run参数时,需要先指定好模型的路径、tag(标签)及signature(签名),再往模型里面输入数据,并运行。
在输入数据部分,可以用参数来指定不同的输入方式。

  • —inputs:后面跟具体的文件文件类型支持numpy文件(npy、npz)和pickle文件(plk)。
  • —input_exprs:指定某个变量,向模型注入数据。
  • —input_examples:用字典方式向模型注入数据。

以“–input_exprs”为例,具体命令如下:

saved_model_cli run --dir log/tfservingmodel --tag_set serve --signature_def my_signature --input_exprs"input_x=4.2"

运行结果:

在这里插入图片描述

参考书籍:《深度学习之TensorFlow工程化项目实战》 李金洪 编著

相关文章

MNIST数据集可以说是深度学习的入门,但是使用模型预测单张M...
1、新建tensorflow环境(1)打开anacondaprompt,输入命令行...
这篇文章主要介绍“张量tensor是什么”,在日常操作中,相信...
tensorflow中model.fit()用法model.fit()方法用于执行训练过...
https://blog.csdn.net/To_be_little/article/details/12443...
根据身高推测体重const$=require('jquery');const...