我正在学习如何使用TensorFlow数据API,并努力理解映射是如何工作的。对于上下文,我想加载一组图像并将它们发送到一个神经网络。
下面的MWE就是这么做的(10号的假数据集,
read_image
映射到数据集的函数)。
import tensorflow as tf
import numpy as np
def read_image(filename, label):
return np.random.rand(8, 8, 1), label # simulate data load (generate random data)
# generate fake dataset of filenames (of size 10)
filenames = tf.constant(np.asarray(["file" + str(i) for i in range(10)]))
labels = tf.constant(np.asarray([2*i for i in range(10)]))
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(read_image)
dataset = dataset.repeat().batch(2)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
X, y = iterator.get_next()
train_init_op = iterator.make_initializer(dataset)
with tf.Session() as session:
tf.global_variables_initializer().run()
session.run(train_init_op)
for _ in range(10):
print(session.run([X]))
运行此代码时(应该不做任何操作,只需打印
读出图像
,结果总是相同的数据:
读出图像
只调用一次。为什么?我用过
dataset.map
,不是应该对数据集的每个元素(此处为10)调用它吗?
提前感谢您的帮助或建议。