代码之家  ›  专栏  ›  技术社区  ›  singa1994

Pytorch中的批量矩阵乘法-与输出维数的处理混淆

  •  2
  • singa1994  · 技术社区  · 5 年前

    我有两个阵列:

    A
    B
    

    阵列 A 包含一批RGB图像,形状:

    [batch, Width, Height, 3]
    

    B 包含对图像进行“类似变换”操作所需的系数,形状:

    [batch, 4, 4, 3]
    

    简单地说,对单个图像的操作是输出环境映射的乘法( normalMap * Coefficients ).

    我想要的输出应该保持形状:

    [批次,宽度,高度,3]
    

    我试着用 torch.bmm 但是失败了。这有可能吗?

    0 回复  |  直到 5 年前
        1
  •  1
  •   prosti    4 年前

    我想你需要计算一下Pythorch的工作原理

    BxCxHxW : number of mini-batches, channels, height, width
    

    格式,也可以使用 matmul ,因为 bmm 使用张量或ndim/dim/rank=3。

    我知道你可能会在网上找到这个,但无论如何:

    batch1 = torch.randn(10, 3, 20, 10)
    batch2 = torch.randn(10, 3, 10, 30)
    res = torch.matmul(batch1, batch2)
    res.size() # torch.Size([10, 3, 20, 30])