代码之家  ›  专栏  ›  技术社区  ›  ma3oun

如何嵌套numba jitclass

  •  16
  • ma3oun  · 技术社区  · 8 年前

    我试图理解@jitclass装饰器如何与嵌套类一起工作。我写了两个虚拟类:fifi和toto fifi具有toto属性。这两个类都有@jitclass修饰符,但编译失败。代码如下:

    菲菲公司

    from numba import jitclass, float64
    from toto import toto
    
    spec = [('a',float64),('b',float64),('c',toto)]
    
    @jitclass(spec)
    class fifi(object):
      def __init__(self, combis):
        self.a = combis
        self.b = 2
        self.c = toto(combis)
    
      def mySqrt(self,x):
        s = x
        for i in xrange(self.a):
          s = (s + x/s) / 2.0
        return s
    

    toto.py(总计):

    from numba import jitclass,int32
    
    spec = [('n',int32)]
    
    @jitclass(spec)
    class toto(object):
      def __init__(self,n):
        self.n = 42 + n
    
      def work(self,y):
        return y + self.n
    

    启动代码的脚本:

    from datetime import datetime
    from fifi import fifi
    from numba import jit
    
    @jit(nopython = True)
    def run(n,results):
      for i in xrange(n):
        q = fifi(200)
        results[i+1] = q.mySqrt(i + 1)
    
    if __name__ == '__main__':
      n = int(1e6)
      results = [0.0] * (n+1)
      starttime = datetime.now()
      run(n,results)
      endtime = datetime.now()
    
      print("Script running time: %s"%str(endtime-starttime))
      print("Sqrt of 144 is %f"%results[145])
    

    当我运行脚本时,我得到[…]

    键入错误:未键入全局名称“toto” 文件“fifi.py”,第11行

    请注意,如果我在“fifi”中删除了对“toto”的引用,代码就可以正常工作了,多亏了numba,我得到了x16的速度提升。

    1 回复  |  直到 8 年前
        1
  •  23
  •   JoshAdel    8 年前

    deferred_type 例子这在Numba 0.27和更早版本中有效。改变 fifi.py 至:

    from numba import jitclass, float64, deferred_type
    from toto import toto
    
    toto_type = deferred_type()
    toto_type.define(toto.class_type.instance_type)
    
    spec = [('a',float64),('b',float64),('c',toto_type)]
    
    @jitclass(spec)
    class fifi(object):
      def __init__(self, combis):
        self.a = combis
        self.b = 2
        self.c = toto(combis)
    
      def mySqrt(self,x):
        s = x
        for i in xrange(self.a):
          s = (s + x/s) / 2.0
        return s
    

    然后作为输出:

    $ python test.py
    Script running time: 0:00:01.991600
    Sqrt of 144 is 12.041595
    

    在一些更高级的jitclass数据结构示例中可以看到此功能,例如: