这可能不明显,但是
pd.Series.isin
使用
O(1)
-抬起头来。
经过分析,这证明了上述声明,我们将使用其见解创建一个Cython原型,可以轻松击败最快的箱外解决方案。
假设“集合”有
n
元素和“系列”有
m
元素。运行时间为:
T(n,m)=T_preprocess(n)+m*T_lookup(n)
对于纯python版本,这意味着:
-
T_preprocess(n)=0
-不需要预处理
-
T_lookup(n)=O(1)
-python集合的众所周知的行为
-
结果在
T(n,m)=O(m)
发生了什么
pd.Series.isin(x_arr)
?显然,如果我们跳过预处理并在线性时间内搜索,我们将得到
O(n*m)
,这是不可接受的。
在调试器或探查器的帮助下(我使用了valgrind callgrind+kcachegrind),很容易看到正在发生的事情:工作马是函数
__pyx_pw_6pandas_5_libs_9hashtable_23ismember_int64
. 可以找到它的定义
here
:
-
在预处理步骤中,哈希映射(pandas使用
khash from klib
)是由
n
元素来自
x_arr
,即在运行时间内
O(n)
.
-
米
查找发生在
O(1)
每个或
O(m)
在构造的哈希映射中总计。
-
结果在
T(n,m)=O(m)+O(n)
我们必须记住-numpy数组的元素是原始c整数,而不是原始集合中的python对象-所以我们不能像现在这样使用集合。
将一组python对象转换为一组c-int的另一种方法是将单个c-int转换为python对象,从而能够使用原始集合。这就是发生在
[i in x_set for i in ser.values]
-变体:
-
没有预处理。
-
M查找发生在
O(1)
每个时间或
o(m)
总的来说,但是由于需要创建一个python对象,查找速度较慢。
-
结果在
t(n,m)=o(m)
显然,你可以使用cython来加速这个版本。
但是足够的理论,让我们看看不同的运行时间
n
固定S
米
S:
我们可以看到:预处理的线性时间占主导地位的NUMPY版本的大。
n
从numpy到纯python的转换版本(
numpy->python
)与纯python版本具有相同的常量行为,但速度较慢,因为需要进行转换—这都符合我们的分析。
在图表中看不清楚:如果
n < m
numpy版本变得更快-在这种情况下,更快的查找
khash
-lib扮演着最重要的角色,而不是预处理部分。
我从这个分析中得到的启示:
-
n& lt;m
:
PD系列
应该是因为
o(n)
-预处理并没有那么昂贵。
-
n > m
(可能是cythonized版本的)
[i in x_为i in ser.值设置]
应该采取这样的措施
o(n)
避免。
-
很明显有一个灰色地带
n
和
米
近似相等,不经测试很难判断哪种解决方案最好。
-
如果你能控制它,最好的办法就是
set
直接作为c整数集(
哈什
(
already wrapped in pandas
(或者甚至一些C++实现),从而消除了预处理的需要。我不知道pandas中是否有可以重用的东西,但是用cython编写函数可能没什么大不了的。
问题是,最后一个建议不是现成的,因为熊猫和numpy在它们的接口中都没有一个集合的概念(至少在我有限的知识范围内)。但拥有raw-c-set-interfaces将是两方面的最佳选择:
-
不需要预处理,因为值已作为集合传递
-
不需要转换,因为传递的集合由原始c值组成
我已经编码了一个快速和肮脏的
Cython-wrapper for khash
(灵感来自熊猫的包装),可以通过
pip install https://github.com/realead/cykhash/zipball/master
然后和cython一起使用
isin
版本:
%%cython
import numpy as np
cimport numpy as np
from cykhash.khashsets cimport Int64Set
def isin_khash(np.ndarray[np.int64_t, ndim=1] a, Int64Set b):
cdef np.ndarray[np.uint8_t,ndim=1, cast=True] res=np.empty(a.shape[0],dtype=np.bool)
cdef int i
for i in range(a.size):
res[i]=b.contains(a[i])
return res
作为C++的另一种可能
unordered_map
可以被包装(见清单C),它的缺点是需要C++库和(如我们所看到的)稍微慢一点。
比较方法(有关创建计时的信息,请参见清单D):
卡什比
巨蟒
,大约比纯python快6个因子(但是纯python并不是我们想要的),甚至比cpp的版本快3个因子。
清单
1)Valgrind仿形:
#isin.py
import numpy as np
import pandas as pd
np.random.seed(0)
x_set = {i for i in range(2*10**6)}
x_arr = np.array(list(x_set))
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
for _ in range(10):
ser.isin(x_arr)
现在:
>>> valgrind --tool=callgrind python isin.py
>>> kcachegrind
导致以下调用图:
B:IPython生成运行时间的代码:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
np.random.seed(0)
x_set = {i for i in range(10**2)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
lst = arr.tolist()
n=10**3
result=[]
while n<3*10**6:
x_set = {i for i in range(n)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
t1=%timeit -o ser.isin(x_arr)
t2=%timeit -o [i in x_set for i in lst]
t3=%timeit -o [i in x_set for i in ser.values]
result.append([n, t1.average, t2.average, t3.average])
n*=2
#plotting result:
for_plot=np.array(result)
plt.plot(for_plot[:,0], for_plot[:,1], label='numpy')
plt.plot(for_plot[:,0], for_plot[:,2], label='python')
plt.plot(for_plot[:,0], for_plot[:,3], label='numpy->python')
plt.xlabel('n')
plt.ylabel('running time')
plt.legend()
plt.show()
C:CPP包装:
%%cython --cplus -c=-std=c++11 -a
from libcpp.unordered_set cimport unordered_set
cdef class HashSet:
cdef unordered_set[long long int] s
cpdef add(self, long long int z):
self.s.insert(z)
cpdef bint contains(self, long long int z):
return self.s.count(z)>0
import numpy as np
cimport numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def isin_cpp(np.ndarray[np.int64_t, ndim=1] a, HashSet b):
cdef np.ndarray[np.uint8_t,ndim=1, cast=True] res=np.empty(a.shape[0],dtype=np.bool)
cdef int i
for i in range(a.size):
res[i]=b.contains(a[i])
return res
D:使用不同的集合包装器绘制结果:
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
from cykhash import Int64Set
np.random.seed(0)
x_set = {i for i in range(10**2)}
x_arr = np.array(list(x_set))
x_list = list(x_set)
arr = np.random.randint(0, 20000, 10000)
ser = pd.Series(arr)
lst = arr.tolist()
n=10**3
result=[]
while n<3*10**6:
x_set = {i for i in range(n)}
x_arr = np.array(list(x_set))
cpp_set=HashSet()
khash_set=Int64Set()
for i in x_set:
cpp_set.add(i)
khash_set.add(i)
assert((ser.isin(x_arr).values==isin_cpp(ser.values, cpp_set)).all())
assert((ser.isin(x_arr).values==isin_khash(ser.values, khash_set)).all())
t1=%timeit -o isin_khash(ser.values, khash_set)
t2=%timeit -o isin_cpp(ser.values, cpp_set)
t3=%timeit -o [i in x_set for i in lst]
t4=%timeit -o [i in x_set for i in ser.values]
result.append([n, t1.average, t2.average, t3.average, t4.average])
n*=2
#ploting result:
for_plot=np.array(result)
plt.plot(for_plot[:,0], for_plot[:,1], label='khash')
plt.plot(for_plot[:,0], for_plot[:,2], label='cpp')
plt.plot(for_plot[:,0], for_plot[:,3], label='pure python')
plt.plot(for_plot[:,0], for_plot[:,4], label='numpy->python')
plt.xlabel('n')
plt.ylabel('running time')
ymin, ymax = plt.ylim()
plt.ylim(0,ymax)
plt.legend()
plt.show()