代码之家  ›  专栏  ›  技术社区  ›  Rocketq

如何从pytorch网络中删除maxpooling层

  •  3
  • Rocketq  · 技术社区  · 6 年前

    我在PyTorchV3.1上有一个带有resnet编码器的“简单”Unet,它工作得非常好:

    class UNetResNet(nn.Module):
        """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.
    
        UNet: https://arxiv.org/abs/1505.04597
        ResNet: https://arxiv.org/abs/1512.03385
    
        Args:
                encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
                num_classes (int): Number of output classes.
                num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
                dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
                pretrained (bool, optional):
                    False - no pre-trained weights are being used.
                    True  - ResNet encoder is pre-trained on ImageNet.
                    Defaults to False.
                is_deconv (bool, optional):
                    False: bilinear interpolation is used in decoder.
                    True: deconvolution is used in decoder.
                    Defaults to False.
    
        """
    
        def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                     pretrained=False, is_deconv=False):
            super().__init__()
            self.num_classes = num_classes
            self.dropout_2d = dropout_2d
    
            if encoder_depth == 34:
                self.encoder = torchvision.models.resnet34(pretrained=pretrained)
                bottom_channel_nr = 512
            elif encoder_depth == 101:
                self.encoder = torchvision.models.resnet101(pretrained=pretrained)
                bottom_channel_nr = 2048
            elif encoder_depth == 152:
                self.encoder = torchvision.models.resnet152(pretrained=pretrained)
                bottom_channel_nr = 2048
            else:
                raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
    
            self.pool = nn.MaxPool2d(2, 2)
    
            self.relu = nn.ReLU(inplace=True)
    
            self.conv1 = nn.Sequential(self.encoder.conv1,
                                       self.encoder.bn1,
                                       self.encoder.relu,
                                       self.pool)
    
            self.conv2 = self.encoder.layer1
    
            self.conv3 = self.encoder.layer2
    
            self.conv4 = self.encoder.layer3
    
            self.conv5 = self.encoder.layer4
    
            self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
                                       is_deconv)
            self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
                                       is_deconv)
            self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                       is_deconv)
            self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0 = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
    
        def forward(self, x):
            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            conv5 = self.conv5(conv4)
    
            pool = self.pool(conv5) # that pool layer I would like to delete
            center = self.center(pool)
    
    
            dec5 = self.dec5(torch.cat([center, conv5], 1))
    
            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
    
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))
    

    forward 我们来下一个 向前地

    def forward(self, x):
            conv1 = self.conv1(x)
            conv2 = self.conv2(conv1)
            conv3 = self.conv3(conv2)
            conv4 = self.conv4(conv3)
            conv5 = self.conv5(conv4)
    
            center = self.center(conv5) #now center connects with conv5
    
            dec5 = self.dec5(torch.cat([center, conv5], 1))
    
            dec4 = self.dec4(torch.cat([dec5, conv4], 1))
            dec3 = self.dec3(torch.cat([dec4, conv3], 1))
            dec2 = self.dec2(torch.cat([dec3, conv2], 1))
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
    
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))
    

    仅应用这些更改,我们将面临这样的错误:

    ~/Desktop/ml/salt/unet_models.py in forward(self, x)
        397         center = self.center(conv5)
        398 
    --> 399         dec5 = self.dec5(torch.cat([center, conv5], 1))
        400 
        401         dec4 = self.dec4(torch.cat([dec5, conv4], 1))
    
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 4 and 8 in dimension 2 at /pytorch/torch/lib/THC/generic/THCTensorMath.cu:111
    

    所以我们得想办法改变12月4日,但怎么改变呢?

    更多信息:

        class DecoderBlockV2(nn.Module):
            def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
                super(DecoderBlockV2, self).__init__()
                self.in_channels = in_channels
    
                if is_deconv:
                    """
                        Paramaters for Deconvolution were chosen to avoid artifacts, following
                        link https://distill.pub/2016/deconv-checkerboard/
                    """
    
                    self.block = nn.Sequential(
                        ConvRelu(in_channels, middle_channels),
                        nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                           padding=1),
                nn.BatchNorm2d(out_channels), 
                        nn.ReLU(inplace=True)
                    )
                else:
                    self.block = nn.Sequential(
                        nn.Upsample(scale_factor=2, mode='bilinear'),
                        ConvRelu(in_channels, middle_channels),
                        ConvRelu(middle_channels, out_channels),
                    )
    
            def forward(self, x):
                return self.block(x)
    
    class ConvRelu(nn.Module):
        def __init__(self, in_, out):
            super().__init__()
            self.conv = conv3x3(in_, out)
            self.activation = nn.ReLU(inplace=True)
    
        def forward(self, x):
            x = self.conv(x)
            x = self.activation(x)
            return x
    
    1 回复  |  直到 6 年前
        1
  •  2
  •   Shai    6 年前

    torch.cat ) center conv5 -但是这些张量的维数不匹配。

    • 转换5 :4x4个
    • pool :2x2个
    • 居中 这样你就可以 居中 因为它们都是4x4张量。
      但是,如果删除池层,则会导致

    • 转换5 :4x4个

    • 居中 :8x8#超出需要的上采样 现在你不能再继续了 居中 转换5 因为它们的空间尺寸不匹配。

    1. 居中 以及池,留给你的

      dec5 = self.dec5(conv5)
      
    2. 另一种选择是移除 nn.Upsample 部分来自 ,留给你

      self.block = nn.Sequential(
                      ConvRelu(in_channels, middle_channels),
                      ConvRelu(middle_channels, out_channels),
                  )