如何在tensorflow中保存和恢复变量?
我遇到了问题.我的代码:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
save_path = saver.save(sess, 'model.ckpt')
print "model saved in file:", save_path
v1 = v1 + 1
print sess.run(v1)
saver = tf.train.import_Meta_graph('model.ckpt.Meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run(v1)
结果:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 1. 1.]
[ 1. 1.]]
我希望得到:
[[ 0. 0.]
[ 0. 0.]]
[[ 1. 1.]
[ 1. 1.]]
[[ 0. 0.]
[ 0. 0.]]
我犯了什么错误?
请帮我理解.
解决方法:
您的代码中有两个主要问题:
>行v1 = v1 1创建一个新的TensorFlow Tensor并将其绑定到Python变量v1,但不会更改使用名称“v1”创建的TensorFlow变量中的值.因此,当您稍后调用sess.run(v1)时,您正在评估将原始变量加1的新张量,而不是从张量中读取值.
increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)
> tf.train.import_meta_graph()
调用重新创建原始图形,并在此过程中向图形中添加新节点,包括新的tf.train.Saver
.当您尚未构建图形(或者没有程序可用于此图形)时,它非常有用.由于您已经构建了图形,因此只需要使用saver.restore(sess,’model.ckpt’).
以下程序应该产生您预期的行为:
import tensorflow as tf
v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(v1)
save_path = saver.save(sess, './model.ckpt')
print "model saved in file:", save_path
# Create an op to increment v1, run it, and print the result.
increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)
print sess.run(v1)
# Restore from the checkpoint saved above.
saver.restore(sess, './model.ckpt')
print sess.run(v1)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。