代码之家  ›  专栏  ›  技术社区  ›  Ricky

RNN从哪里获取批量大小?

  •  0
  • Ricky  · 技术社区  · 6 年前

    我正在通过以下方式培训RNN:

    def create_rnn_model(stateful,length):
        model = Sequential()
        model.add(SimpleRNN(20,return_sequences=False,stateful=stateful,batch_input_shape=(1,length,1)))
        adam = optimizers.Adam(lr=0.001)
        model.add(Dense(1))
        model.compile(loss='mean_absolute_error', optimizer=adam, metrics=[root_mean_squared_error])
        print(model.summary())
        return model
    

    和适合的

    model_info = model_rnn_stateful.fit(x=x_train, y=y_train, validation_data=(x_test, y_test), batch_size=1, epochs=10,verbose=1)
    

    预测

    predicted_rnn_stateful = model_rnn_stateful.predict(x_test)
    

    但当我预测的时候就会出错

    可除以批量大小的样本数。发现:200 样品。批量:32。

    编辑

    1 回复  |  直到 6 年前
        1
  •  1
  •   Mael Galliffet    6 年前

    Keras documentation

    • 批量大小 :整数或无。每次渐变更新的样本数。如果未指定,批次大小将默认为32。

    1可能是不正确的batch\u size值,然后它采用默认值32。请尝试使用2或20作为batch\u size