代码之家  ›  专栏  ›  技术社区  ›  Rocketq

一个CIFAR-10级上的Keras过盈

  •  1
  • Rocketq  · 技术社区  · 6 年前

    为了使一切都清楚,让我展示整个模型,非常简单:

    from keras.datasets import cifar10 #much more libraries imported
    # simple prerocessing 
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()    
    batch_size = 32
    num_classes = 10
    y_train = np_utils.to_categorical(y_train, num_classes)
    y_test = np_utils.to_categorical(y_test, num_classes)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train  /= 255
    x_test /= 255
    
    def base_model():
    
        model = Sequential()
        model.add(Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]))
        model.add(Activation('relu'))
        model.add(Conv2D(32,(3, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
    
        model.add(Conv2D(64, (3, 3), padding='same'))
        model.add(Activation('relu'))
        model.add(Conv2D(64, (3,3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))
    
        model.add(Flatten())
        model.add(Dense(512))
        model.add(Activation('relu'))
        model.add(Dropout(0.5))
        model.add(Dense(num_classes))
        model.add(Activation('softmax'))
    
        sgd = SGD(lr = 0.1, decay=1e-6, momentum=0.9, nesterov=True)
        # Train model
    
        model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
        return model
    
    cnn_n = base_model()
    cnn_n.summary()
    
    # Fit model
    
    cnn = cnn_n.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test,y_test)
                    ,shuffle=True, verbose= 
    0)
    

    正如你所看到的,训练错误和验证甚至不能减少错误

    error

    sequential_model_to_ascii_printout(cnn_n)
     OPERATION           DATA DIMENSIONS   WEIGHTS(N)   WEIGHTS(%)
    
                   Input   #####     32   32    3
                  Conv2D    \|/  -------------------       896     0.1%
                    relu   #####     32   32   32
                  Conv2D    \|/  -------------------      9248     0.7%
                    relu   #####     30   30   32
            MaxPooling2D   Y max -------------------         0     0.0%
                           #####     15   15   32
                 Dropout    | || -------------------         0     0.0%
                           #####     15   15   32
                  Conv2D    \|/  -------------------     18496     1.5%
                    relu   #####     15   15   64
                  Conv2D    \|/  -------------------     36928     3.0%
                    relu   #####     13   13   64
            MaxPooling2D   Y max -------------------         0     0.0%
                           #####      6    6   64
                 Dropout    | || -------------------         0     0.0%
                           #####      6    6   64
                 Flatten   ||||| -------------------         0     0.0%
                           #####        2304
                   Dense   XXXXX -------------------   1180160    94.3%
                    relu   #####         512
                 Dropout    | || -------------------         0     0.0%
                           #####         512
                   Dense   XXXXX -------------------      5130     0.4%
                 softmax   #####          10
    

    混淆矩阵,模型在第三类上绝对过了头: enter image description here

    y_测试还包含其他类:

    y_test
    array([[0., 0., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 0., 1., 0.],
           [0., 0., 0., ..., 0., 1., 0.],
           ...,
           [0., 0., 0., ..., 0., 0., 0.],
           [0., 1., 0., ..., 0., 0., 0.],
           [0., 0., 0., ..., 1., 0., 0.]]
    

    为什么模特只看一节课?

    PS:我是按照这个指南做的: https://blog.plon.io/tutorials/cifar-10-classification-using-keras-tutorial/

    1 回复  |  直到 6 年前
        1
  •  1
  •   Wendong Zheng    6 年前

    我认为这个cifar-10任务可以选择adam优化算法,sgd收敛速度更早。而且你设置的学习速率太大(你可以设置lr=0.01或lr=0.001),将接近休克的最小点。这是我的代码: CIFAR-10