尝试运行以下基本示例来运行条件计算我收到以下错误消息:
‘x’ was passed float incompatible with expected float_ref
什么是tensorflow float_ref以及如何修改代码?
import tensorflow as tf
from tensorflow.python.ops.control_flow_ops import cond
a = tf.Variable(tf.constant(0.),name="a")
b = tf.Variable(tf.constant(0.),name="b")
x = tf.Variable(tf.constant(0.),name="x")
def add():
x.assign( a + b)
return x
def last():
return x
calculate= cond(x==0.,add,last)
with tf.Session() as s:
val = s.run([calculate], {a: 1., b: 2., x: 0.})
print(val) # 3
val=s.run([calculate],{a:4.,b:5.,x:val})
print(val) # 3
解决方法:
这并没有解释float_ref是什么,但它修复了问题:
1)需要在会话中创建变量
2)任务不是我们所期望的
这个固定代码有效:
def add():
print("add")
x = a + b
return x
def last():
print("last")
return x
with tf.Session() as s:
a = tf.Variable(tf.constant(0.),name="a")
b = tf.Variable(tf.constant(0.),name="b")
x = tf.constant(-1.)
calculate= cond(x.eval()==-1.,add,last)
val = s.run([calculate], {a: 1., b: 2.})
print(val) # 3
print(s.run([calculate],{a:3.,b:4.})) # 7
print(val) # 3
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 [email protected] 举报,一经查实,本站将立刻删除。