代码之家  ›  专栏  ›  技术社区  ›  Allen Qin

Tensorflow:故障排除tf。估计器。输入。numpy\u input\u fn函数

  •  4
  • Allen Qin  · 技术社区  · 7 年前

    text classification

    test_input_fn = tf.estimator.inputs.numpy_input_fn(
      x={WORDS_FEATURE: x_test},
      y=y_test,
      num_epochs=1,
      shuffle=False)
    classifier.train(input_fn=train_input_fn, steps=100)
    

    我跟踪了代码,发现train\u input\u fn函数将数据馈送到以下2个变量:

    features
    Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>}
    
    labels
    Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32>
    

    当我试图通过sess评估特征变量时。运行(功能),我的终端似乎卡住并停止响应。

    检查此类变量内容的正确方法是什么?

    非常感谢。

    1 回复  |  直到 7 年前
        1
  •  3
  •   DomJack    7 年前

    numpy_input_fn documentation 我认为底层实现的行为(挂起)取决于队列运行器。队列运行程序未启动时会发生挂起。尝试根据以下内容修改会话运行脚本: this guide

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            for step in xrange(1000000):
                if coord.should_stop():
                    break
                features_data = sess.run(features)
                print(features_data)
    
        except Exception, e:
            # Report exceptions to the coordinator.
            coord.request_stop(e)
        finally:
            # Terminate as usual. It is safe to call `coord.request_stop()` twice.
            coord.request_stop()
            coord.join(threads)
    

    或者,我鼓励您查看 tf.data.Dataset 接口(可能 tf.contrib.data.Dataset Dataset.from_tensor_slices . 创建要稍微复杂一些,但接口要灵活得多,并且实现不使用队列运行器,这意味着会话运行要简单得多。

    import tensorflow as tf
    import numpy as np
    
    x_data = np.random.random((100000, 2))
    y_data = np.random.random((100000,))
    
    batch_size = 2
    buff = 100
    
    
    def input_fn():
        # possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
        dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
        dataset = dataset.repeat().shuffle(buff).batch(batch_size)
        x, y = dataset.make_one_shot_iterator().get_next()
        return x, y
    
    
    x, y = input_fn()
    with tf.Session() as sess:
        print(sess.run([x, y]))