代码之家  ›  专栏  ›  技术社区  ›  Paul Terwilliger

在chainer中使用数组作为MNIST数据的标签

  •  1
  • Paul Terwilliger  · 技术社区  · 7 年前

    python模块 chainer 有一个 introduction 它使用其神经网络识别来自 MNIST database .

    假设一个特定的手写数字 D.png 标记为 3 . 我已经习惯了以数组形式出现的标签,如下所示:

    label = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    

    然而 chainer 改为使用整数标记:

    label = 3
    

    数组标签对我来说更直观,因为输出预测也是一个数组。在不处理图像的神经网络中,我希望能够灵活地将标签设置为特定的数组。

    我直接从chainer简介中包含了下面的代码。如果你通过解析 train test 请注意,所有标签都是整数,而不是浮点。

    如何使用数组作为标签而不是整数来运行训练/测试数据?

    import numpy as np
    import chainer
    from chainer import cuda, Function, gradient_check, report, training, utils, Variable
    from chainer import datasets, iterators, optimizers, serializers
    from chainer import Link, Chain, ChainList
    import chainer.functions as F
    import chainer.links as L
    from chainer.training import extensions
    
    class MLP(Chain):
        def __init__(self, n_units, n_out):
            super(MLP, self).__init__()
            with self.init_scope():
                # the size of the inputs to each layer will be inferred
                self.l1 = L.Linear(None, n_units)  # n_in -> n_units
                self.l2 = L.Linear(None, n_units)  # n_units -> n_units
                self.l3 = L.Linear(None, n_out)    # n_units -> n_out
    
        def __call__(self, x):
            h1 = F.relu(self.l1(x))
            h2 = F.relu(self.l2(h1))
            y = self.l3(h2)
            return y
    
    train, test = datasets.get_mnist()
    
    train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
    test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)
    
    model = L.Classifier(MLP(100, 10))  # the input size, 784, is inferred
    optimizer = optimizers.SGD()
    optimizer.setup(model)
    
    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (20, 'epoch'), out='result')
    
    trainer.extend(extensions.Evaluator(test_iter, model))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
    trainer.extend(extensions.ProgressBar())
    trainer.run()
    
    1 回复  |  直到 7 年前
        1
  •  1
  •   TulakHord    6 年前

    分类器接受包含图像或其他数据的元组作为数组(float32),标签作为int。这是chainer的约定以及它在那里的工作方式。 如果打印标签,您将看到得到一个具有dtype int的数组。图像/非图像数据和标签都将位于数组中,但分别具有dtype float和int。

    为了回答您的问题:标签本身是数组格式的,带有dtype int(标签应该是这样的)。

    如果希望标签为0和1,而不是1到10,则使用一种热编码( https://blog.cambridgespark.com/robust-one-hot-encoding-in-python-3e29bfcec77e ).