代码之家  ›  专栏  ›  技术社区  ›  Richard

Tensorflow中的过滤数据

  •  0
  • Richard  · 技术社区  · 5 年前

    import pandas as pd
    
    import tensorflow.compat.v2 as tf
    import tensorflow.compat.v1 as tfv1
    tfv1.enable_v2_behavior()
    
    csv_file = tf.keras.utils.get_file('heart.csv', 'https://storage.googleapis.com/applied-dl/heart.csv')
    
    df = pd.read_csv(csv_file)
    target = df.pop('target')
    df['thal'] = pd.Categorical(df['thal'])
    df['thal'] = df.thal.cat.codes
    
    # Use interleave() and prefetch() to read many files concurrently.
    #files = tf.data.Dataset.list_files(file_pattern=input_file_pattern, shuffle=True, seed=123456789)
    #dataset = files.interleave(lambda x: tf.data.RecordIODataset(x).prefetch(100), cycle_length=8)
    
    #Pretend I actually had some data files
    dataset = tf.data.Dataset.from_tensor_slices((df.to_dict('list'), target.values))
    
    dataset = dataset.shuffle(1000, seed=123456789)
    dataset = dataset.batch(20)
    #Pretend I did some parsing here
    # dataset = dataset.map(parse_record, num_parallel_calls=20) 
    dataset = dataset.filter(lambda x, label: x['trestbps']<135)
    

    但这会产生错误消息:

    值错误: predicate TensorSpec(shape=(None,), dtype=tf.bool, name=None)

    我该怎么过滤数据?

    0 回复  |  直到 5 年前
        1
  •  1
  •   AlexisBRENON    5 年前

    这是因为你应用 filter batch . lambda 表情, x 是一批有形状的 (None,) (通过 drop_reminder=True 成形 (20,) .

    map 相反。但是,正如您所看到的,这有一个副作用,使成批的变量变大:您在输入中得到一批20个,然后删除与特定条件不匹配的元素(trestbps<135),而不是从每个批中删除相同数量的元素。而且这个解决方案的性能很差。。。

    import timeit
    
    import pandas as pd
    
    import tensorflow.compat.v2 as tf
    import tensorflow.compat.v1 as tfv1
    tfv1.enable_v2_behavior()
    
    def s1(ds):
        dataset = ds
        dataset = dataset.filter(lambda x, label: x['trestbps']<135)
        dataset = dataset.batch(20)
        return dataset
    
    def s2(ds):
        dataset = ds
        dataset = dataset.batch(20)
        dataset = dataset.map(lambda x, label: (tf.nest.map_structure(lambda y: y[x['trestbps'] < 135], x), label[x['trestbps'] < 135]))
        return dataset
    
    
    def base_ds():
        csv_file = tf.keras.utils.get_file('heart.csv', 'https://storage.googleapis.com/applied-dl/heart.csv')
    
        df = pd.read_csv(csv_file)
        target = df.pop('target')
        df['thal'] = pd.Categorical(df['thal'])
        df['thal'] = df.thal.cat.codes
    
        return tf.data.Dataset.from_tensor_slices((df.to_dict('list'), target.values))
    
    
    def main():
        ds = base_ds()
        ds1 = s1(ds)
        ds2 = s2(ds)
        tf.print("DS_S1:", [tf.nest.map_structure(lambda x: x.shape, x) for x in ds1])
        tf.print("DS_S2:", [tf.nest.map_structure(lambda x: x.shape, x) for x in ds2])
        tf.print("Are equals?", [x for x in ds1] == [x for x in ds2])
        tf.print("Contains same elements?", [x for x in ds1.unbatch()] == [x for x in ds2.unbatch()])
    
        tf.print("Filter and batch:", timeit.timeit(lambda: s1(ds), number=100))
        tf.print("Batch and map:", timeit.timeit(lambda: s2(ds), number=100))
    
    if __name__ == '__main__':
        main()
    

    结果:

    # Tensor shapes
    [...]
    Are equals? False
    Contains same elements? True
    Filter and batch: 0.5571189750007761
    Batch and map: 15.582061060000342
    

    善良