代码之家  ›  专栏  ›  技术社区  ›  MeanStreet

如何正确使用TensorFlow数据API中的dataset.map

  •  1
  • MeanStreet  · 技术社区  · 6 年前

    我正在学习如何使用TensorFlow数据API,并努力理解映射是如何工作的。对于上下文,我想加载一组图像并将它们发送到一个神经网络。

    下面的MWE就是这么做的(10号的假数据集, read_image 映射到数据集的函数)。

    import tensorflow as tf
    import numpy as np
    
    
    def read_image(filename, label):
        return np.random.rand(8, 8, 1), label # simulate data load (generate random data)
    
    # generate fake dataset of filenames (of size 10)
    filenames = tf.constant(np.asarray(["file" + str(i) for i in range(10)]))
    labels = tf.constant(np.asarray([2*i for i in range(10)]))
    
    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(read_image)
    
    dataset = dataset.repeat().batch(2)
    iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
    
    
    X, y = iterator.get_next()
    train_init_op = iterator.make_initializer(dataset)
    
    with tf.Session() as session:
        tf.global_variables_initializer().run()
        session.run(train_init_op)
        for _ in range(10):
            print(session.run([X]))
    

    运行此代码时(应该不做任何操作,只需打印 读出图像 ,结果总是相同的数据: 读出图像 只调用一次。为什么?我用过 dataset.map ,不是应该对数据集的每个元素(此处为10)调用它吗?

    提前感谢您的帮助或建议。

    1 回复  |  直到 6 年前
        1
  •  0
  •   MeanStreet    6 年前

    我对它的理解是错误的,我真的知道如何让它工作。这个映射函数被“集成”到张量流图中,因此实际上只称为一次。其中一个需要使用TF操作。

    如果 read_image 取而代之的是:

    def read_image(filename, label):
        return tf.random_normal([4]), tf.random_normal([1])
    

    它按预期工作(每次生成随机值 session.run 被调用)。