import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from time import time
import numba
times = np.arange(datetime(2000, 1, 1), datetime(2020, 2, 1), timedelta(minutes=10)).astype(np.datetime64)
tlen = len(times)
A, Z = np.array(['A', 'Z']).view('int32')
symbol_names = np.random.randint(low=A, high=Z, size=1 * 7, dtype='int32').view(f'U{7}')
times = np.concatenate([times] * 1)
names = np.array([y for x in [[s] * tlen for s in symbol_names] for y in x])
open_column = np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
high_column = np.random.randint(low=50, high=70, size=len(times), dtype='uint32')
low_column = np.random.randint(low=30, high=50, size=len(times), dtype='uint32')
close_column = np.random.randint(low=40, high=60, size=len(times), dtype='uint32')
df = pd.DataFrame({'open': open_column, 'high': high_column, 'low': low_column, 'close': close_column}, index=[names, times])
df.index = df.index.set_names(['Symbol', 'Date'])
df['entry'] = np.select( [df.open > df.open.shift(), False], (df.close, -1), np.nan)
df['exit'] = df.close.where(df.high > df.open*1.33, np.nan)
def timing(f):
def wrap(*args):
time1 = time()
ret = f(*args)
time2 = time()
print('{:s} function took {:.3f} s'.format(f.__name__, (time2-time1)))
return ret
return wrap
JIT编译函数:
@numba.jit(nopython=True)
def entry_exit(arr, limit=0, stop=0, tbe=0):
is_active = 0
bars_held = 0
limit_target = np.inf
stop_target = -np.inf
result = np.empty(arr.shape[0], dtype='float32')
for n in range(arr.shape[0]):
ret = 0
if is_active == 1:
bars_held += 1
if arr[n][2] < stop_target:
ret = stop_target
is_active = 0
elif arr[n][1] > limit_target:
ret = limit_target
is_active = 0
elif bars_held >= tbe:
ret = arr[n][3]
is_active = 0
elif arr[n][5] > 0:
ret = arr[n][3]
is_active = 0
if is_active == 0:
if arr[n][4] > 0:
is_active = 1
bars_held = 0
if stop != 0:
stop_target = arr[n][3] * stop
if limit != 0:
limit_target = arr[n][3] * limit
result[n] = ret
return result
测验:
@timing
def run_one(arr):
entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
@timing
def run_ten(arr):
for _ in range(10):
entry_exit(arr, limit=1.20, stop=0.50, tbe=5)
arr = df[['open', 'high', 'low', 'close', 'entry', 'exit']].to_numpy()
run_one(arr)
run_ten(arr)
在本机Python中运行时,我得到:
-
运行一个函数需要0.680秒
-
运行函数需要5.816秒
有道理。
当我在JIT中运行相同的程序时,得到的结果完全不同:
-
运行一个函数需要0.753秒
-
运行函数需要0.105秒
为什么会这样?
我还很有兴趣知道如何进一步加速函数,因为当前的速度增益虽然很大,但还不够。