代码之家  ›  专栏  ›  技术社区  ›  aerin

pytorch广播是如何工作的?

  •  3
  • aerin  · 技术社区  · 6 年前
    torch.add(torch.ones(4,1), torch.randn(4))
    

    生成一个张量,其大小为: torch.Size([4,4]) .

    有人能提供背后的逻辑吗?

    1 回复  |  直到 6 年前
        1
  •  10
  •   kmario23 Mazdak    6 年前
    numpy broadcasting rules PyTorch broadcasting guide

    In [27]: t_rand
    Out[27]: tensor([ 0.23451,  0.34562,  0.45673])
    
    In [28]: t_ones
    Out[28]: 
    tensor([[ 1.],
            [ 1.],
            [ 1.],
            [ 1.]])
    

    torch.add(t_rand, t_ones)

                   # shape of (3,)
                   tensor([ 0.23451,      0.34562,       0.45673])
          # (4, 1)          | | | |       | | | |        | | | |
          tensor([[ 1.],____+ | | |   ____+ | | |    ____+ | | |
                  [ 1.],______+ | |   ______+ | |    ______+ | |
                  [ 1.],________+ |   ________+ |    ________+ |
                  [ 1.]])_________+   __________+    __________+
    

    (4,3)

    # shape of (4,3)
    In [33]: torch.add(t_rand, t_ones)
    Out[33]: 
    tensor([[ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673]])
    

    # shape of (4, 3)
    In [34]: torch.add(t_ones, t_rand)
    Out[34]: 
    tensor([[ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673],
            [ 1.23451,  1.34562,  1.45673]])
    


    Example-1:

    broadcasting-1


    Example-2:

    theano broadcasting

    T F True False Theano


    Example-3:

    b a

    broadcastable shapes