代码之家  ›  专栏  ›  技术社区  ›  Richard_wth

如何在每次迭代中只从一个类中抽样

  •  1
  • Richard_wth  · 技术社区  · 6 年前

    我想在一个ImageNet数据集上训练一个分类器(1000个类,每个类有大约1300个图像)。出于某种原因,我需要每个批包含来自同一类的64个图像,以及来自不同类的连续批。使用最新的张量流是否可能(且有效)?

    tf.contrib.data.sample_from_datasets 在tf 1.9中,允许从 tf.data.Dataset 对象,与 weights 表示概率。我想知道以下想法是否有意义:

    • 将每个类的数据保存为单独的tfrecord文件。
    • 通过A tf.data.Dataset.from_generator 对象作为 砝码 . 对象样本来自分类分布,每个样本看起来像 [0,...,0,1,0,...,0] 用999 0 S和1 1 ;
    • 创建1000 TF.DATA数据集 对象,每个对象链接一个tfrecord文件。

    我想,这样,也许在每次迭代时, sample_from_datasets 将首先对表示 TF.DATA数据集 从中取样,然后从该类中取样。

    这是正确的吗?还有其他有效的方法吗?

    更新

    正如P-GN所建议的,从一个类中抽取数据的一种方法是:

    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(some_parser_fun)  # parse one datum from tfrecord
    dataset = dataset.shuffle(buffer_size)
    
    if sample_same_class:
        group_fun = tf.contrib.data.group_by_window(
            key_func=lambda data_x, data_y: data_y,
            reduce_func=lambda key, d: d.batch(batch_size),
            window_size=batch_size)
        dataset = dataset.apply(group_fun)
    else:
        dataset = dataset.batch(batch_size)
    
    dataset = dataset.repeat()
    data_batch = dataset.make_one_shot_iterator().get_next()
    

    后续问题可在 How to sample batch from a specific class?

    1 回复  |  直到 6 年前
        1
  •  3
  •   P-Gn    6 年前

    如果我理解正确的话,我认为你的解决方案是行不通的,因为 sample_from_dataset 需要其值的列表 weights 不是 Tensor .

    但是如果你不介意有1000个 Dataset 正如你提出的解决方案,那么我建议

    • 创建一个 数据集 每班,
    • batch 每批数据集都有来自单个类的样本,
    • zip 它们都变成一个大的 数据集 批次,
    • shuffle 数据集 洗牌将发生在批次上,而不是样品上,因此不会改变批次是单类的事实。

    更复杂的方法是依靠 tf.contrib.data.group_by_window . 让我用一个合成的例子来说明这一点。

    import numpy as np
    import tensorflow as tf
    
    def gen():
      while True:
        x = np.random.normal()
        label = np.random.randint(10)
        yield x, label
    
    batch_size = 4
    batch = (tf.data.Dataset
      .from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
      .apply(tf.contrib.data.group_by_window(
        key_func=lambda x, label: label,
        reduce_func=lambda key, d: d.batch(batch_size),
        window_size=batch_size))
      .make_one_shot_iterator()
      .get_next())
    
    sess = tf.InteractiveSession()
    sess.run(batch)
    # (array([ 0.04058843,  0.2843775 , -1.8626076 ,  1.1154234 ], dtype=float32),
    # array([6, 6, 6, 6], dtype=int64))
    sess.run(batch)
    # (array([ 1.3600663,  0.5935658, -0.6740045,  1.174328 ], dtype=float32),
    # array([3, 3, 3, 3], dtype=int64))