【问题标题】:How to pass array pointer to Numba function?如何将数组指针传递给 Numba 函数?
【发布时间】:2020-08-14 00:28:33
【问题描述】:

我想创建一个 Numba 编译函数,该函数将指针或数组的内存地址作为参数并对其进行计算,例如,修改基础数据。

用于说明这一点的纯 python 版本如下所示:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array


def modify_data(addr):
    """ a function taking the memory address of an array to modify it """
    ptr = ctypes.c_void_p(addr)
    data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
    data += 2

addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])

正如您在示例中看到的那样,数组arr 被修改了,但没有将其显式传递给函数。在我的用例中,数组的形状和 dtype 是已知的,并且始终保持不变,这应该会简化界面。

1。尝试:天真的jitting

我现在尝试编译modify_data 函数,但失败了。我的第一次尝试是使用

shape = arr.shape
dtype = arr.dtype

@nb.njit
def modify_data_nb(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2


ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr)   # <<< error

cannot determine Numba type of &lt;class 'ctypes.c_void_p'&gt; 失败,即它不知道如何解释指针。

2。尝试:显式类型

我尝试放置显式类型,

arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape

@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(ptr, shape)
    data += 2

但这没有帮助。它没有抛出任何错误,但我不知道如何调用函数modify_data_nb。我尝试了以下选项

modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64

ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

有没有办法从 arr 获取正确的指针格式,以便我可以将它传递给 Numba 编译的 modify_data_nb 函数?或者,是否有另一种方法可以将内存位置传递给函数。

3。尝试:使用scipy.LowLevelCallable

我通过使用scipy.LowLevelCallable 及其魔力取得了一些进展:

arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])

# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype

@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2

modify_data_llc = LowLevelCallable(modify_data.ctypes).function    

# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))

# call the function only with the pointer
modify_data_llc(ptr)

# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])

我现在可以调用一个函数来访问数组,但是这个函数不再是一个 Numba 函数。特别是它不能用于其他 Numba 函数。

【问题讨论】:

  • 我不知道如何从解释代码中调用这个函数,但是从另一个 jitted 函数中它可以工作。例如。 @nb.njit() def call(arr): modify_data_nb(arr.ctypes)
  • 这是一个有趣的观察,但不幸的是它对我没有帮助,因为我不会在函数 call 的范围之外更改 arr 的内容。

标签: python arrays numpy numba


【解决方案1】:

感谢伟大的@stuartarchibald,我现在有了一个可行的解决方案:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array
print(arr)

@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba.core import types, cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

addr = arr.ctypes.data

@nb.njit
def modify_data():
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
    data += 2

modify_data()
print(arr)

关键是新的address_as_void_pointer 函数,它将内存地址(以int 形式给出)转换为numba.carray 可用的指针。

【讨论】:

    猜你喜欢
    • 2013-02-15
    • 1970-01-01
    • 1970-01-01
    • 2012-08-08
    • 2012-04-15
    • 2017-06-29
    • 2016-03-27
    相关资源
    最近更新 更多