代码之家  ›  专栏  ›  技术社区  ›  Belter

用TensorFlow2训练的简单线性回归模型的池性能

  •  0
  • Belter  · 技术社区  · 4 年前

    y = 2*x + 200 + error ,但我不能用简单的方法得到正确的结果。我不知道发生了什么。

    import numpy as np
    from tensorflow import keras
    x = np.arange(100)
    error = np.random.rand(100,1).ravel()
    y = 2*x + 200 + error
    
    opt = keras.optimizers.Adam(lr=0.0005)
    model = keras.Sequential([keras.layers.Dense(1, input_shape=[1])])
    model.compile(optimizer=opt, loss='mse', metrics=['mae'])
    early_stopping_callback = keras.callbacks.EarlyStopping(
            patience=10,
            monitor='val_loss',
            mode='min',
            restore_best_weights=True)
    history = model.fit(x, y, epochs=2000, batch_size=16, verbose=1,
                        validation_split=0.2, callbacks=[early_stopping_callback])
    

    当验证损失很大时,我的模型总是停止运行:

    纪元901/2000 5/5[==================================]-0s 3ms/步- 损失:14767.1357-价值损失:166.8979

    经过训练后,我一直觉得不合适:

    model.weights
    

    [<tf.变量'密集'28/内核:0'shape=(1,1)dtype=float32,numpy=array([[4.2019334]],dtype=float32)>,
    <tf.变量'密集'28/偏差:0'shape=(1,)dtype=float32,numpy=array([2.611792], dtype=float32)>]

    请帮我找出我的代码有什么问题。

    我使用tensorflow-v2.3.0

    0 回复  |  直到 4 年前
        1
  •  0
  •   Belter    4 年前

    我明白了,主要问题是 EarlyStopping 太早停止了我的训练过程!另一个问题是学习率太低。

    所以当我改变双参数设置时,我得到了正确的结果:

    import numpy as np
    from tensorflow import keras
    x = np.arange(100)
    error = np.random.rand(100,1).ravel()
    y = 2*x + 200 + error
    
    opt = keras.optimizers.Adam(lr=0.8)  # <--- bigger lr
    model = keras.Sequential([keras.layers.Dense(1, input_shape=[1])])
    model.compile(optimizer=opt, loss='mse', metrics=['mae'])
    early_stopping_callback = keras.callbacks.EarlyStopping(
            patience=100,  # <--- longer patience to training
            monitor='val_loss',
            mode='min',
            restore_best_weights=True)
    history = model.fit(x, y, epochs=2000, batch_size=16, verbose=1,
                        validation_split=0.2, callbacks=[early_stopping_callback])