微信公众号搜"智元新知"关注
微信扫一扫可直接关注哦!

python – Tensorflow:保存和恢复变量问题

如何在tensorflow中保存和恢复变量?

我遇到了问题.我的代码

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    save_path = saver.save(sess, 'model.ckpt')
    print "model saved in file:", save_path
    v1 = v1 + 1
    print sess.run(v1)
    saver = tf.train.import_Meta_graph('model.ckpt.Meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    print sess.run(v1)

结果:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 1.  1.]
 [ 1.  1.]]

我希望得到:

[[ 0.  0.]
 [ 0.  0.]]

[[ 1.  1.]
 [ 1.  1.]]

[[ 0.  0.]
 [ 0.  0.]]

我犯了什么错误

请帮我理解.

解决方法:

您的代码中有两个主要问题:

>行v1 = v1 1创建一个新的TensorFlow Tensor并将其绑定到Python变量v1,但不会更改使用名称“v1”创建的TensorFlow变量中的值.因此,当您稍后调用sess.run(v1)时,您正在评估将原始变量加1的新张量,而不是从张量中读取值.

相反,要将变量添加到变量,您应该使用以下内容

increment_op = v1.assign_add(tf.ones([2, 2]))
sess.run(increment_op)

> tf.train.import_meta_graph()调用重新创建原始图形,并在此过程中向图形中添加新节点,包括新的tf.train.Saver.当您尚未构建图形(或者没有程序可用于此图形)时,它非常有用.由于您已经构建了图形,因此只需要使用saver.restore(sess,’model.ckpt’).

以下程序应该产生您预期的行为:

import tensorflow as tf

v1 = tf.Variable(tf.zeros([2, 2], dtype=tf.float32, name='v1'))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(v1)
    save_path = saver.save(sess, './model.ckpt')
    print "model saved in file:", save_path

    # Create an op to increment v1, run it, and print the result.   
    increment_op = v1.assign_add(tf.ones([2, 2]))
    sess.run(increment_op)
    print sess.run(v1)

    # Restore from the checkpoint saved above.
    saver.restore(sess, './model.ckpt')
    print sess.run(v1)

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。

相关推荐