代码之家  ›  专栏  ›  技术社区  ›  Jürgen K.

Confusion MatrixDisplay(Scikit Learn)绘图标签超出范围

  •  1
  • Jürgen K.  · 技术社区  · 3 年前

    以下代码绘制了一个混淆矩阵:

    from sklearn.metrics import ConfusionMatrixDisplay
    
    confusion_matrix = confusion_matrix(y_true, y_pred)
    target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
    disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
    plt.savefig("conf.png")
    

    Confusion Matrix

    这个情节有两个问题。

    1. y轴标签被截断(True label)。x标签也被剪掉了。
    2. x轴上的名称太长。

    为了解决我试图使用的第一个问题 poof(bbox_inches='tight') 遗憾的是,sklearn无法使用该功能。 在第二种情况下,我尝试了以下解决方案 2. 这导致了一个完全扭曲的情节。

    总的来说,这两个问题都让我感到困扰。

    0 回复  |  直到 3 年前
        1
  •  6
  •   Alexander L. Hayes    3 年前

    我认为最简单的方法是切换到 tight_layout 并添加 pad_inches= 什么东西。

    from sklearn.metrics import confusion_matrix
    from sklearn.metrics import ConfusionMatrixDisplay
    import matplotlib.pyplot as plt
    from numpy.random import default_rng
    
    rand = default_rng()
    y_true = rand.integers(low=0, high=7, size=500)
    y_pred = rand.integers(low=0, high=7, size=500)
    
    
    confusion_matrix = confusion_matrix(y_true, y_pred)
    target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
    disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
    
    plt.tight_layout()
    plt.savefig("conf.png", pad_inches=5)
    

    结果:

    Confusion matrix where all text in the axes is visible.