【问题标题】:Indexing multidimensional numpy array inside numba's jitclass在 numba 的 jitclass 中索引多维 numpy 数组
【发布时间】:2019-06-07 16:02:09
【问题描述】:

我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个较大的数组中。小数组是由索引列表定义的大数组的特定位置。

以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作

import numpy as np

class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution 1 using pure python
    def nonNumbaFunction1(self, idx, values):
        self.A[idx[:, None], idx] = values

    # solution 2 using pure python
    def nonNumbaFunction2(self, idx, values):
        self.A[np.ix_(idx, idx)] = values

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.nonNumbaFunction1(idx, values)
    print(f'A =\n{obj.A}')

    obj.nonNumbaFunction2(idx, values)
    print(f'A =\n{obj.A}')

nonNumbaFunction1nonNumbaFunction2 两个函数都不能在 numba 类中工作。所以我目前的解决方案看起来像这样,在我看来这不是很好

import numpy as np

from numba import jitclass      
from numba import int64, float64
from collections import OrderedDict

specs = OrderedDict()
specs['A'] = float64[:, :]

@jitclass(specs)
class NumbaClass(object):

    def __init__(self, n, m):
        self.A = np.zeros((n, m))

    # solution for numba jitclass
    def numbaFunction(self, idx, values):
        for i in range(len(values)):
            idxi = idx[i]
            for j in range(len(values)):
                idxj = idx[j]
                self.A[idxi, idxj] = values[i, j]

if __name__ == "__main__":
    n = 6
    m = 8
    obj = NumbaClass(n, m)
    print(f'A =\n{obj.A}')

    idx = np.array([0, 2, 5])
    values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
    print(f'values =\n{values}')

    obj.numbaFunction(idx, values)
    print(f'A =\n{obj.A}')

所以我的问题是:

  • 有谁知道 numba 中这种索引的解决方案,或者是否有其他矢量化解决方案?
  • nonNumbaFunction1 有更快的解决方案吗?

知道插入的数组很小(4x4 到 10x10)可能很有用,但是这种索引出现在嵌套循环中,所以它也必须快速安静!稍后我也需要对三维对象进行类似的索引。

【问题讨论】:

    标签: python multidimensional-array indexing jit numba


    【解决方案1】:

    由于 numba 对索引支持的限制,我认为您最好自己编写 for 循环。要使其跨维度通用,您可以使用 generated_jit 装饰器进行专业化。像这样的:

    def set_2d(target, values, idx):
        for i in range(values.shape[0]):
            for j in range(values.shape[1]):
                target[idx[i], idx[j]] = values[i, j]
    
    def set_3d(target, values, idx):
        for i in range(values.shape[0]):
            for j in range(values.shape[1]):
                for k in range(values.shape[2]):
                    target[idx[i], idx[j], idx[k]] = values[i, j, l]
    
    @numba.generated_jit
    def set_nd(target, values, idx):
        if target.ndim == 2:
            return set_2d
        elif target.ndim == 3:
            return set_3d
    

    然后,这可以在你的 jitclass 中使用

    specs = OrderedDict()
    specs['A'] = float64[:, :]
    
    @jitclass(specs)
    class NumbaClass(object):
        def __init__(self, n, m):
            self.A = np.zeros((n, m))
        def numbaFunction(self, idx, values):
            set_nd(self.A, values, idx)
    

    【讨论】:

    • 谢谢你形成很好的答案。我还不知道generated_jit。对我来说看起来很有希望。我明天试试!
    猜你喜欢
    • 2019-04-21
    • 2015-04-19
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2013-06-05
    • 1970-01-01
    • 2020-03-25
    • 2019-01-04
    相关资源
    最近更新 更多