代码之家  ›  专栏  ›  技术社区  ›  Soerendip

如何使用tensorflow.models.sequential()提前停止?

  •  0
  • Soerendip  · 技术社区  · 6 年前

    使用这样生成的顺序模型:

    def generate_model():
        model = Sequential()
        model.add(Conv1D(64, kernel_size=10, strides=1,
                         activation='relu', padding='same',
                         input_shape=(MAXLENGTH, NAMESPACELENGTH)))
        model.add(MaxPooling1D(pool_size=4, strides=2))
        model.add(Conv1D(32, 3, activation='relu', padding='same'))
        model.add(MaxPooling1D(pool_size=4))
        model.add(Flatten())
        model.add(Dense(10, activation='relu'))
        model.add(Dense(1, activation='linear'))
        model.compile(loss='mean_squared_error', 
                      optimizer='adam', metrics=['mean_squared_error'])
        return model
    

    我想做KFOLD交叉验证建模。所以,我在一个循环中训练K模型:

    models = []
    for ndx_train, ndx_val in kfold.split(X, y):
        model = generate_model()
        N_train = len(ndx_train)
        X_batch = X[ndx_train]
        y_batch = y[ndx_train]
        model.fit(X_batch, y_batch, epochs=100, verbose=1, steps_per_epoch=10,
                 validation_data=(X[ndx_val], y[ndx_val]), validation_steps=100)
    
        models.append(model)
    

    现在,我可以通过查看输出来查看每个模型何时停止。即当验证错误再次增加时。有没有可能用纯的 tf 有了这个更高级别的API设置?有一些建议使用沿线使用 tflearn here .

    1 回复  |  直到 6 年前
        1
  •  2
  •   Soerendip    6 年前

    通过使用 EarlyStopping 回调:

    from tensorflow.keras.callbacks import EarlyStopping
    callbacks = [
        EarlyStopping(monitor='val_mean_squared_error', patience=2, verbose=1),
    ]
    model.fit(..., callbacks=callbacks)