代码之家  ›  专栏  ›  技术社区  ›  plhn

按索引在numpy中选取

  •  3
  • plhn  · 技术社区  · 5 年前

    假设两个数组

    ind = 
    array([[1, 3, 2, 4, 0],
           [0, 1, 3, 2, 4],
           [3, 4, 2, 0, 1]])
    
    x =
    array([[[24, 97, 28, 57, 59],
            [97, 67, 94, 77, 50],
            [56, 89, 25, 55, 76],
            [88, 21,  1, 50, 24]],
    
           [[54, 83, 64, 81, 12],
            [89, 49, 15, 26, 97],
            [94, 97, 32, 55, 79],
            [24, 63, 63, 15, 40]],
    
           [[41, 99, 84, 64, 21],
            [12,  9, 85, 43, 28],
            [75, 98, 48, 10,  0],
            [93, 94, 37, 22, 63]]])
    

    我想根据第一个数组对第二个数组重新排序(第一个数组是索引)

    所以,也许结果如下。

    array([[[97, 57, 28, 59, 24], 
            [67, 77, 94, 50, 97],
            [89, 55, 25, 76, 56],
            [21, 50,  1, 24, 88]],
    
           [[54, 83, 81, 64, 12],
            [89, 49, 26, 15, 97],
            [94, 97, 55, 32, 79],
            [24, 63, 15, 63, 40]],
    
           [[64, 21, 84, 41, 99],
            [43, 28, 85, 12,  9],
            [10,  0, 48, 75, 98],  
            [22, 63, 37, 93, 94]]])
    # x[0]s are reordered by ind[0] and so on.
    

    这有可能吗 np.take ?

    1 回复  |  直到 5 年前
        1
  •  2
  •   Paul Panzer    5 年前

    很容易使用 take_along_axis :

    >>> np.take_along_axis(x, ind[:, None, :], 2)
    array([[[97, 57, 28, 59, 24],
            [67, 77, 94, 50, 97],
            [89, 55, 25, 76, 56],
            [21, 50,  1, 24, 88]],
    
           [[54, 83, 81, 64, 12],
            [89, 49, 26, 15, 97],
            [94, 97, 55, 32, 79],
            [24, 63, 15, 63, 40]],
    
           [[64, 21, 84, 41, 99],
            [43, 28, 85, 12,  9],
            [10,  0, 48, 75, 98],
            [22, 63, 37, 93, 94]]])
    

    如果你在1.15年前,你可以:

    >>> m,n,k = x.shape
    >>> m,n,k = np.ogrid[:m, :n, :k]
    >>> x[m,n,ind[:, None, :]]
    array([[[97, 57, 28, 59, 24],
            [67, 77, 94, 50, 97],
            [89, 55, 25, 76, 56],
            [21, 50,  1, 24, 88]],
    
           [[54, 83, 81, 64, 12],
            [89, 49, 26, 15, 97],
            [94, 97, 55, 32, 79],
            [24, 63, 15, 63, 40]],
    
           [[64, 21, 84, 41, 99],
            [43, 28, 85, 12,  9],
            [10,  0, 48, 75, 98],
            [22, 63, 37, 93, 94]]])