代码之家  ›  专栏  ›  技术社区  ›  anti

LibTorch,使用deeplab模型在转发时提供segfault

  •  0
  • anti  · 技术社区  · 5 年前

    我正试图通过Libtorch的DeepLab模型来分割图像。使用pytorch,我正在转换Deeplabv3模型,如下所示:

    import torch
    import torchvision
    from torchvision import models
    
    deeplap_model = models.segmentation.deeplabv3_resnet101(pretrained=True)
    deeplap_model.eval()
    
    class wrapper(torch.nn.Module):
        def __init__(self, model):
            super(wrapper, self).__init__()
            self.model = model
    
        def forward(self, input):
            results = []
            output = self.model(input)
            for k, v in output.items():
                results.append(v)
            return tuple(results)
    
    model = wrapper(deeplap_model)
    
    example = torch.rand(1, 3, 224, 224)
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    
    traced_script_module.save("model.pt")
    

    现在,在带有LibTorch的c++中,我试图加载模型并通过它运行数据。然而,这一点失败了:

    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("model.pt");
    module->to(torch::kCUDA);
    assert(module != nullptr);
    std::cout << "ok\n";
    
    
    std::vector<torch::jit::IValue> inputs;
    cv::Mat image;
    image = cv::imread("pic.jpeg", 1);
    cv::Mat image_resized;
    cv::resize(image, image_resized, cv::Size(224, 224));
    cv::cvtColor(image_resized, image_resized, cv::COLOR_BGR2RGB);
    cv::Mat image_resized_float;
    image_resized.convertTo(image_resized_float, CV_32F, 1.0 / 255);
    
    auto img_tensor = torch::from_blob(image_resized_float.data, { 1, 224, 224, 3 }, torch::kFloat32);
    cout << "img tensor loaded..\n";
    img_tensor = img_tensor.permute({ 0, 3, 1, 2 });
    img_tensor[0][0] = img_tensor[0][0].sub(0.485).div(0.229);
    img_tensor[0][1] = img_tensor[0][1].sub(0.456).div(0.224);
    img_tensor[0][2] = img_tensor[0][2].sub(0.406).div(0.225);
    
    // to GPU
    img_tensor = img_tensor.to(at::kCUDA);
    
    torch::Tensor out_tensor2 = module->forward({ img_tensor }).toTensor(); //SEGFAULT
    

    我哪里做错了?

    0 回复  |  直到 5 年前