代码之家  ›  专栏  ›  技术社区  ›  Maystro

利用混淆矩阵理解多标签分类器

  •  9
  • Maystro  · 技术社区  · 6 年前

    我有一个12类的多标签分类问题。我正在使用 slim 属于 Tensorflow 使用预先训练的模型训练模型 ImageNet . 以下是每个班级参加培训的百分比;验证

                Training     Validation
      class0      44.4          25
      class1      55.6          50
      class2      50            25
      class3      55.6          50
      class4      44.4          50
      class5      50            75
      class6      50            75
      class7      55.6          50
      class8      88.9          50
      class9     88.9           50
      class10     50            25
      class11     72.2          25
    

    问题是模型没有收敛,并且 ROC 曲线图( Az )验证集上的错误,例如:

                   Az 
      class0      0.99
      class1      0.44
      class2      0.96  
      class3      0.9
      class4      0.99
      class5      0.01
      class6      0.52
      class7      0.65
      class8      0.97
      class9     0.82
      class10     0.09
      class11     0.5
      Average     0.65
    

    我不知道为什么它对某些课程有效,而对其他课程无效。我决定深入研究细节,看看神经网络在学习什么。我知道混淆矩阵只适用于二进制或多类分类。因此,为了能够绘制它,我必须将问题转换为成对的多类分类。即使模型是使用 sigmoid 乙状结肠 矩阵行中的类存在而列中的类不存在的图像的张量流预测函数。这已应用于验证集图像。通过这种方式,我想我可以获得关于模型学习内容的更多细节。为了便于显示,我只是将对角线元素圈起来。

    enter image description here

    我的解释是:

    1. 0类(&A);4在存在时检测到存在,在不存在时检测到不存在。这意味着可以很好地检测这些类。
    2. 2级、6级和;7始终被检测为不存在。这不是我要找的。
    3. 3级、8级和;9始终检测到存在。这不是我要找的。这可以应用于11类。
    4. 类别5在不存在时检测为存在,在存在时检测为不存在。它被反向检测到。
    5. 类别3和;10: 我认为我们不能为这两个类提取太多的信息。

    我的问题是口译。。我不确定问题出在哪里,也不确定产生这种结果的数据集中是否存在偏差。我还想知道是否有一些指标可以帮助解决多标签分类问题?你能和我分享一下你对这种混淆矩阵的解释吗?接下来要看什么/哪里?对其他指标提出一些建议会很好。

    谢谢

    编辑:

    我将问题转换为多类分类,以便为每对类(例如0,1)计算概率(类0,类1),表示为 p(0,1) : 我对工具0存在和工具1不存在的图像中的工具1进行预测,并通过应用sigmoid函数将其转换为概率,然后显示这些概率的平均值。对于 p(1, 0) ,我对工具0执行相同的操作,但现在使用存在工具1而不存在工具0的图像。对于 p(0, 0) ,我使用存在工具0的所有图像。考虑到 p(0,4) 在上图中,不适用表示没有工具0和工具4存在的图像。

    以下是两个子集的图像数:

    1. 169320张培训图片
    2. 37440张验证图像

    下面是在训练集上计算的混淆矩阵(计算方法与前面描述的验证集相同),但这次颜色代码是用于计算每个概率的图像数: enter image description here

    已编辑: 对于数据扩充,我对网络中的每个输入图像进行随机平移、旋转和缩放。此外,以下是有关这些工具的一些信息:

    class 0 shape is completely different than the other objects.
    class 1 resembles strongly to class 4.
    class 2 shape resembles to class 1 & 4 but it's always accompanied by an object different than the others objects in the scene. As a whole, it is different than the other objects.
    class 3 shape is completely different than the other objects.
    class 4 resembles strongly to class 1
    class 5 have common shape with classes 6 & 7 (we can say that they are all from the same category of objects)
    class 6 resembles strongly to class 7
    class 7 resembles strongly to class 6
    class 8 shape is completely different than the other objects.
    class 9 resembles strongly to class 10
    class 10 resembles strongly to class 9
    class 11 shape is completely different than the other objects.
    

    已编辑: 以下是针对培训集提出的代码输出:

    Avg. num labels per image =  6.892700212615167
    On average, images with label  0  also have  6.365296803652968  other labels.
    On average, images with label  1  also have  6.601033718926901  other labels.
    On average, images with label  2  also have  6.758548914659531  other labels.
    On average, images with label  3  also have  6.131520940484937  other labels.
    On average, images with label  4  also have  6.219187208527648  other labels.
    On average, images with label  5  also have  6.536933407946279  other labels.
    On average, images with label  6  also have  6.533908387864367  other labels.
    On average, images with label  7  also have  6.485973817793214  other labels.
    On average, images with label  8  also have  6.1241642788920725  other labels.
    On average, images with label  9  also have  5.94092288040875  other labels.
    On average, images with label  10  also have  6.983303518187239  other labels.
    On average, images with label  11  also have  6.1974066621953945  other labels.
    

    对于验证集:

    Avg. num labels per image =  6.001282051282051
    On average, images with label  0  also have  6.0  other labels.
    On average, images with label  1  also have  3.987080103359173  other labels.
    On average, images with label  2  also have  6.0  other labels.
    On average, images with label  3  also have  5.507731958762887  other labels.
    On average, images with label  4  also have  5.506459948320414  other labels.
    On average, images with label  5  also have  5.00169779286927  other labels.
    On average, images with label  6  also have  5.6729452054794525  other labels.
    On average, images with label  7  also have  6.0  other labels.
    On average, images with label  8  also have  6.0  other labels.
    On average, images with label  9  also have  5.506459948320414  other labels.
    On average, images with label  10  also have  3.0  other labels.
    On average, images with label  11  also have  4.666095890410959  other labels.
    

    评论: 我认为这不仅与分布之间的差异有关,因为如果模型能够很好地概括类10(意味着对象在训练过程中被正确识别,如类0),验证集的准确性就足够了。我的意思是,问题在于训练集本身以及它是如何构建的,而不仅仅是两种分布之间的差异。它可以是:类或对象存在的频率非常相似(如类10非常相似于类9)或数据集或薄对象内部的偏差(可能代表类2输入图像中像素的1%或2%)。我并不是说问题是其中之一,但我只是想指出,我认为这不仅仅是两种分布之间的差异。

    1 回复  |  直到 6 年前
        1
  •  7
  •   Dennis Soemers    6 年前

    输出校准

    首先,我认为有一点很重要,那就是神经网络的输出可能很差 已校准 . 我的意思是,它给不同实例的输出可能会导致很好的排名(标签为L的图像往往比没有标签的图像在该标签上的得分更高),但这些得分不能总是可靠地解释为概率(它可能会给出很高的分,如 0.9 ,对于没有标签的实例,只需给出更高的分数,如 0.99 ,到带有标签的实例)。我想这是否会发生取决于你选择的损失函数。

    有关此方面的更多信息,请参阅示例: https://arxiv.org/abs/1706.04599


    逐一完成所有课程

    0级: AUC(曲线下面积)=0.99。那是一个很好的分数。混淆矩阵中的列0看起来也很好,所以这里没有问题。

    第1类: AUC=0.44。这太可怕了,低于0.5,如果我没有弄错的话,这意味着你最好还是故意 您的网络对此标签的预测。

    看看你的混淆矩阵中的第1列,它在所有地方的分数都差不多。对我来说,这表明网络并没有对这门课有太多的了解,而只是根据训练集中包含这个标签的图像的百分比(55.6%)进行“猜测”。由于这个百分比在验证集中下降到了50%,这个策略确实意味着它会比随机策略做得稍差。尽管如此,第1行在该列中的行数仍然最多,因此它似乎至少学到了一点点,但并没有学到多少。

    第2类: AUC=0.96。那很好。

    您对这个类的解释是,根据整个列的明暗处理,它总是被预测为不存在。但我不认为这种解释是正确的。查看其得分情况>对角线上为0,列中其他位置仅为0。该行的分数可能相对较低,但很容易与同一列中的其他行分离。您可能只需要设置阈值来选择该标签是否相对较低。我怀疑这是由于上面提到的校准问题。

    这也是AUC实际上非常好的原因;可以选择一个阈值,以便分数高于该阈值的大多数实例正确地具有标签,而分数低于该阈值的大多数实例则不具有标签。但该阈值可能不是0.5,如果假设校准良好,这是您可能期望的阈值。绘制该特定标签的ROC曲线可能有助于您确定阈值的确切位置。

    第3类: AUC=0.9,很好。

    您将其解释为总是被检测到存在,并且混淆矩阵确实在列中有很多高数字,但AUC很好,对角线上的单元格确实有足够高的值,可以很容易地将其与其他单元格分离。我怀疑这与第2类类似(只是翻了一下,到处都是高预测,因此正确决策需要高阈值)。

    如果您希望能够确定一个精心选择的阈值是否确实能够正确地将大多数“积极”(具有类3的实例)与大多数“消极”(没有类3的实例)区分开来,那么您需要根据标签3的预测分数对所有实例进行排序,然后遍历整个列表,并在每对连续条目之间计算验证集的精度,如果您决定将阈值放在那里,则会得到该精度,然后选择最佳阈值。

    第4类: 与0类相同。

    第5类: AUC=0.01,显然很糟糕。也同意您对混淆矩阵的解释。很难确定为什么这里的表现如此糟糕。也许这是一种很难识别的物体?可能还存在一些过度拟合(从第二个矩阵中的列判断,训练数据中的0个误报,尽管还有其他类会发生这种情况)。

    从培训到验证数据,标签5图像的比例增加了,这可能也没有帮助。这意味着,对于网络来说,在培训期间在该标签上表现良好的重要性不如在验证期间表现良好的重要性。

    第6类: AUC=0.52,仅略好于随机。

    根据第一个矩阵中的第6列判断,这实际上可能与第2类类似。但是,如果我们也考虑AUC,它似乎也没有很好地学习对实例进行排名。与5班相似,只是没那么糟糕。同样,培训和验证分布也非常不同。

    第7类:

    第8类: AUC=0.97,非常好,类似于3级。

    第9类: AUC=0.82,虽然不太好,但仍然很好。矩阵中的列有如此多的暗细胞,而且数字非常接近,因此AUC在我看来出人意料地好。它几乎出现在训练数据中的每一张图像中,所以它被预测为经常出现也就不足为奇了。也许有些非常暗的细胞仅仅基于少量的绝对图像?这将是一个有趣的问题。

    第10类: AUC=0.09,很糟糕。对角线上的0非常令人担忧(您的数据标记是否正确?)。根据第一个矩阵的第10行,第3类和第9类似乎很容易混淆(棉花和primary\u cution\u刀看起来像secondary\u cution\u刀吗?)。也可能是对训练数据的过度拟合。

    第11类: AUC=0.5,不优于随机。由于大多数训练图像中都存在此标签,但只有少数验证图像中存在此标签,因此可能会出现性能不佳(以及矩阵中明显过高的分数)。


    还需要绘制/测量什么?

    为了更深入地了解您的数据,我将首先绘制每个类共同发生的频率的热图(一个用于培训,一个用于验证数据)。单元格(i,j)将根据包含标签i和j的图像的比率进行着色。这将是一个对称图,对角线上的单元格将根据您问题中的第一组数字进行着色。比较这两个热图,看看它们有什么不同,看看这是否有助于解释模型的性能。

    此外,了解(对于两个数据集)每个图像平均有多少个不同的标签,以及每个标签平均有多少个其他标签共享一个图像可能很有用。例如,我怀疑标签为10的图像在训练数据中的其他标签相对较少。如果网络识别出其他事物,这可能会阻止网络预测标签10,如果标签10突然在验证数据中更经常地与其他对象共享图像,则会导致性能不佳。由于伪代码可能比单词更容易理解这一点,所以打印以下内容可能会很有趣:

    # Do all of the following once for training data, AND once for validation data    
    tot_num_labels = 0
    for image in images:
        tot_num_labels += len(image.get_all_labels())
    avg_labels_per_image = tot_num_labels / float(num_images)
    print("Avg. num labels per image = ", avg_labels_per_image)
    
    for label in range(num_labels):
        tot_shared_labels = 0
        for image in images_with_label(label):
            tot_shared_labels += (len(image.get_all_labels()) - 1)
        avg_shared_labels = tot_shared_labels / float(len(images_with_label(label)))
        print("On average, images with label ", label, " also have ", avg_shared_labels, " other labels.")
    

    对于单个数据集,这并不能提供太多有用的信息,但如果您对训练集和验证集这样做,那么如果数字非常不同,则可以看出它们的分布非常不同

    最后,我有点担心第一个矩阵中的一些列 确切地 相同的平均预测出现在许多不同的行上。我不太确定是什么导致了这种情况,但这可能有助于调查。


    如何改进?

    如果你还没有,我建议你 数据扩充 用于您的培训数据。由于您正在处理图像,因此可以尝试将现有图像的旋转版本添加到数据中。

    具体来说,对于多标签情况,目标是检测不同类型的对象,也可以尝试简单地将一组不同的图像(例如,两个或四个图像)连接在一起。然后,可以将它们缩小到原始图像大小,并作为标签指定原始标签集的并集。在合并图像的边缘会出现有趣的间断,我不知道这是否有害。在我看来,也许你的多目标检测不值得一试。