方法#1:基于Broadacasting的方法
使用
outer
两个数组之间的相等性比较以利用矢量化
broadcasting
然后得到行、列索引,这是两个数组对应的匹配索引所需要的-
a_idx, b_idx = np.where(a[:,None]==b)
a_idx, b_idx = np.where(np.equal.outer(a,b))
我们还可以使用
np.nonzero
代替
np.where
.
方法#2:具体案例解决方案
由于没有重复和排序的输入数组,我们可以使用
np.searchsorted
,就像这样-
idx0 = np.searchsorted(a,b)
idx1 = np.searchsorted(b,a)
idx0[idx0==len(a)] = 0
idx1[idx1==len(b)] = 0
a_idx = idx0[a[idx0] == b]
b_idx = idx1[b[idx1] == a]
稍微修改一下,可能更有效的方法是-
idx0 = np.searchsorted(a,b)
idx0[idx0==len(a)] = 0
a_idx = idx0[a[idx0] == b]
b_idx = np.searchsorted(b,a[a_idx])
方法#3:一般情况
以下是一般情况的解决方案(允许重复)-
def findwhere(a, b):
c = np.bincount(b, minlength=a.max()+1)[a]
a_idx1 = np.repeat(np.flatnonzero(c),c[c!=0])
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1
计时
使用
mock_data
来自@Paul Panzer的soln,用于设置输入:
In [295]: a, b = mock_data(1000000)
# @Paul Panzer's soln
In [296]: %timeit sqlwhere(a, b) # USE_HEAPQ = False
10 loops, best of 3: 118 ms per loop
# Approach #3 from this post
In [297]: %timeit findwhere(a,b)
10 loops, best of 3: 61.7 ms per loop
将重新排列(uint8数据)转换为
1D
阵列
def convert_recarrays_to_1Darrs(a, b):
a2D = a.view('u1').reshape(-1,2)
b2D = b.view('u1').reshape(-1,2)
s = max(a2D[:,0].max(), b2D[:,0].max())+1
a1D = s*a2D[:,1] + a2D[:,0]
b1D = s*b2D[:,1] + b2D[:,0]
return a1D, b1D
样本运行-
In [90]: dt = np.dtype([('f1', np.uint8), ('f2', np.uint8)])
...: a = np.rec.fromarrays([[3, 4, 4, 7, 9, 9],
...: [1, 5, 5, 4, 2, 2]], dtype=dt)
...: b = np.rec.fromarrays([[1, 4, 7, 9, 9],
...: [7, 5, 4, 2, 2]], dtype=dt)
In [91]: convert_recarrays_to_1Darrs(a, b)
Out[91]:
(array([13, 54, 54, 47, 29, 29], dtype=uint8),
array([71, 54, 47, 29, 29], dtype=uint8))
覆盖的通用版本
rec-arrays
版本#1:
def findwhere_generic_v1(a, b):
cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
count = np.diff(cidx)
b_starts = b[cidx[:-1]]
a_starts = np.searchsorted(a,b_starts)
a_starts[a_starts==len(a)] = 0
valid_mask = (b_starts == a[a_starts])
count_valid = count[valid_mask]
idx2m0 = np.searchsorted(a,b_starts[valid_mask],'right')
idx1m0 = a_starts[valid_mask]
id_arr = np.zeros(len(a)+1, dtype=int)
id_arr[idx2m0] -= 1
id_arr[idx1m0] += 1
n = idx2m0 - idx1m0
r1 = np.flatnonzero(id_arr.cumsum()!=0)
r2 = np.repeat(count_valid,n)
a_idx1 = np.repeat(r1, r2)
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1
版本#2:
def findwhere_generic_v2(a, b):
cidx = np.flatnonzero(np.r_[True,b[1:] != b[:-1],True])
count = np.diff(cidx)
b_starts = b[cidx[:-1]]
idxx = np.flatnonzero(np.r_[True,a[1:] != a[:-1],True])
av = a[idxx[:-1]]
idxxs = np.searchsorted(av,b_starts)
idxxs[idxxs==len(av)] = 0
valid_mask0 = av[idxxs] == b_starts
starts = idxx[idxxs]
stops = idxx[idxxs+1]
idx1m0 = starts[valid_mask0]
idx2m0 = stops[valid_mask0]
count_valid = count[valid_mask0]
id_arr = np.zeros(len(a)+1, dtype=int)
id_arr[idx2m0] -= 1
id_arr[idx1m0] += 1
n = idx2m0 - idx1m0
r1 = np.flatnonzero(id_arr.cumsum()!=0)
r2 = np.repeat(count_valid,n)
a_idx1 = np.repeat(r1, r2)
b_idx1 = np.searchsorted(b,a[a_idx1])
m1 = np.r_[False,a_idx1[1:] == a_idx1[:-1],False]
idx11 = np.flatnonzero(m1[1:] != m1[:-1])
id_arr = m1.astype(int)
id_arr[idx11[1::2]+1] = idx11[::2]-idx11[1::2]
b_idx1 += id_arr.cumsum()[:-1]
return a_idx1, b_idx1