代码之家  ›  专栏  ›  技术社区  ›  Ahmad

如何将编码器的编码输出提供给分类器的输入?

  •  2
  • Ahmad  · 技术社区  · 6 年前

    AutoEncoder 提高分类器的性能。

    这就是 我使用keras文档编写MNIST数据:

    from keras.layers import Input, Dense
    from keras.models import Model
    
    # this is the size of our encoded representations
    # 32 floats -> compression of factor 24.5, assuming the input is 784 floats
    encoding_dim = 32
    
    # this is our input placeholder
    input_img = Input(shape=(784,))
    # "encoded" is the encoded representation of the input
    encoded = Dense(encoding_dim, activation='relu')(input_img)
    # "decoded" is the lossy reconstruction of the input
    decoded = Dense(784, activation='sigmoid')(encoded)
    
    # this model maps an input to its reconstruction
    autoencoder = Model(input_img, decoded)
    # this model maps an input to its encoded representation
    encoder = Model(input_img, encoded)
    # create a placeholder for an encoded (32-dimensional) input
    encoded_input = Input(shape=(encoding_dim,))
    # retrieve the last layer of the autoencoder model
    decoder_layer = autoencoder.layers[-1]
    # create the decoder model
    decoder = Model(encoded_input, decoder_layer(encoded_input))
    autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
    
    autoencoder.fit(x_train, x_train,
                    epochs=50,
                    batch_size=256,
                    shuffle=True,
                    validation_data=(x_test, x_test))
    

    现在我想把它连接到一个分类器上。使用 this question

    x = encoder.output
    # h = Dense(3, activation='relu', name='hidden')(x)
    y = Dense(1, activation='sigmoid', name='predictions')(x)
    
    classifier = Model(inputs=autoencoder.inputs, outputs=y)
    
    
    # Compile model
    classifier.compile(loss='binary_crossentropy', optimizer='adam',
                       metrics=['accuracy'])
    
    # Fit the model
    history = classifier.fit(x_train, y_train, 
                             epochs=10, 
                             batch_size=10,
                             validation_split=.1)
    

    首先,我不懂这个代码。第二,我猜 y 必须是10位对应的10位数字,但我不能设置为10,因为我得到一个错误。

    0 回复  |  直到 6 年前