如果我理解正确的话,我认为你的解决方案是行不通的,因为
sample_from_dataset
需要其值的列表
weights
不是
Tensor
.
但是如果你不介意有1000个
Dataset
正如你提出的解决方案,那么我建议
-
创建一个
数据集
每班,
-
batch
每批数据集都有来自单个类的样本,
-
zip
它们都变成一个大的
数据集
批次,
-
shuffle
这
数据集
洗牌将发生在批次上,而不是样品上,因此不会改变批次是单类的事实。
更复杂的方法是依靠
tf.contrib.data.group_by_window
. 让我用一个合成的例子来说明这一点。
import numpy as np
import tensorflow as tf
def gen():
while True:
x = np.random.normal()
label = np.random.randint(10)
yield x, label
batch_size = 4
batch = (tf.data.Dataset
.from_generator(gen, (tf.float32, tf.int64), (tf.TensorShape([]), tf.TensorShape([])))
.apply(tf.contrib.data.group_by_window(
key_func=lambda x, label: label,
reduce_func=lambda key, d: d.batch(batch_size),
window_size=batch_size))
.make_one_shot_iterator()
.get_next())
sess = tf.InteractiveSession()
sess.run(batch)
# (array([ 0.04058843, 0.2843775 , -1.8626076 , 1.1154234 ], dtype=float32),
# array([6, 6, 6, 6], dtype=int64))
sess.run(batch)
# (array([ 1.3600663, 0.5935658, -0.6740045, 1.174328 ], dtype=float32),
# array([3, 3, 3, 3], dtype=int64))