代码之家  ›  专栏  ›  技术社区  ›  misantroop

Numba/numPy多运行速度差/优化

  •  0
  • misantroop  · 技术社区  · 4 年前

    import pandas as pd
    import numpy as np
    from datetime import datetime, timedelta
    from time import time
    import numba
    
    times =             np.arange(datetime(2000, 1, 1), datetime(2020, 2, 1), timedelta(minutes=10)).astype(np.datetime64)
    tlen =              len(times)
    A, Z =              np.array(['A', 'Z']).view('int32')
    symbol_names =      np.random.randint(low=A, high=Z, size=1 * 7, dtype='int32').view(f'U{7}')
    times =             np.concatenate([times] * 1)
    names =             np.array([y for x in [[s] * tlen for s in symbol_names] for y in x])
    open_column =       np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
    high_column =       np.random.randint(low=50, high=70, size=len(times), dtype='uint32')
    low_column =        np.random.randint(low=30, high=50, size=len(times), dtype='uint32')
    close_column =      np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
    df = pd.DataFrame({'open': open_column, 'high': high_column, 'low': low_column, 'close': close_column}, index=[names, times])
    df.index = df.index.set_names(['Symbol', 'Date'])
    df['entry'] = np.select( [df.open > df.open.shift(), False], (df.close, -1), np.nan)
    df['exit'] =  df.close.where(df.high > df.open*1.33, np.nan)
    

    def timing(f):
        def wrap(*args):
            time1 = time()
            ret = f(*args)
            time2 = time()
            print('{:s} function took {:.3f} s'.format(f.__name__, (time2-time1)))
            return ret
        return wrap
    

    JIT编译函数:

    @numba.jit(nopython=True)
    def entry_exit(arr, limit=0, stop=0, tbe=0):
        is_active = 0
        bars_held = 0
        limit_target = np.inf
        stop_target = -np.inf
        result = np.empty(arr.shape[0], dtype='float32')
    
        for n in range(arr.shape[0]):
            ret = 0
            if is_active == 1:
                bars_held += 1
                if arr[n][2] < stop_target:
                    ret = stop_target
                    is_active = 0
                elif arr[n][1] > limit_target:
                    ret = limit_target
                    is_active = 0
                elif bars_held >= tbe:
                    ret = arr[n][3]
                    is_active = 0
                elif arr[n][5] > 0:
                    ret = arr[n][3]
                    is_active = 0
            if is_active == 0:
                if arr[n][4] > 0:
                    is_active = 1
                    bars_held = 0
                    if stop != 0:
                        stop_target = arr[n][3] * stop
                    if limit != 0:
                        limit_target = arr[n][3] * limit
            result[n] = ret
        return result
    

    测验:

    @timing
    def run_one(arr):
        entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
    
    @timing
    def run_ten(arr):
        for _ in range(10):
            entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
    
    arr = df[['open', 'high', 'low', 'close', 'entry', 'exit']].to_numpy()
    run_one(arr)
    run_ten(arr)
    

    在本机Python中运行时,我得到:

    • 运行一个函数需要0.680秒
    • 运行函数需要5.816秒

    有道理。

    当我在JIT中运行相同的程序时,得到的结果完全不同:

    • 运行一个函数需要0.753秒
    • 运行函数需要0.105秒

    为什么会这样? 我还很有兴趣知道如何进一步加速函数,因为当前的速度增益虽然很大,但还不够。

    0 回复  |  直到 4 年前
        1
  •  1
  •   Michael Anderson    4 年前

    numba.jit 将在函数首次使用时编译它。这使得函数的第一次执行很昂贵,而随后的执行要便宜得多。

    你的测试可能会运行 run_one -哪个电话 entry_exit run_ten ,但是 进出口 已经编译过了,所以编译后的表单可以重用,所以速度很快。

    总而言之,我认为故障可能是

    run_one: 0.74s compile + 1 x 0.01s run
    run_ten: no compile + 10 x 0.01s run
    

    要检查这一点,您只需要确保在开始测试函数的速度之前调用该函数一次(以便numba编译它)。或者您可以设置标志来告诉numba提前编译。

    您只需将测试脚本更改为:

    @timing
    def run_one(arr):
        entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
    
    @timing
    def run_ten(arr):
        for _ in range(10):
            entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
    
    arr = df[['open', 'high', 'low', 'close', 'entry', 'exit']].to_numpy()
    
    # Run it once so that numba compiles it
    entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
    
    # Use the compiled version
    run_one(arr)
    run_ten(arr)