代码之家  ›  专栏  ›  技术社区  ›  ScalaBoy

如何获取Keras的历史回调指标?

  •  0
  • ScalaBoy  · 技术社区  · 6 年前

    如何检索回调度量的历史记录?我有一节课 Metrics 我把它用在 fit Keras模型的功能如下 callbacks=[model_metrics] .

    这是全班的代码 韵律学 适合 功能。

    class Metrics(Callback):
    
        def on_train_begin(self, logs={}):
            self.val_f1s = []
            self.val_bal_accs = []
    
        def on_epoch_end(self, epoch, logs={}):
            val_predict = np.argmax((np.asarray(self.model.predict(self.validation_data[0]))).round(), axis=1)
            val_targ = np.argmax(self.validation_data[1], axis=1)
            _val_f1 = metrics.f1_score(val_targ, val_predict, average='weighted')
            _val_bal_acc = metrics.balanced_accuracy_score(val_targ, val_predict)    
            self.val_f1s.append(_val_f1)
            self.val_bal_accs.append(_val_bal_acc)
            print(" — val_f1: {:f} — val_bal_acc: {:f}".format(_val_f1, _val_bal_acc))
            return
    
    model_metrics = Metrics()
    
    history = model.fit(np.array(X_train), y_train, 
                        validation_data=(np.array(X_test), y_test),
                        epochs=5,
                        batch_size=2,
                        callbacks=[model_metrics],
                        shuffle=False,
                        verbose=1)
    

    我怎么才能拿到 history 属于 val_f1 val_bal_acc ?现在我只能进入 loss , val_loss , acc , val_acc :

    print(history.history.keys())
    
    1 回复  |  直到 6 年前
        1
  •  1
  •   Primusa    6 年前

    与…互动 keras 需要为其传入参数的历史API metrics 而不是 callbacks .

    在当前状态下, val_f1 val_bal_acc 不会存储在History对象中,而是存储在 model_metrics 对象。

    您可以这样访问它们:

    model_metrics.val_f1s
    

    它与访问任何对象的属性相同。

    最后,如果您确实希望创建自定义度量并希望从历史记录中访问它,则需要定义自定义度量(作为函数),然后将其传递到 韵律学 克瓦格 model.compile . 这样做:

    def my_metric(y_true y_pred):
        return y_true # just a dummy return value
    
    # assume that the model is defined somewhere
    model.compile(loss=..., optimizer=..., metrics = [my_metric]
    

    然后你就能找到 val_my_metric 在History对象中,您不适合。