【发布时间】:2020-07-24 07:46:30
【问题描述】:
我们正在编写一个模拟赛车,以应对学习驾驶模拟赛车的机器学习挑战。作为其中的一部分,我们使用汽车的动态模型。这是一个复杂的模型,需要实时运行,所以我正在尝试用 numba 加速它。
我设法使用带有 deferred_type() 实例的花哨的 jitclass 规范解决了大多数嵌套类对象问题,但现在遇到了以下问题:
import numba as nb
from numba import float64 as f64 # use f64 to type explicitly for list elements to tell numba it is just scalar value
from numba import jit, deferred_type
# .... skips some lines
#consider steering constraints
pl=p.longitudinal
ps=p.steering
x3=f64(x[3])
u1=f64(uInit[1])
a=accelerationConstraints(x3,u1,pl)
运行此命令会导致 numba 输出以下错误:
File "C:\Users\tobid\miniconda3\envs\l2race\lib\site-packages\numba\core\utils.py", line 81, in reraise
raise value
numba.core.errors.LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
Cannot cast array(float64, 1d, A) to float64: %".218" = load {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}, {i8*, i8*, i64, i64, double*, [1 x i64], [1 x i64]}* %"$36binary_subscr.14"
File "commonroad\vehicleDynamics_KS.py", line 63:
def vehicleDynamics_KS(x,uInit,p):
<source elided>
ps=p.steering
x3=f64(x[3])
^
During: lowering "$38call_function.15 = call $30load_global.11($36binary_subscr.14, func=$30load_global.11, args=[Var($36binary_subscr.14, vehicleDynamics_KS.py:63)], kws=(), vararg=None)" at F:\tobi\Dropbox (Personal)\Share Marcin Tobi\l2race\commonroad\vehicleDynamics_KS.py (63)
即 uInit 的元素(类型为 float32 的一维向量)不被视为浮点数。关键错误行是
无法将数组(float64, 1d, A) 转换为 float64: %".218"
即我取了 uUnit 的一个元素,甚至明确地将其转换为 numba float32,推理似乎失败了。 我一定是错过了什么。
如果我将代码更改如下
x3=x[3]
u1=uInit[1]
a=accelerationConstraints(x3,u1,pl)
结果
Invalid use of type(CPUDispatcher(<function accelerationConstraints at 0x0000010AE7D599D8>)) with parameters (array(float64, 1d, A), array(float64, 1d, A), DeferredType#1146309640904)
Known signatures:
* (float64, float64, DeferredType#1146309640904) -> float64
明确说明 numba 不能正确推断 float32 数组的单个元素是标量 float32。 IE。 numba 认为我传递的是浮点数组,而不是标量值(“parameters (array(float64, 1d, A), array(float64, 1d, A), DeferredType#1146309640904)” )
如何让 numba 相信我的向量的一个元素是标量值?还是我在这里做一些根本错误的事情?谢谢!
【问题讨论】:
标签: arrays casting element jit numba