代码之家  ›  专栏  ›  技术社区  ›  utengr

使用DEV PyTr火炬1将PyTrac模型加载到C++中

  •  3
  • utengr  · 技术社区  · 6 年前

    详细信息在本教程中。 https://pytorch.org/tutorials/advanced/cpp_export.html

    import torch
    import torchvision
    
    # An instance of your model.
    model = A UNET MODEL FROM FASTAI which has hooks as required by UNET
    
    # An example input you would normally provide to your model's forward() method.
    example = torch.rand(1, 3, 224, 224)
    
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(model, example)
    

    在我的用例中,我使用一个UNET模型进行语义分割。但是,我使用这种方法跟踪模型,得到以下错误。

    Forward or backward hooks can't be compiled 
    

    UNET模型使用钩子来保存中间特征,这些特征将在网络的后续层中使用。有办法吗?或者这仍然是这种新方法的一个局限性,即它无法处理使用这种钩子的模型。

    1 回复  |  直到 6 年前
        1
  •  0
  •   Abhishek Sharma    4 年前

    如果您可以使用Pytorch hub的UNET模型。它将与TorchScript一起工作。

    import torch
    
    # downloading the model from torchhub
    model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
        in_channels=3, out_channels=1, init_features=32, pretrained=True)
    
    #  downloading the sample
    import urllib
    url, filename = ("https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png", "TCGA_CS_4944.png")
    try: urllib.URLopener().retrieve(url, filename)
    except: urllib.request.urlretrieve(url, filename)
        
    # reading the sample and some prerequisites for transformation
    import numpy as np
    from PIL import Image
    from torchvision import transforms
    
    input_image = Image.open(filename)
    
    m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
    preprocess = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=m, std=s),])
    
    input_tensor = preprocess(input_image)
    
    input_batch = input_tensor.unsqueeze(0)
    
    # creating the trace
    traced_module = torch.jit.trace(model,input_batch)
    
    # running the trace
    traced_module(input_batch)
    

    PS:torch.jit.trace/torch.jit.script都不支持所有的torch功能,所以在外部库中使用它们总是很棘手的。

        2
  •  -1
  •   Yanlin Qiu    5 年前