代码之家  ›  专栏  ›  技术社区  ›  Artem

如何使用批量训练模型,对单个输入进行预测?

  •  0
  • Artem  · 技术社区  · 5 年前

    我有经过数据集培训的RNN模型:

    train = tf.data.Dataset.from_tensor_slices((data_x[:train_size],
                              data_y[:train_size])).batch(batch_size).repeat()
    

    模型:

        model = tf.keras.Sequential()
        model.add(tf.keras.layers.GRU(units=lstm_num_units,
                                       return_sequences=True,
                                       kernel_initializer='random_uniform',
                                       recurrent_initializer='random_uniform',
                                       bias_initializer='random_uniform',
                                       batch_size=batch_size,
                                       input_shape = [seq_len, num_features]))
        model.add(tf.keras.layers.LSTM(units=lstm_num_units,
                                       batch_size=batch_size,
                                       return_sequences=True,
                                       input_shape = [seq_len, num_features]))
        model.add(tf.keras.layers.Flatten())
        model.add(tf.keras.layers.Dense(units=dence_units))
        model.add(tf.keras.layers.Dropout(drop_flat))
        model.add(tf.keras.layers.Dense(units=out_units))
        model.add(tf.keras.layers.Softmax())   
    
        model.compile(loss="sparse_categorical_crossentropy",
                optimizer=tf.train.RMSPropOptimizer(opt),
                metrics=['accuracy'])
    
     model.fit(train, epochs=EPOCHS,
                            steps_per_epoch=repeat_size_train,
                            validation_data=validate,
                            validation_steps=repeat_size_validate,
                            verbose=1,
                            shuffle=True)
                            callbacks=[tensorboard, cp_callback])
    

    我需要对seq-len的单个输入进行预测,但看起来我的输入必须是一个批大小:

    ar = np.random.randint(98, size=[batch_size, seq_len])
    ar = np.reshape(ar, [batch_size, seq_len, 1])
    prediction = model.m.predict(ar)
    

    有没有一种方法可以使它在形状[1,序列长度,1]的单个输入上工作?

    1 回复  |  直到 5 年前
        1
  •  2
  •   Daniel Möller    5 年前

    是的,只需重建第一层中没有批量大小的模型。

    复制旧模型的权重。

    newModel.set_weights(oldModel.get_weights())
    

    批量大小的目的仅存在于 stateful=True 保持批次间一致性的模型。

    尽管如此,由于批量大小没有数学上的变化。