代码之家  ›  专栏  ›  技术社区  ›  sachinruk

PyTorch用索引列表修改数组

  •  1
  • sachinruk  · 技术社区  · 6 年前

    假设我有一个索引列表,并希望用这个列表修改一个现有的数组。目前唯一能做到这一点的方法是使用for循环,如下所示。只是想知道是否有更快/更有效的方法。

    torch.manual_seed(0)
    a = torch.randn(5,3)
    idx = torch.Tensor([[1,2], [3,2]], dtype=torch.long)
    for i,j in idx:
        a[i,j] = 1
    

    我最初假设 gather index_select 在回答这个问题时会有一些办法,但是看看 documentation 这似乎不是答案。

    在我的特殊情况下,a是5维向量,idx是nx5向量。所以输出(订阅之后 a[idx] )我想是一个 (N,) 形状向量。

    回答

    感谢下面的@shai,我想要的答案是: a[idx.t().chunk(chunks=2,dim=0)] . 从中吸取 SO answer .

    1 回复  |  直到 6 年前
        1
  •  1
  •   Shai    6 年前

    这很简单

    a[idx[:,0], idx[:,1]] = 1
    

    您可以在 this thread .