代码之家  ›  专栏  ›  技术社区  ›  Hameer Abbasi

获取数组中匹配元素的索引,考虑重复

  •  3
  • Hameer Abbasi  · 技术社区  · 7 年前

    我想要类似SQL的东西 WHERE Numpy中具有两个数组的表达式。假设我有两个类似这样的数组:

    import numpy as np
    dt = np.dtype([('f1', np.uint8), ('f2', np.uint8), ('f3', np.float_)])
    a = np.rec.fromarrays([[3,    4,    4,   7,    9,    9],
                           [1,    5,    5,   4,    2,    2],
                           [2.0, -4.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
    b = np.rec.fromarrays([[ 1,    4,   7,    9,    9],
                           [ 7,    5,   4,    2,    2],
                           [-3.5, -4.5, 1.3, 24.3, 24.3]], dtype=dt)
    

    我想返回原始数组的索引,以便 涵盖了所有可能的匹配对 . 此外,我想利用这一事实,即两个数组都已排序,因此最坏的情况是 O(mn) 不需要算法。在这种情况下,因为 (4, 5, -4.5) 匹配,但在第一个数组中出现两次,它将在结果索引中出现两次,并且 (9, 2, 24.3) 这两种情况都会发生两次,总共会发生4次。自从 (3, 1, 2.0) 不会出现在第二个数组中,它将被跳过,因此 (1, 7, -3.5) 在第二个数组中。该函数应适用于任何 dtype .

    在这种情况下,结果如下:

    a_idx, b_idx = match_arrays(a, b)
    a_idx = np.array([1, 2, 3, 4, 4, 5, 5])
    b_idx = np.array([1, 1, 2, 3, 4, 3, 4])
    

    具有相同输出的另一个示例:

    dt2 = np.dtype([('f1', np.uint8), ('f2', dt)])
    a2 = np.rec.fromarrays([[3, 4, 4, 7, 9, 9], a], dtype=dt2)
    b2 = np.rec.fromarrays([[1, 4, 7, 9, 9], b], dtype=dt2)
    

    我有一个纯Python实现,但在我的用例中,它非常慢。我希望有更矢量化的东西。以下是我目前掌握的情况:

    def match_arrays(a, b):
        len_a = len(a)
        len_b = len(b)
    
        a_idx = []
        b_idx = []
    
        i, j = 0, 0
    
        first_matched_j = 0
    
        while i < len_a and j < len_b:
            matched = False
            j = first_matched_j
    
            while j < len_b and a[i] == b[j]:
                a_idx.append(i)
                b_idx.append(j)
                if not matched:
                    matched = True
                    first_matched_j = j
    
                j += 1
            else:
                i += 1
    
            j = first_matched_j
    
            while i < len_a and j < len_b and a[i] > b[j]:
                j += 1
                first_matched_j = j
    
            while i < len_a and j < len_b and a[i] < b[j]:
                i += 1
    
        return np.array(a_idx), np.array(b_idx)
    

    编辑: Divakar 在他的 answer ,我可以使用 a_idx, b_idx = np.where(np.equal.outer(a, b)) . 然而,这似乎正是最坏的情况 O(锰) 我希望通过对数组进行预排序来避免解决方案。特别是如果是 O(m + n) 如果没有任何重复。

    编辑2: Paul Panzer answer 不是 O(m+n) 如果只是使用Numpy,但通常速度更快。此外,他还提供了 O(m+n) 回答,所以我接受这个答案。我将使用 timeit 希望很快。

    编辑3: 以下是承诺的性能结果:

    ╔════════════════╦═══════════════════╦═══════════════════╦═══════════════════╦══════════════════╦═══════════════════╗
    ║ User           ║ Version           ║ n = 10 ** 2       ║ n = 10 ** 4       ║ n = 10 ** 6      ║ n = 10 ** 8       ║
    ╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║ Paul Panzer    ║ USE_HEAPQ = False ║ 115 µs ± 385 ns   ║ 793 µs ± 8.43 µs  ║ 105 ms ± 1.57 ms ║ 18.2 s ± 116 ms   ║
    ║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║                ║ USE_HEAPQ = True  ║ 189 µs ± 3.6 µs   ║ 6.38 ms ± 28.8 µs ║ 650 ms ± 2.49 ms ║ 1min 11s ± 420 ms ║
    ╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║ SigmaPiEpsilon ║ Generator         ║ 936 µs ± 1.52 µs  ║ 9.17 s ± 57 ms    ║ N/A              ║ N/A               ║
    ║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║                ║ for loop          ║ 144 µs ± 526 ns   ║ 15.6 ms ± 18.6 µs ║ 1.74 s ± 33.9 ms ║ N/A               ║
    ╠════════════════╬═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║ Divakar        ║ np.where          ║ 39.1 µs ± 281 ns  ║ 302 ms ± 4.49 ms  ║ Out of memory    ║ N/A               ║
    ║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║                ║ recarrays 1       ║ 69.9 µs ± 491 ns  ║ 1.6 ms ± 24.2 µs  ║ 230 ms ± 3.52 ms ║ 41.5 s ± 543 ms   ║
    ║                ╠═══════════════════╬═══════════════════╬═══════════════════╬══════════════════╬═══════════════════╣
    ║                ║ recarrays 2       ║ 82.6 µs ± 1.01 µs ║ 1.4 ms ± 4.51 µs  ║ 212 ms ± 2.59 ms ║ 36.7 s ± 900 ms   ║
    ╚════════════════╩═══════════════════╩═══════════════════╩═══════════════════╩══════════════════╩═══════════════════╝
    

    看来 保罗装甲车 答复 轻松获胜 USE_HEAPQ = False . 我以为 USE_HEAPQ = True 赢得大量投入 O(m+n) 但事实并非如此。另一条评论 USE\u HEAPQ=False 版本使用的内存更少,最大为5.79 GB,而最大为10.18 GB USE\u HEAPQ=真 对于 n = 10 ** 8 . 请记住,这是进程内存,包括控制台的输入和其他内容。 迪瓦卡 的重新排列答案1使用了8.42 GB内存,而重新排列答案2使用了10.61 GB内存。

    3 回复  |  直到 7 年前
        1
  •  2
  •   Paul Panzer    7 年前

    这是一个O(n)-ish解决方案(ish,因为如果重复很长,它显然不可能是O(n))。实际上,根据输入长度的不同,可以通过牺牲O(n)和替换 heapq.merge 有一个马厩 np.argsort . 当N=10^6时,它需要大约1秒的时间。

    代码:

    import numpy as np
    
    USE_HEAPQ = True
    
    def sqlwhere(a, b):
        asw = np.r_[0, 1 + np.flatnonzero(a[:-1]!=a[1:]), len(a)]
        bsw = np.r_[0, 1 + np.flatnonzero(b[:-1]!=b[1:]), len(b)]
        al, bl = np.diff(asw), np.diff(bsw)
        na, nb = len(al), len(bl)
        abunq = np.r_[a[asw[:-1]], b[bsw[:-1]]]
        if USE_HEAPQ:
            from heapq import merge
            m = np.fromiter(merge(range(na), range(na, na+nb), key=abunq.__getitem__), int, na+nb)
        else:
            m = np.argsort(abunq, kind='mergesort')
        mv = abunq[m]
        midx = np.flatnonzero(mv[:-1]==mv[1:])
        ai, bi = m[midx], m[midx+1] - na
        aic = np.r_[0, np.cumsum(al[ai])]
        a_idx = np.ones((aic[-1],), dtype=int)
        a_idx[aic[:-1]] = asw[ai]
        a_idx[aic[1:-1]] -= asw[ai[:-1]] + al[ai[:-1]] - 1
        a_idx = np.repeat(np.cumsum(a_idx), np.repeat(bl[bi], al[ai]))
        bi = np.repeat(bi, al[ai])
        bic = np.r_[0, np.cumsum(bl[bi])]
        b_idx = np.ones((bic[-1],), dtype=int)
        b_idx[bic[:-1]] = bsw[bi]
        b_idx[bic[1:-1]] -= bsw[bi[:-1]] + bl[bi[:-1]] - 1
        b_idx = np.cumsum(b_idx)
        return a_idx, b_idx
    
    def f_D(a, b):
        return np.where(np.equal.outer(a,b))
    
    def mock_data(n):
        return np.cumsum(np.random.randint(0, 3, (2, n)), axis=1)
    
    
    a = np.array([3, 4, 4, 7, 9, 9], dtype=np.uint8)
    b = np.array([1, 4, 7, 9, 9], dtype=np.uint8)
    
    # check correct
    a, b = mock_data(1000)
    ai0, bi0 = f_D(a, b)
    ai1, bi1 = sqlwhere(a, b)
    print(np.all(ai0 == ai1), np.all(bi0 == bi1))
    
    # check fast
    a, b = mock_data(1000000)
    sqlwhere(a, b)
    
        2
  •  2
  •   Community CDub    4 年前

    方法#1:基于Broadacasting的方法

    使用 outer 两个数组之间的相等性比较以利用矢量化 broadcasting 然后得到行、列索引,这是两个数组对应的匹配索引所需要的-

    a_idx, b_idx = np.where(a[:,None]==b)
    a_idx, b_idx = np.where(np.equal.outer(a,b))
    

    我们还可以使用 np.nonzero 代替 np.where .

    方法#2:具体案例解决方案

    由于没有重复和排序的输入数组,我们可以使用 np.searchsorted ,就像这样-

    idx0 = np.searchsorted(a,b)
    idx1 = np.searchsorted(b,a)
    idx0[idx0==len(a)] = 0
    idx1[idx1==len(b)] = 0
    
    a_idx = idx0[a[idx0] == b]
    b_idx = idx1[b[idx1] == a]
    

    稍微修改一下,可能更有效的方法是-

    idx0 = np.searchsorted(a,b)
    idx0[idx0==len(a)] = 0
    
    a_idx = idx0[a[idx0] == b]
    b_idx = np.searchsorted(b,a[a_idx])
    

    方法#3:一般情况

    以下是一般情况的解决方案(允许重复)-

    def findwhere(a, b):
        c = np.bincount(b, minlength=a.max()+1)[a]
        a_idx1 = np.repeat(np.flatnonzero(c),c[c!=0])
        
        b_idx1 = np.searchsorted(b,a[a_idx1])
        m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
        idx11 = np.flatnonzero(m1[1:] != m1[:-1])
        id_arr = m1.astype(int)
        id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
        b_idx1 += id_arr.cumsum()[:-1]
        return a_idx1, b_idx1
    

    计时

    使用 mock_data 来自@Paul Panzer的soln,用于设置输入:

    In [295]: a, b = mock_data(1000000)
    
    # @Paul Panzer's soln
    In [296]: %timeit sqlwhere(a, b) # USE_HEAPQ = False
    10 loops, best of 3: 118 ms per loop
    
    # Approach #3 from this post
    In [297]: %timeit findwhere(a,b)
    10 loops, best of 3: 61.7 ms per loop
    

    将重新排列(uint8数据)转换为 1D 阵列

    def convert_recarrays_to_1Darrs(a, b):
        a2D = a.view('u1').reshape(-1,2)
        b2D = b.view('u1').reshape(-1,2)
        s = max(a2D[:,0].max(), b2D[:,0].max())+1
        
        a1D = s*a2D[:,1] + a2D[:,0]
        b1D = s*b2D[:,1] + b2D[:,0]
        return a1D, b1D
    

    样本运行-

    In [90]: dt = np.dtype([('f1', np.uint8), ('f2', np.uint8)])
        ...: a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
        ...:                        [1, 5, 5, 4, 2, 2]], dtype=dt)
        ...: b = np.rec.fromarrays([[1, 4, 7, 9, 9],
        ...:                        [7, 5, 4, 2, 2]], dtype=dt)
    
    In [91]: convert_recarrays_to_1Darrs(a, b)
    Out[91]: 
    (array([13, 54, 54, 47, 29, 29], dtype=uint8),
     array([71, 54, 47, 29, 29], dtype=uint8))
    

    覆盖的通用版本 rec-arrays

    版本#1:

    def findwhere_generic_v1(a, b):
        cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
        count = np.diff(cidx)
        b_starts = b[cidx[:-1]]
        
        a_starts = np.searchsorted(a,b_starts)
        a_starts[a_starts==len(a)] = 0
        
        valid_mask = (b_starts == a[a_starts])
        count_valid = count[valid_mask]
        
        idx2m0 = np.searchsorted(a,b_starts[valid_mask],'right')    
        idx1m0 = a_starts[valid_mask]
        
        id_arr = np.zeros(len(a)+1, dtype=int)
        id_arr[idx2m0] -= 1
        id_arr[idx1m0] += 1
        
        n = idx2m0 - idx1m0
        r1 = np.flatnonzero(id_arr.cumsum()!=0)
        r2 = np.repeat(count_valid,n)
        a_idx1 = np.repeat(r1, r2)
        
        b_idx1 = np.searchsorted(b,a[a_idx1])
        m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
        idx11 = np.flatnonzero(m1[1:] != m1[:-1])
        id_arr = m1.astype(int)
        id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
        b_idx1 += id_arr.cumsum()[:-1]
        return a_idx1, b_idx1
    

    版本#2:

    def findwhere_generic_v2(a, b):    
        cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
        count = np.diff(cidx)
        b_starts = b[cidx[:-1]]
        
        idxx = np.flatnonzero(np.r_[True,a[1:] != a[:-1],True])
        av = a[idxx[:-1]]
        idxxs = np.searchsorted(av,b_starts)
        idxxs[idxxs==len(av)] = 0
        valid_mask0 = av[idxxs] == b_starts
        
        starts = idxx[idxxs]
        stops = idxx[idxxs+1]
        
        idx1m0 = starts[valid_mask0]
        idx2m0 = stops[valid_mask0]  
        
        count_valid = count[valid_mask0]
        
        id_arr = np.zeros(len(a)+1, dtype=int)
        id_arr[idx2m0] -= 1
        id_arr[idx1m0] += 1
        
        n = idx2m0 - idx1m0
        r1 = np.flatnonzero(id_arr.cumsum()!=0)
        r2 = np.repeat(count_valid,n)
        a_idx1 = np.repeat(r1, r2)
        
        b_idx1 = np.searchsorted(b,a[a_idx1])
        m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
        idx11 = np.flatnonzero(m1[1:] != m1[:-1])
        id_arr = m1.astype(int)
        id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
        b_idx1 += id_arr.cumsum()[:-1]
        return a_idx1, b_idx1
    
        3
  •  1
  •   SigmaPiEpsilon    7 年前

    纯Python方法

    发电机理解

    另一种纯python实现,带有生成器和列表理解。与您的代码相比,内存效率可能会更高,但与numpy版本相比可能会更慢。对于排序数组,这将更快。

    def pywheregen(a, b):
    
        l = ((ia,ib) for ia,j in enumerate(a) for ib,k in enumerate(b) if j == k)
        a_idx,b_idx = zip(*l)
        return a_idx,b_idx
    

    考虑数组排序的Python for循环

    这里有一个替代版本,它使用简单的python for循环,并考虑到数组被排序,因此它只检查需要检查的对。

    def pywhere(a, b):
    
        l = []
        a.sort()
        b.sort()
        match = 0
        for ia,j in enumerate(a):
            ib = match
            while ib < len(b) and j >= b[ib]:
                if j == b[ib]:
                    l.append(((ia,ib)))
                    if b[match] < b[ib]:
                        match = ib
                ib += 1
    
        a_ind,b_ind = zip(*l)
    
        return a_ind, b_ind
    

    计时

    我使用@Paul Panzer的mock\u data()函数比较了计时,并将其与 findwhere() f_D() np。@Divakar的外部进近。 findwhere() 仍然表现最好,但 pywhere() 考虑到它是纯python的,情况并不那么糟糕。 pywheregen() 失败,令人惊讶 f\u D() 需要更长的时间。它们都在N=10^6时失败。由于 heapq 单元

    In [2]: a, b = mock_data(10000)
    In [10]: %timeit -n 10 findwhere(a,b)                                     
    10 loops, best of 3: 1.62 ms per loop
    
    In [11]: %timeit -n 10 pywhere(a,b)                                       
    10 loops, best of 3: 20.6 ms per loop
    
    In [12]: %timeit pywheregen(a,b)                                          
    1 loop, best of 3: 12.7 s per loop
    
    In [13]: %timeit -n 10 f_D(a,b)                                           
    10 loops, best of 3: 476 ms per loop
    
    In [14]: a, b = mock_data(1000000)
    In [15]: %timeit -n 10 findwhere(a,b)                                     
    10 loops, best of 3: 109 ms per loop
    
    In [16]: %timeit -n 10 pywhere(a,b)                                       
    10 loops, best of 3: 2.51 s per loop