【问题标题】:Using Numba jitclass with input from dicts and tuples将 Numba jitclass 与来自字典和元组的输入一起使用
【发布时间】:2019-07-10 17:05:36
【问题描述】:

我正在优化我拥有的一些代码,这些代码主要包含在单个 python 类中。它对 python 对象的操作很少,所以我认为使用 Numba 将是一个很好的匹配,但我在创建对象期间需要大量参数,而且我认为我不完全理解 Numba 相对较新的 dict 支持(documentation here)。我拥有的参数都是单个浮点数或整数,并被传递到对象中,存储,然后在整个代码运行过程中使用,如下所示:

import numpy as np
from numba import jitclass, float64

spec = [
    ('p', dict),
    ('shape', tuple),               # the shape of the array
    ('array', float64[:,:]),          # an array field
]

params_default = {
    par_1 = 1,
    par_2 = 0.5
    }

@jitclass(spec)
class myObj:
    def __init__(self,params = params_default,shape = (100,100)):
        self.p = params
        self.shape = shape
        self.array = self.p['par_2']*np.ones(shape)

    def inc_arr(self):
        self.array += self.p['par_1']*np.ones(shape)

我认为我不明白 Numba 为此需要什么。如果我想使用 nopython 模式使用 Numba 优化它,是否需要将规范传递给 jitclass 装饰器?如何定义字典的规范?我还需要声明形状元组吗?我查看了在 jitclass 装饰器上找到的 documentation 以及 dict numba 文档,但我不知道该怎么做。当我运行上面的代码时,我得到以下错误:

TypeError: spec values should be Numba type instances, got <class 'dict'>

我是否需要以某种方式在规范中包含 dict 元素?从文档中不清楚正确的语法是什么。

或者,有没有办法让 Numba 推断输入类型?

【问题讨论】:

    标签: python dictionary jit numba


    【解决方案1】:

    spec 需要由 numba 特定类型 组成,而不是 python 类型! 所以规范中的tupledict 必须是typed numba 类型(并且afaik 只允许使用同质字典)。

    因此,要么在 jitted 函数中指定 params_default dict,如 here 所示,要么显式键入 numba dict as shown here

    对于这种情况,我将采用后一种方法:

    import numpy as np
    from numba import jitclass, float64
    
    # Explicitly define the types of the key and value:
    params_default = nb.typed.Dict.empty(
        key_type=nb.typeof('par_1'),
        value_type=nb.typeof(0.5)
    )
    
    # assign your default values
    params_default['par_1'] = 1.  # Same type required, thus setting to float
    params_default['par_2'] = .5
    
    spec = [
        ('p', nb.typeof(params_default)),
        ('shape', nb.typeof((100, 100))),               # the shape of the array
        ('array', float64[:, :]),          # an array field
    ]
    
    @jitclass(spec)
    class myObj:
        def __init__(self, params=params_default, shape=(100, 100)):
            self.p = params
            self.shape = shape
            self.array = self.p['par_2'] * np.ones(shape)
    
        def inc_arr(self):
            self.array += self.p['par_1'] * np.ones(shape)
    

    正如已经指出的:dict 是,afaik,同质类型。因此,所有键/值必须属于同一类型。所以将intfloat 存储在同一个字典中是行不通的。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2018-10-06
      • 1970-01-01
      • 1970-01-01
      • 2016-12-05
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多