代码之家  ›  专栏  ›  技术社区  ›  Isky Mathews

Python-Numba多项式根下误差

  •  1
  • Isky Mathews  · 技术社区  · 6 年前

    我创建了一个函数,给定系数的范围,用这些系数构造多项式,并输出其所有根的列表。然而,Numba不喜欢它。是这样的:

    import math
    import numpy as np
    import itertools
    from numba import jit
    from sympy.solvers import solve
    from sympy import Symbol
    from sympy import Poly
    
    @jit
    def polyn(ranges=[[-20,20],[-20,20],[-20,20],[-20,20]],step=4):
        l = []
        x = Symbol('x')
        rangl = [np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step)) for i in ranges]
        coeffl = iter(itertools.product(*rangl))
        leng = 1
        for i in rangl:
            leng *= len(i)
        for i in range(0, leng):
            a = solve(Poly(list(next(coeffl)),x),x)
            for j in a:
                l.append(j)
        return np.array(l)
    

    当我尝试运行此程序时,它会输出一个神秘的: AssertionError:在对象处失败(对象模式前端) 我不明白。。。有人能帮忙吗?

    1 回复  |  直到 6 年前
        1
  •  1
  •   John Zwinck    6 年前

    你的代码中有很多东西是Numba目前无法处理的。第一个是你构建的列表理解 rangl :

    [np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step)) for i in ranges]
    

    您应该将其替换为NumPy解决方案,如:

    rangl = np.empty((len(ranges), step))
    for i in ranges:
        rangl[i] = np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step))
    

    第二件Numba无法应付的事情是itertools。产品您也可以用NumPy和for循环来替换它。

    一般来说,试着通过注释代码的较低部分来减少代码,直到Numba接受它,然后自上而下地工作,看看哪些部分无法编译。要有条不紊,一步一步走,尽量坚持简单的结构,比如简单 for 循环和数组。