代码之家  ›  专栏  ›  技术社区  ›  ch1maera

通过图像通道求平均值和标准偏差PyTorch

  •  0
  • ch1maera  · 技术社区  · 5 年前

    假设我有一批尺寸为(B x C x W x H)的张量形式的图像,其中B是批大小,C是图像中的通道数,W和H分别是图像的宽度和高度。我想用 transforms.Normalize() 通过C图像通道 ,这意味着我想要一个形式为1x C的结果张量。有没有一种简单的方法可以做到这一点?

    我试过了 torch.view(C, -1).mean(1) torch.view(C, -1).std(1) 但我得到了错误:

    view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
    

    编辑

    view() 在PyTorch工作,我知道为什么我的方法不起作用;但是,我仍然不知道如何得到每个通道的平均值和标准偏差。

    0 回复  |  直到 5 年前
        1
  •  2
  •   trsvchn    5 年前

    您只需要以正确的方式重新排列批处理张量:从 [B, C, W, H] [B, C, W * H] 签署人:

    batch = batch.view(batch.size(0), batch.size(1), -1)
    

    代码:

    import torch
    from torch.utils.data import TensorDataset, DataLoader
    
    data = torch.randn(64, 3, 28, 28)
    labels = torch.zeros(64, 1)
    dataset = TensorDataset(data, labels)
    loader = DataLoader(dataset, batch_size=8)
    
    nimages = 0
    mean = 0.
    std = 0.
    for batch, _ in loader:
        # Rearrange batch to be the shape of [B, C, W * H]
        batch = batch.view(batch.size(0), batch.size(1), -1)
        # Update total number of images
        nimages += batch.size(0)
        # Compute mean and std here
        mean += batch.mean(2).sum(0) 
        std += batch.std(2).sum(0)
    
    # Final step
    mean /= nimages
    std /= nimages
    
    print(mean)
    print(std)
    

    输出:

    tensor([-0.0029, -0.0022, -0.0036])
    tensor([0.9942, 0.9939, 0.9923])