代码之家  ›  专栏  ›  技术社区  ›  lhlmgr

使用数据集API生成平衡的小批量

  •  10
  • lhlmgr  · 技术社区  · 7 年前

    我有一个关于新数据集API(tensorflow 1.4rc1)的问题。 我的数据集与标签不平衡 0 1 . 我的目标是在预处理期间创建平衡的小批量。

    假设我有两个过滤过的数据集:

    ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
    ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()
    

    有没有一种方法可以将这两个数据集结合起来,使生成的数据集看起来像 ds = [0, 1, 0, 1, 0, 1] :

    类似这样:

    dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
    dataset = dataset.apply(...)
    # dataset looks like [0, 1, 0, 1, 0, 1, ...]
    dataset = dataset.batch(20)
    

    我目前的做法是:

    def _concat(x, y):
       return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
    dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
    dataset = dataset.map(_concat)
    

    但我觉得还有一种更优雅的方式。

    1 回复  |  直到 7 年前
        1
  •  7
  •   mrry    7 年前

    你在正确的轨道上。以下示例使用 Dataset.flat_map()

    dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
    
    # Each input element will be converted into a two-element `Dataset` using
    # `Dataset.from_tensors()` and `Dataset.concatenate()`, then `Dataset.flat_map()`
    # will flatten the resulting `Dataset`s into a single `Dataset`.
    dataset = dataset.flat_map(
        lambda ex_pos, ex_neg: tf.data.Dataset.from_tensors(ex_pos).concatenate(
            tf.data.Dataset.from_tensors(ex_neg)))
    
    dataset = dataset.batch(20)