【问题标题】:numba jit cannot understand that single element of array is not arraynumba jit 无法理解数组的单个元素不是数组
【发布时间】:2020-07-24 07:46:30
【问题描述】:

我们正在编写一个模拟赛车,以应对学习驾驶模拟赛车的机器学习挑战。作为其中的一部分,我们使用汽车的动态模型。这是一个复杂的模型,需要实时运行,所以我正在尝试用 numba 加速它。

我设法使用带有 deferred_type() 实例的花哨的 jitclass 规范解决了大多数嵌套类对象问题,但现在遇到了以下问题:

import numba as nb
from numba import float64 as f64  # use f64 to type explicitly for list elements to tell numba it is just scalar value
from numba import jit, deferred_type
# .... skips some lines

#consider steering constraints
pl=p.longitudinal
ps=p.steering
x3=f64(x[3])
u1=f64(uInit[1])
a=accelerationConstraints(x3,u1,pl)

运行此命令会导致 numba 输出以下错误:

  File "C:\Users\tobid\miniconda3\envs\l2race\lib\site-packages\numba\core\utils.py", line 81, in reraise
    raise value
numba.core.errors.LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Cannot cast array(float64, 1d, A) to float64: %".218" = load {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}* %"$36binary_subscr.14"

File "commonroad\vehicleDynamics_KS.py", line 63:
def vehicleDynamics_KS(x,uInit,p):
    <source elided>
    ps=p.steering
    x3=f64(x[3])
    ^

During: lowering "$38call_function.15 = call $30load_global.11($36binary_subscr.14, func=$30load_global.11, args=[Var($36binary_subscr.14, vehicleDynamics_KS.py:63)], kws=(), vararg=None)" at F:\tobi\Dropbox (Personal)\Share Marcin Tobi\l2race\commonroad\vehicleDynamics_KS.py (63)

即 uInit 的元素(类型为 float32 的一维向量)不被视为浮点数。关键错误行是

无法将数组(float64, 1d, A) 转换为 float64: %".218"

即我取了 uUnit 的一个元素,甚至明确地将其转换为 numba float32,推理似乎失败了。 我一定是错过了什么。

如果我将代码更改如下

x3=x[3]
u1=uInit[1]
a=accelerationConstraints(x3,u1,pl)

结果

Invalid use of type(CPUDispatcher(<function accelerationConstraints at 0x0000010AE7D599D8>)) with parameters (array(float64, 1d, A), array(float64, 1d, A), DeferredType#1146309640904)
Known signatures:
 * (float64, float64, DeferredType#1146309640904) -> float64

明确说明 numba 不能正确推断 float32 数组的单个元素是标量 float32。 IE。 numba 认为我传递的是浮点数组,而不是标量值(“parameters (array(float64, 1d, A), array(float64, 1d, A), DeferredType#1146309640904)” )

如何让 numba 相信我的向量的一个元素是标量值?还是我在这里做一些根本错误的事情?谢谢!

【问题讨论】:

    标签: arrays casting element jit numba


    【解决方案1】:

    我至少找到了这个问题的根本原因。 我的函数的签名是

    fa=nb.types.List(nb.float64[:], reflected=False) # define numba type of list of float
    params_type=deferred_type()
    params_type.define(VehicleParameters.class_type.instance_type) # define numba type for VehicleParameters class instance that has model parameters (it also has @jitclass)
    
    @jit(fa(fa, fa, params_type),nopython=True)
    def vehicleDynamics_KS(x,uInit,p):
    

    这意味着我告诉 numba 我正在传递一个浮动列表的列表,而不仅仅是一个浮动列表。

    我将浮点数组“fa”的定义从 fa=nb.types.List(nb.float64[:],reflected=False) 更改为 fa=nb.types。 List(nb.float64,reflected=False) 和我的方法的签名 @jit(fa(fa(fa, fa, params_type),nopython=True) 现在 numba 正确推断列表的元素是一个标量。

    还有更多问题需要解决,但这是朝着正确方向迈出的一步。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-08-08
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-01-13
      • 1970-01-01
      • 1970-01-01
      • 2018-05-04
      相关资源
      最近更新 更多