代码之家  ›  专栏  ›  技术社区  ›  Karnivaurus

TensorFlow将数据加载到tf.Dataset需要太长时间

  •  1
  • Karnivaurus  · 技术社区  · 6 年前

    我的GPU有3gb的内存,RAM有32gb的内存。每个半数据集的大小为20GB。我的硬盘有足够的可用空间(超过1 TB)。

    我的尝试如下。我创建了一个可初始化的 tf.Dataset

    但是,这非常慢,因为从硬盘加载数据需要很长时间,而且每次用这些数据初始化数据集也需要很长时间。

    有没有更有效的方法来做到这一点?

    在加载数据集的另一半之前,我已经尝试过对数据集的每一半进行多个阶段的培训,这要快得多,但这会使验证数据的性能差得多。这大概是因为模型在每一半上都过拟合,然后没有推广到另一半的数据。

    import tensorflow as tf
    import numpy as np
    import time
    
    # Create and save 2 datasets of test NumPy data
    dataset_num_elements = 100000
    element_dim = 10000
    batch_size = 50
    test_data = np.zeros([2, int(dataset_num_elements * 0.5), element_dim], dtype=np.float32)
    np.savez('test_data_1.npz', x=test_data[0])
    np.savez('test_data_2.npz', x=test_data[1])
    
    # Create the TensorFlow dataset
    data_placeholder = tf.placeholder(tf.float32, [int(dataset_num_elements * 0.5), element_dim])
    dataset = tf.data.Dataset.from_tensor_slices(data_placeholder)
    dataset = dataset.shuffle(buffer_size=dataset_num_elements)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(1)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    init_op = iterator.initializer
    
    num_batches = int(dataset_num_elements / batch_size)
    
    with tf.Session() as sess:
        while True:
            for dataset_section in range(2):
                # Load the data from the hard drive
                t1 = time.time()
                print('Loading')
                loaded_data = np.load('test_data_' + str(dataset_section + 1) + '.npz')
                x = loaded_data['x']
                print('Loaded')
                t2 = time.time()
                loading_time = t2 - t1
                print('Loading time = ' + str(loading_time))
                # Initialize the dataset with this loaded data
                t1 = time.time()
                sess.run(init_op, feed_dict={data_placeholder: x})
                t2 = time.time()
                initialization_time = t2 - t1
                print('Initialization time = ' + str(initialization_time))
                # Read the data in batches
                for i in range(num_batches):
                    x = sess.run(next_element)
    
    1 回复  |  直到 6 年前
        1
  •  3
  •   Kaihong Zhang    6 年前

    Feed不是输入数据的有效方法。您可以输入如下数据:

    1. 创建包含所有输入文件名的文件名数据集。你可以洗牌,在这里重复数据集。
    2. 预取要训练的数据。

    这只是一个例子。你可以自己设计管道,记住以下几点:

    • 尽可能使用轻量级提要
    • 使用多线程读取和预处理