【问题标题】:Python/Numba - Custom class object as input typePython/Numba - 自定义类对象作为输入类型
【发布时间】:2018-05-22 12:22:01
【问题描述】:

我从numba 开始,我的第一个目标是尝试使用嵌套循环加速一个不那么复杂的函数。

给定以下类:

class TestA:
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def get_mult(self):
        return self.a * self.b

还有一个包含TestA 类对象的numpy ndarray。维度(N,) 其中N 通常长度约为300 万。

现在给出以下函数:

def test_no_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in range(container_length):
        for j in range(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

我尝试过使用numba 让它与上面的函数一起工作,但是我似乎无法让它与nopython=True 标志一起工作,如果它设置为false,那么运行时间高于no-jit 函数。

这是我尝试 jit 函数的最新尝试(也使用 nb.prange):

@nb.jit(nopython=False, parallel=True)
def test_jit(custom_class_obj_container):
    container_length = len(custom_class_obj_container)
    sum = 0
    for i in nb.prange(container_length):
        for j in nb.prange(i + 1, container_length):
            obj_i = custom_class_obj_container[i]
            obj_j = custom_class_obj_container[j]
            sum += (obj_i.get_mult() + obj_j.get_mult())

    return sum

我试图四处搜索,但似乎找不到关于如何在签名中定义自定义类的教程,以及我将如何去加速此类功能并让它在 GPU 上运行并且可能(任何有关此事的信息都将受到高度赞赏)让它与 cuda 库一起运行 - 这些库已安装并可以使用(以前与 tensorflow 一起使用)

【问题讨论】:

  • 如果这能起作用,我感到很惊讶。 Numba 依赖于具有底层 ctype 映射并且不能扩展到任意对象 dtype AFAIK。
  • Numba 仅针对 atomic 类型和 numpy 类型编译为 C(因为它们已经在 C 中)。它无法使用 nopython 处理自定义对象,因为它没有将对象映射到 C 的方法。

标签: python python-3.x numpy numba


【解决方案1】:

numba 文档给出了创建自定义类型的示例,即使对于 nopython 模式也是如此:https://numba.pydata.org/numba-doc/latest/extending/interval-example.html

但在您的情况下,除非这是您实际想要做的真正精简的版本,否则似乎最简单的方法是重用现有类型。此外,构建 3M 长度的对象数组会很慢,并且会产生碎片内存(因为对象没有存储在连续的块中)。

一个如何使用记录数组来解决问题的例子:

x_dt = np.dtype([('a', np.float64),
                 ('b', np.float64)])
n = 30000
buf = np.arange(n*2).reshape((n, 2)).astype(np.float64)
vec3 = np.recarray(n, dtype=x_dt, buf=buf) 

@numba.njit
def mult(a):
    return a.a * a.b

@numba.jit(nopython=True, parallel=True)
def sum_of_prod(vector):
    sum = 0
    vector_len = len(vector)
    for i in numba.prange(vector_len):
        for j in numba.prange(i + 1, vector_len):
            sum += mult(vector[i]) + mult(vector[j])
    return sum

sum_of_prod(vec3)

FWIW,我不是 numba 专家。我在搜索如何在 numba 中为非数字内容实现自定义类型时发现了这个问题。在您的情况下,因为这是高度数字化的,我认为自定义类型可能是矫枉过正。

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2011-06-21
    • 1970-01-01
    • 1970-01-01
    • 2011-07-30
    • 1970-01-01
    • 2016-05-22
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多