代码之家  ›  专栏  ›  技术社区  ›  mistakeNot

numpy函数cythonization

  •  2
  • mistakeNot  · 技术社区  · 6 年前

    我在纯python中有以下函数:

    import numpy as np
    
    def subtractPython(a, b):
        xAxisCount = a.shape[0]
        yAxisCount = a.shape[1]
    
        shape = (xAxisCount, yAxisCount, xAxisCount)
        results = np.zeros(shape)
        for index in range(len(b)):
            subtracted = (a - b[index])
            results[:, :, index] = subtracted
        return results
    

    我试着用这种方式来解释它:

    import numpy as np
    cimport numpy as np
    
    DTYPE = np.int
    ctypedef np.int_t DTYPE_t
    
    def subtractPython(np.ndarray[DTYPE_t, ndim=2] a, np.ndarray[DTYPE_t, ndim=2] b):
        cdef int xAxisCount = a.shape[0]
        cdef int yAxisCount = a.shape[1]
    
        cdef np.ndarray[DTYPE_t, ndim=3] results = np.zeros([xAxisCount, yAxisCount, xAxisCount], dtype=DTYPE)
    
        cdef int lenB = len(b)
    
        cdef np.ndarray[DTYPE_t, ndim=2] subtracted
        for index in range(lenB):
            subtracted = (a - b[index])
            results[:, :, index] = subtracted
        return results
    

    然而,我没有看到任何加速。是否有什么地方我遗漏了,或者这个过程无法加快?

    编辑->我意识到我并没有在上面的代码中对减法算法进行简单化。我已经成功地对其进行了cythonize,但它的运行时与a-b完全相同[:,无],因此我猜这是此操作的最大速度。

    这基本上是a-b[:,无]->具有相同的运行时

    %%cython
    
    import numpy as np
    cimport numpy as np
    
    
    DTYPE = np.int
    ctypedef np.int_t DTYPE_t
    
    cimport cython
    @cython.boundscheck(False) # turn off bounds-checking for entire function
    @cython.wraparound(False)  # turn off negative index wrapping for entire function
    def subtract(np.ndarray[DTYPE_t, ndim=2] a, np.ndarray[DTYPE_t, ndim=2] b):
        cdef np.ndarray[DTYPE_t, ndim=3] result = np.zeros([b.shape[0], a.shape[0], a.shape[1]], dtype=DTYPE)
    
        cdef int lenB = b.shape[0]
        cdef int lenA = a.shape[0]
        cdef int lenColB = b.shape[1]
    
        cdef int rowA, rowB, column
    
        for rowB in range(lenB):
            for rowA in range(lenA):
                for column in range(lenColB):
                    result[rowB, rowA, column] = a[rowA, column] - b[rowB, column]
        return result
    
    1 回复  |  直到 6 年前
        1
  •  4
  •   ead    6 年前

    当你试图优化一个函数时,你应该知道这个函数的瓶颈是什么——没有它,你会花很多时间在错误的方向上运行。

    让我们使用python函数作为基线(实际上我使用 result=np.zeros(shape,dtype=a.dtype) 否则,您的方法将返回 floats 这可能是一个bug):

    >>> import numpy as np
    >>> a=np.random.randint(1,1000,(300,300), dtype=np.int)
    >>> b=np.random.randint(1,1000,(300,300), dtype=np.int)
    >>> %timeit subtractPython(a,b)
    274 ms ± 3.61 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    我们应该问自己的第一个问题是:这个任务是内存还是CPU受限?显然,这是一项内存受限的任务——与所需的内存读写访问相比,减法算不了什么。

    这意味着,所有这些我们都必须优化内存布局,以减少缓存未命中。根据经验,我们的内存访问应该一个接一个地访问连续的内存地址。

    是这样吗?不,阵列 result 按C顺序,即行主顺序,因此访问

    results[:, :, index] = subtracted
    

    不是连续的。另一方面

    results[index, :, :] = subtracted
    

    将是连续访问。让我们改变信息存储的方式 后果 :

    def subtract1(a, b):
        xAxisCount = a.shape[0]
        yAxisCount = a.shape[1]
    
        shape = (xAxisCount,  xAxisCount, yAxisCount) #<=== Change order
        results = np.zeros(shape, dtype=a.dtype)
        for index in range(len(b)):
            subtracted = (a - b[index])
            results[index, :, :] = subtracted   #<===== consecutive access
        return results
    

    现在的计时是:

    >>> %timeit subtract1(a,b)
    >>> 35.8 ms ± 285 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    还有两个小的改进:我们不必用零初始化结果,我们可以节省一些python开销,但这只给了我们大约5%:

    def subtract2(a, b):
        xAxisCount = a.shape[0]
        yAxisCount = a.shape[1]
    
        shape = (xAxisCount,  xAxisCount, yAxisCount) 
        results = np.empty(shape, dtype=a.dtype)        #<=== no need for zeros
        for index in range(len(b)):
            results[index, :, :] = (a-b[index])   #<===== less python overhead
        return results
    
    >>> %timeit subtract2(a,b)
    34.5 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    现在,这大约比原始版本快8倍。

    您可以使用Cython来进一步加快速度,但任务可能仍然是内存受限的,所以不要期望它明显更快,毕竟Cython无法使内存更快地工作。然而,如果没有适当的评测,很难说有多少改进是可能的——如果有人能想出一个更快的版本,那也就不足为奇了。