代码之家  ›  专栏  ›  技术社区  ›  Evan Weissburg

获取Tensorflow中数据集的长度

  •  9
  • Evan Weissburg  · 技术社区  · 7 年前
    source_dataset = tf.data.TextLineDataset('primary.csv')
    target_dataset = tf.data.TextLineDataset('secondary.csv')
    dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
    dataset = dataset.shard(10000, 0)
    dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                                  tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
    dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
    dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))
    
    dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code
    

    我想彻底洗牌我的整个数据集,但是 shuffle() 需要抽取大量样本,以及 tf.Size() 不适用于 tf.data.Dataset .

    我怎样才能正确地洗牌?

    2 回复  |  直到 4 年前
        1
  •  2
  •   Ringo    6 年前

    我和tf一起工作。数据FixedLengthRecordDataset()并遇到类似问题。 在我的例子中,我试图只提取一定比例的原始数据。 因为我知道所有记录都有固定的长度,所以我的解决方法是:

    totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
    numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
    dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)
    

    在您的情况下,我的建议是在python中直接计算“primary”中的记录数。csv“和”辅助。csv’。或者,我认为出于您的目的,设置buffer\u size参数实际上并不需要计算文件数。根据 the accepted answer about the meaning of buffer_size ,大于数据集中元素数的数字将确保整个数据集中的均匀洗牌。因此,只需输入一个非常大的数字(您认为将超过数据集大小)就可以了。

        2
  •  1
  •   Timbus Calin    4 年前

    对于TensorFlow 2,数据集的长度可以通过 cardinality() 作用

    dataset = tf.data.Dataset.range(42)
    #both print 42 
    dataset_length_v1 = tf.data.experimental.cardinality(dataset).numpy())
    dataset_length_v2 = dataset.cardinality().numpy()
    

    注意:当使用谓词(例如filter)时,长度的返回值可能为-2。你可以参考一个解释 here ,否则请阅读以下段落:

    如果使用过滤器谓词,则基数可能返回值-2,因此未知;如果确实在数据集上使用过滤器谓词,请确保已以另一种方式计算了数据集的长度(例如,在应用之前,pandas dataframe的长度) .from_tensor_slices() 在上面。