代码之家  ›  专栏  ›  技术社区  ›  Wasi Ahmad

参数维数对聚集函数的影响

  •  3
  • Wasi Ahmad  · 技术社区  · 7 年前

    我正在尝试使用 gather 在Pytork中的功能,但无法理解 dim 参数

    t = torch.Tensor([[1,2],[3,4]])
    print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]])))
    

    输出:

     1  2
     3  2
    [torch.FloatTensor of size 2x2]
    

    print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])))
    

    输出变为:

     1  1
     4  3
    [torch.FloatTensor of size 2x2]
    

    怎样 gather 功能实际工作?

    2 回复  |  直到 7 年前
        1
  •  4
  •   Wasi Ahmad    6 年前

    我了解了收集函数的工作原理。

    t = torch.Tensor([[1,2],[3,4]])
    index = torch.LongTensor([[0,0],[1,0]])
    torch.gather(t, 0, index)
    

    dimension

    | t[index[0, 0], 0]   t[index[0, 1], 1] |
    | t[index[1, 0], 0]   t[index[1, 1], 1] |
    

    如果 设置为1时,输出将变为:

    | t[0, index[0, 0]]   t[0, index[0, 1]] |
    | t[1, index[1, 0]]   t[1, index[1, 1]] |
    

    所以公式是:

    For a 3-D tensor the output is specified by:
    
    out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
    out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
    out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2
    

    http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather

        2
  •  3
  •   GabrielChu    6 年前

    只需在现有答案的基础上添加一个 gather 是沿着指定维度收集分数。

    例如,我们有这样的设置:

    • 3类5例
    • 每个班级都有一个分数,每个例子都要做
    • y

    代码如下

    torch.manual_seed(0)
    
    num_examples = 5
    num_classes = 3
    scores = torch.randn(5, 3)
    
    #print of scores
    scores: tensor([[ 1.5410, -0.2934, -2.1788],
            [ 0.5684, -1.0845, -1.3986],
            [ 0.4033,  0.8380, -0.7193],
            [-0.4033, -0.5966,  0.1820],
            [-0.8567,  1.1006, -1.0712]])
    
    
    y = torch.LongTensor([1, 2, 1, 0, 2])
    res = scores.gather(1, y.view(-1, 1)).squeeze()
    

    输出:

    #print of gather results
    tensor([-0.2934, -1.3986,  0.8380, -0.4033, -1.0712])