代码之家  ›  专栏  ›  技术社区  ›  DocDriven

TensorFlow图中还原变量的错误输出

  •  3
  • DocDriven  · 技术社区  · 6 年前

    我目前正在忙于保存和恢复变量。为此,我创建了两个脚本。其中一个保存一个简单的图形,而另一个则恢复它。这里是保存图表的测试脚本:

    import tensorflow as tf
    
    a = tf.Variable(3.0, name='a')
    b = tf.Variable(5.0, name='b')
    
    b = tf.assign_add(b, a)
    
    n_steps = 5
    
    global_step = tf.Variable(0, name='global_step', trainable=False)
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
    
        sess.run(tf.global_variables_initializer())
    
        for step in range(n_steps):
            print(sess.run(b))
    
            global_step.assign_add(1).eval()
            print(global_step.eval())
    
            saver.save(sess, './my_test_model', global_step=global_step)
    

    基本上,我想运行5次循环,每次我这样做,我补充说 a b . 我还想通过 global_step . 这是按预期工作的。输出为:

    8.0     # value of b
    1       # step
    11.0
    2
    14.0
    3
    17.0
    4
    20.0
    5
    

    现在,当恢复变量时,我尝试获取所有三个变量。脚本是:

    import tensorflow as tf
    
    from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
    
    # List ALL tensors.
    print_tensors_in_checkpoint_file(tf.train.latest_checkpoint('./'), all_tensors=True, tensor_name='')
    
    tf.reset_default_graph()
    
    a = tf.get_variable('a', shape=[])
    b = tf.get_variable('b', shape=[])
    global_step = tf.get_variable('global_step', shape=[])
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
    
        ckpt = tf.train.latest_checkpoint('./')
        if ckpt:
            print(ckpt)
    
            saver.restore(sess, ckpt)
    
        else:
            print('Nothing restored')
    
        print(a.eval())
        print(b.eval())
        print(global_step.eval())
    

    它的输出是

    tensor_name:  a
    3.0
    tensor_name:  b
    20.0
    tensor_name:  global_step
    5
    ./my_test_model-5
    INFO:tensorflow:Restoring parameters from ./my_test_model-5
    3.0
    20.0
    7e-45
    

    全局步骤的值是如何正确存储在检查点中的,但是经过评估,我得到的值很小 7E-45 ?而且,在恢复之后,我似乎无法定义任何附加变量,因为它声明它在检查点中找不到该变量。例如,如何定义变量并将其添加到 恢复的图形?

    谢谢你的帮助!

    1 回复  |  直到 6 年前
        1
  •  2
  •   kww    6 年前

    这似乎没有被tf文档很好地记录,但是您应该为 global_step 变量。

    不正确

    global_step = tf.get_variable('global_step', shape=[], dtype=tf.float32) 结果 global_step=7e-5 . 默认情况下,类型假定为dtf.float32。

    对的

    global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32) 结果 global_step=5