问题描述
在TF 1.14和TF 2.1之间使用tf.saved_model.simple_save和tf.saved_model.load时遇到了一些麻烦
如您所见,我附上了代码,
我想看看权重(W),其值必须是节省时间时初始化的状态。
在TF 2.1下, 保存和加载tensorflow模型(pb文件)没有问题。 保存后加载时,我能够识别出相同的重量(W)值
但是,当我使用TF 1.14时, 保存模型还可以..但是,当我加载保存的模型时,结果不是我期望的。 看来tf.saved_model.load无法加载节省的重量,只能随机初始化。
我附上了下面的代码, 您可以通过切换TF_VERSION = 2.1和1.14,SAVE = True和False来运行
TF_VERSION = 2.1
# TF_VERSION = 1.14
SAVE = False
model_dir_path = "./pb"
if TF_VERSION == 1.14:
import tensorflow as tf
else:
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
X = tf.placeholder(tf.float32,shape=[None,2],name='input')
# weight
weight_initer = tf.truncated_normal_initializer(mean=0.0,stddev=0.01)
W = tf.get_variable(name="Weight",dtype=tf.float32,shape=[2,1],initializer=weight_initer)
# bias
bias_initer = tf.constant(0.,shape=[1],dtype=tf.float32)
b = tf.get_variable(name="Bias",initializer=bias_initer)
x_w = tf.matmul(X,W,name="MatMul")
x_w_b = tf.add(x_w,b,name="Add")
#save
if SAVE:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_batch = [[2,[3,5]]
Feed_dict = {X: x_batch}
output = sess.run(x_w_b,Feed_dict=Feed_dict)
tf.saved_model.simple_save(sess,model_dir_path,inputs={"inputs": X},outputs={"outputs": W})
# restore
else:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.load(sess,[tf.saved_model.tag_constants.SERVING],model_dir_path)
x_batch = [[2,5]]
Feed_dict = {X: x_batch}
weight = sess.run(W,Feed_dict=Feed_dict)
print(weight)
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)