numpy_input_fn
documentation
我认为底层实现的行为(挂起)取决于队列运行器。队列运行程序未启动时会发生挂起。尝试根据以下内容修改会话运行脚本:
this guide
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
for step in xrange(1000000):
if coord.should_stop():
break
features_data = sess.run(features)
print(features_data)
except Exception, e:
# Report exceptions to the coordinator.
coord.request_stop(e)
finally:
# Terminate as usual. It is safe to call `coord.request_stop()` twice.
coord.request_stop()
coord.join(threads)
或者,我鼓励您查看
tf.data.Dataset
接口(可能
tf.contrib.data.Dataset
Dataset.from_tensor_slices
. 创建要稍微复杂一些,但接口要灵活得多,并且实现不使用队列运行器,这意味着会话运行要简单得多。
import tensorflow as tf
import numpy as np
x_data = np.random.random((100000, 2))
y_data = np.random.random((100000,))
batch_size = 2
buff = 100
def input_fn():
# possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier
dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
dataset = dataset.repeat().shuffle(buff).batch(batch_size)
x, y = dataset.make_one_shot_iterator().get_next()
return x, y
x, y = input_fn()
with tf.Session() as sess:
print(sess.run([x, y]))