代码之家  ›  专栏  ›  技术社区  ›  Jame

如何将自己模型的权重传递到最后一层中相同网络但不同数量的类?

  •  0
  • Jame  · 技术社区  · 6 年前

    我在Pytorch有自己的网络。它首先针对二进制分类器(2类)进行训练。经过10公里的时间,我获得了训练后的体重 10000_model.pth . 现在,我想用这个模型来解决使用相同网络的4类分类器问题。因此,我想将二进制分类器中所有训练过的权重转移到4类问题中,而不需要lass层进行随机初始化。我怎么做呢?这是我的模型

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5, 1)
            self.conv2 = nn.Conv2d(20, 50, 5, 1)
            self.conv_classify= nn.Conv2d(50, 2, 1, 1, bias=True) # number of class
    
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(self.conv_classify(x))
            return x
    

    这就是我所做的

    model = Net ()
    checkpoint_dict = torch.load('10000_model.pth')        
    pretrained_dict = checkpoint_dict['state_dict']
    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    pretrained_dict.pop('conv_classify.weight', None)
    pretrained_dict.pop('conv_classify.bias', None)
    

    这意味着 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 什么都不做。

    怎么了?我正在使用pytorch 1.0。谢谢

    1 回复  |  直到 6 年前
        1
  •  2
  •   Jatentaki    6 年前

    两个网络具有相同的层,因此在网络中具有相同的密钥 state_dict 的确如此

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    

    什么也不做。两者的区别在于 权张量 (他们的形状)而不是他们的名字。换句话说,您可以通过 [v.shape for v in model.state_dict().values()] 但不是 model.state_dict().keys()

    merged_dict = {}
    for key in model_dict.keys():
        if 'conv_classify' in key: # or perhaps a more complex criterion
            merged_dict[key] = model_dict[key]
        else:
            merged_dict[key] = pretrained_dict[key]