在TF 1.14和TF 2.1下恢复张量流模型问题

问题描述

在TF 1.14和TF 2.1之间使用tf.saved_model.simple_save和tf.saved_model.load时遇到了一些麻烦

如您所见,我附上了代码

我想看看权重(W),其值必须是节省时间时初始化的状态。

在TF 2.1下, 保存和加载tensorflow模型(pb文件)没有问题。 保存后加载时,我能够识别出相同的重量(W)值

但是,当我使用TF 1.14时, 保存模型还可以..但是,当我加载保存的模型时,结果不是我期望的。 看来tf.saved_model.load无法加载节省的重量,只能随机初始化。

我附上了下面的代码, 您可以通过切换TF_VERSION = 2.1和1.14,SAVE = True和False来运行

TF_VERSION = 2.1
# TF_VERSION = 1.14
SAVE = False

model_dir_path = "./pb"


if TF_VERSION == 1.14:
    import tensorflow as tf
else:
    import tensorflow.compat.v1 as tf
    tf.compat.v1.disable_eager_execution()



X = tf.placeholder(tf.float32,shape=[None,2],name='input')

# weight
weight_initer = tf.truncated_normal_initializer(mean=0.0,stddev=0.01)
W = tf.get_variable(name="Weight",dtype=tf.float32,shape=[2,1],initializer=weight_initer)

# bias
bias_initer = tf.constant(0.,shape=[1],dtype=tf.float32)
b = tf.get_variable(name="Bias",initializer=bias_initer)

x_w = tf.matmul(X,W,name="MatMul")
x_w_b = tf.add(x_w,b,name="Add")


#save
if SAVE:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        x_batch = [[2,[3,5]]

        Feed_dict = {X: x_batch}
        output = sess.run(x_w_b,Feed_dict=Feed_dict)

        tf.saved_model.simple_save(sess,model_dir_path,inputs={"inputs": X},outputs={"outputs": W})


# restore
else:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        tf.saved_model.load(sess,[tf.saved_model.tag_constants.SERVING],model_dir_path)

        x_batch = [[2,5]]

        Feed_dict = {X: x_batch}
        weight = sess.run(W,Feed_dict=Feed_dict)

        print(weight)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)