代码之家  ›  专栏  ›  技术社区  ›  Miguel Monteiro

TensorFlow 1.4分布式模式下的新数据集API

  •  1
  • Miguel Monteiro  · 技术社区  · 6 年前

    在TensorFlow 1.4的新数据集API之前,我使用以下代码在不同的工作者之间创建文件名的共享队列:

    # queue with the file names that can be shared amongst workers during training
    filename_queue = tf.FIFOQueue(100, tf.string, shared_name=shared_name)
    enque_op = filename_queue.enqueue_many([tf.train.limit_epochs(file_names, num_epochs)])
    close_op = filename_queue.close(cancel_pending_enqueues=True)
    
    # create queue runner and add it to queue runners
    qr = tf.train.QueueRunner(filename_queue, [enque_op], close_op,
                              queue_closed_exception_types=(tf.errors.OutOfRangeError, tf.errors.CancelledError))
    tf.train.add_queue_runner(qr)
    
    # read example from file
    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)
    
    # parse example
    image, ground_truth, example_name = parse_example(example)
    

    这段代码使用队列和队列运行器,非常难看和混乱。但它允许选择 shared_name= 在工作人员之间创建共享队列,这样他们就不会处理相同的示例。

    TensorFlow 1.4新版本发布后 input pipelines 变得更加易于使用。所以我想更新我的程序来使用这个新功能。 然而,我在新文档中找不到如何在工作人员之间共享数据集。

    这是自动完成的还是不是功能?

    1 回复  |  直到 6 年前
        1
  •  1
  •   jsimsa    6 年前

    您可以使用 tf.data.Dataset.shard (参见 documentation )为此目的。文档说明了如何“切分”单个文件的元素或(如在您的示例中)“切分”文件名。