代码之家  ›  专栏  ›  技术社区  ›  MNM

TF2.0自定义生成器出现问题

  •  0
  • MNM  · 技术社区  · 5 年前

    ValueError:未在未知的TensorShape上定义as\u list()。

    这是我正在使用的代码。

        csv_path = '../Dataset/ShowData.csv'
        df = pd.read_csv(csv_path)
        base_path = "../Dataset/"
        # Make one hot encoding for lables
        le = LabelEncoder()
        df['Label'] = le.fit_transform(df['Label'])
        print(df.head(1)['Label'])
        print(df.tail(1)['Label'])
        '''
        # View the labels (if you want)
        list(encoder.classes_)
        # Convert some integers into their category names
        list(encoder.inverse_transform([2, 2, 1]))
        '''
    
        def process_dataframe(dataframe):
            for index, row in dataframe.iterrows():
                # print(row['Path'], row['Label'])
                # Load image and get lable
                img_path = os.path.join(base_path, row['Path'])
                img = load_img(img_path, target_size=(200, 200))
                img = img_to_array(img)
                img = img/255  # normalize the image
                label = row['Label']
                #label = to_categorical(label, num_classes, dtype=tf.float32)
                yield img, label
    
    
        def generate_dataset(dataframe):
            generator = lambda: process_dataframe(dataframe)
            return tf.data.Dataset.from_generator(generator=generator,
                                              output_types= (tf.float32, tf.int32))
    
    
        dataset = generate_dataset(df)
        data_batch = dataset.shuffle(10000).batch(32)
        print(data_batch)
    
        model = Sequential([
            Flatten(input_shape=(200, 200, 3)),
            Dense(128, activation='relu'),
            Dense(10, activation='softmax')
        ])
    
        model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics=[tf.keras.metrics.Accuracy()])
    
        model.fit(data_batch, epochs=10, verbose=1)
    

    哦,这就是批量打印出来的数据

    <DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.float32, tf.int32)>
    

    这是标签类型

    Name: Label, dtype: int32
    
    0 回复  |  直到 5 年前
        1
  •  1
  •   Gabe    5 年前

    你应该定义 output_shapes 对于 tf.data.Dataset.from_generator 打电话。你可以在 Tensorflow page .