代码之家  ›  专栏  ›  技术社区  ›  Karnivaurus

将“tf.常量”与整数进行比较

  •  1
  • Karnivaurus  · 技术社区  · 6 年前

    在TensorFlow中,我有一个 tf.while_loop ,其中 body 参数定义为以下函数:

    def loop_body(step_num, x):
        if step_num == 0:
            x += 1
        else:
            x += 2
        step_num = tf.add(step_num, 1)
        return step_num, x
    

    问题是这条线 step_num == 0 永远不会 True 即使初始值为 step_num 0 . 我假设这是因为 步进 不是整数,但实际上是 tf.constant 在循环外定义: step_num = tf.constant(0) . 所以我比较了 TF-常数 一个python整数,它将 False .

    我应该用什么来代替这个比较呢?

    1 回复  |  直到 6 年前
        1
  •  3
  •   giser_yugang    6 年前

    第一种方法:使用 tf.cond 以下内容:

    def loop_body(step_num, x):
        x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
        step_num = tf.add(step_num, 1)
        return step_num, x
    

    第二种方法:使用 autograph :

    from tensorflow.contrib import autograph as ag
    ag.to_graph(loop_body2)(step_num, x)
    

    一个例子:

    import tensorflow as tf
    from tensorflow.contrib import autograph as ag
    
    def loop_body(step_num, x):
        x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
        step_num = tf.add(step_num, 1)
        return step_num, x
    
    def loop_body2(step_num, x):
        if step_num == 0:
            x += 1
        else:
            x += 2
        step_num = tf.add(step_num, 1)
        return step_num, x
    
    step_num = tf.constant(0)
    x = tf.constant(2)
    result1 = loop_body(step_num, x)
    result2 = ag.to_graph(loop_body2)(step_num, x)
    
    with tf.Session() as sess:
        print(sess.run(result1))
        print(sess.run(result2))
    
    #print 
    (1, 3)
    (1, 3)