第一种方法:使用
tf.cond
以下内容:
def loop_body(step_num, x):
x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
step_num = tf.add(step_num, 1)
return step_num, x
第二种方法:使用
autograph
:
from tensorflow.contrib import autograph as ag
ag.to_graph(loop_body2)(step_num, x)
一个例子:
import tensorflow as tf
from tensorflow.contrib import autograph as ag
def loop_body(step_num, x):
x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
step_num = tf.add(step_num, 1)
return step_num, x
def loop_body2(step_num, x):
if step_num == 0:
x += 1
else:
x += 2
step_num = tf.add(step_num, 1)
return step_num, x
step_num = tf.constant(0)
x = tf.constant(2)
result1 = loop_body(step_num, x)
result2 = ag.to_graph(loop_body2)(step_num, x)
with tf.Session() as sess:
print(sess.run(result1))
print(sess.run(result2))
#print
(1, 3)
(1, 3)