我有一个关于新数据集API(tensorflow 1.4rc1)的问题。
我的数据集与标签不平衡
0
和
1
. 我的目标是在预处理期间创建平衡的小批量。
假设我有两个过滤过的数据集:
ds_pos = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 1), []))
ds_neg = dataset.filter(lambda l, x, y, z: tf.reshape(tf.equal(l, 0), [])).repeat()
有没有一种方法可以将这两个数据集结合起来,使生成的数据集看起来像
ds = [0, 1, 0, 1, 0, 1]
:
类似这样:
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.apply(...)
# dataset looks like [0, 1, 0, 1, 0, 1, ...]
dataset = dataset.batch(20)
我目前的做法是:
def _concat(x, y):
return tf.cond(tf.random_uniform(()) > 0.5, lambda: x, lambda: y)
dataset = tf.data.Dataset.zip((ds_pos, ds_neg))
dataset = dataset.map(_concat)
但我觉得还有一种更优雅的方式。