代码之家  ›  专栏  ›  技术社区  ›  Karnivaurus

如何对每个历元上的TensorFlow数据集进行完全洗牌

  •  2
  • Karnivaurus  · 技术社区  · 6 年前

    dataset = tf.data.Dataset.from_tensor_slices((inputs, labels))
    dataset = dataset.shuffle(1000)
    dataset = dataset.repeat()
    dataset = dataset.batch(50)
    

    每次从数据集中抽取一批新的50个样本时,它都会从接下来的1000个样本中随机抽取50个样本。但除此之外,我想做的是在每个纪元开始时完全洗牌我的整个数据集。

    一种方法是设置 shuffle_buffer 大小等于整个数据集的大小。但是,这会导致每次绘制批时整个数据集都被完全洗牌,这将非常缓慢。相反,我只希望整个数据集在每个纪元开始时完全洗牌一次。

    我该怎么做?

    2 回复  |  直到 6 年前
        1
  •  2
  •   P-Gn    6 年前

    但是,这会导致每次绘制批时整个数据集都被完全洗牌,这将非常缓慢。

    不完全正确。每次处理一个新样本时,不需要对整个缓冲区进行洗牌,每次处理一个新样本时,只需要进行一次排列。

    enter image description here

    拥有一个大的shuffle缓冲区的代价实际上是内存方面的:拥有一个数据集大小的shuffle缓冲区意味着整个数据集都在内存中,这并不总是可能的。

        2
  •  1
  •   kempy    6 年前

    TF数据集操作并不都是可交换的,但是如果您应用 shuffle repeat batch buffer_size 等于数据集大小。您还必须正确使用迭代器,以确保不会在循环中重新实例化它。