代码之家  ›  专栏  ›  技术社区  ›  buydadip

无法打印模型的混淆矩阵

  •  1
  • buydadip  · 技术社区  · 7 年前

    我实现了 MLP 它工作得很好。然而,我在打印混淆矩阵时遇到了一个问题。

    我的模型定义为。。。

    logits = layers(X, weights, biases)
    

    哪里

    def layers(x, weights, biases):
        layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
        layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
        out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
    
        return out_layer
    

    我在 mnist 数据集。经过培训,我能够成功打印出模型的准确性。。。

    pred = tf.nn.softmax(logits)
    
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print("Accuracy: ", accuracy.eval({X:mnist.test.images, y:mnist.test.labels}))
    

    准确率达到90%。现在我想打印出结果的混淆矩阵。我尝试了以下方法。。。

    confusion = tf.confusion_matrix(
             labels=mnist.test.labels, predictions=correct_prediction)
    

    但这给了我错误。。。

    ValueError:无法挤压尺寸[1],预期尺寸为1,对于输入形状为[10000,10]的“混乱矩阵/删除\u可挤压尺寸/挤压”(op:“挤压”)获得10。

    打印混淆矩阵的正确方法是什么?我已经挣扎了一段时间了。

    3 回复  |  直到 7 年前
        1
  •  2
  •   buydadip    7 年前

    看起来是 tf.confusion_matrix 具有10作为秒dim。问题是如果 mnist.test.labels correct_prediction 是否有一个热编码?这就可以解释了。这里需要标签作为一维张量。你能打印出这两个张量的形状吗?

    而且看起来 正确的\u预测 是一个布尔张量,用于标记预测是否准确。对于需要预测标签的混淆矩阵,应该是 tf.argmax( pred, 1 ) 相反类似地,如果您的标签是一个热编码的标签,则需要对其进行解码以获得混淆矩阵。所以试试这句话 confusion :

    confusion = tf.confusion_matrix(
         labels = tf.argmax( mnist.test.labels, 1 ),
         predictions = tf.argmax( pred, 1 ) )
    

    为了打印混淆矩阵本身,有必要使用 eval 最终结果如下:

    print(confusion.eval({x:mnist.test.images, y:mnist.test.labels}))
    
        2
  •  2
  •   Max Kleiner    6 年前

    这对我很有用:

    confusion = tf.confusion_matrix(
           labels = tf.argmax( mnist.test.labels, 1 ),
           predictions = tf.argmax( y, 1 ) )
       print(confusion.eval({x:mnist.test.images, y_:mnist.test.labels})) 
    
    [[ 960    0    2    2    1    5    7    2    1    0]
     [   0 1113    3    2    0    1    4    2   10    0]
     [   6    7  941   15   12    2   10    8   27    4]
     [   2    1   27  926    1   12    1    8   24    8]
     [   1    2    6    1  928    0    9    2    9   24]
     [   9    2    8   51   12  729   15    9   50    7]
     [  13    3   10    2    9    9  905    2    5    0]
     [   1    9   28    8   11    1    0  938    3   29]
     [   6   10    7   19    9   13    8    5  891    6]
     [   9    7    2    9   43    5    0   14   12  908]]
    
        3
  •  0
  •   Max Kleiner    6 年前

    对于NLTK混淆矩阵,您需要一个列表

    classifier = NaiveBayesClassifier.train(trainfeats)
    refsets = collections.defaultdict(set)
    testsets = collections.defaultdict(set)
    
    lsum = []
    tsum = []
    
    for i, (feats, label) in enumerate(testfeats):
      refsets[label].add(i)
      observed = classifier.classify(feats)
      testsets[observed].add(i)
      lsum.append(label)
      tsum.append(observed
    
    print (nltk.ConfusionMatrix(lsum,tsum))