代码之家  ›  专栏  ›  技术社区  ›  bluesummers

tensorflow tf.data.dataset和bucketing

  •  9
  • bluesummers  · 技术社区  · 6 年前

    对于lstm网络,我看到了bucketing的巨大改进。

    我遇到了 bucketing section in the TensorFlow docs 哪个(tf.contrib)。

    虽然在我的网络中,我使用 tf.data.Dataset api,特别是我正在使用tfrecords,所以我的输入管道看起来像这样

    dataset = tf.data.TFRecordDataset(TFRECORDS_PATH)
    dataset = dataset.map(_parse_function)
    dataset = dataset.map(_scale_function)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.padded_batch(batch_size, padded_shapes={.....})
    

    我怎样才能把屈曲方法 tf.data.dataset数据集 管道?

    如果重要的话,在tfrecords文件中的每个记录中,我都将序列长度保存为整数。

    1 回复  |  直到 6 年前
        1
  •  5
  •   Vijay Mariappan    6 年前

    各种各样的 bucketing 用例使用 Dataset API 解释得很好 here .

    bucket_by_sequence_length() 例子:

    def elements_gen():
       text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2]]
       label = [1, 2, 1, 2]
       for x, y in zip(text, label):
           yield (x, y)
    
    def element_length_fn(x, y):
       return tf.shape(x)[0]
    
    dataset = tf.data.Dataset.from_generator(generator=elements_gen,
                                         output_shapes=([None],[]),
                                         output_types=(tf.int32, tf.int32))
    
    dataset =   dataset.apply(tf.contrib.data.bucket_by_sequence_length(element_length_func=element_length_fn,
                                                                  bucket_batch_sizes=[2, 2, 2],
                                                                  bucket_boundaries=[0, 8]))
    
    batch = dataset.make_one_shot_iterator().get_next()
    
    with tf.Session() as sess:
    
       for _ in range(2):
          print('Get_next:')
          print(sess.run(batch))
    

    输出:

    Get_next:
    (array([[1, 2, 3, 0, 0],
       [3, 4, 5, 6, 7]], dtype=int32), array([1, 2], dtype=int32))
    Get_next:
    (array([[1, 2, 0, 0],
       [8, 9, 0, 2]], dtype=int32), array([1, 2], dtype=int32))