【发布时间】:2019-06-07 16:02:09
【问题描述】:
我正在尝试将一个小的多维数组插入到 numba jitclass 中的一个较大的数组中。小数组是由索引列表定义的大数组的特定位置。
以下 MWE 显示了没有 numba 的问题 - 一切都按预期工作
import numpy as np
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution 1 using pure python
def nonNumbaFunction1(self, idx, values):
self.A[idx[:, None], idx] = values
# solution 2 using pure python
def nonNumbaFunction2(self, idx, values):
self.A[np.ix_(idx, idx)] = values
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.nonNumbaFunction1(idx, values)
print(f'A =\n{obj.A}')
obj.nonNumbaFunction2(idx, values)
print(f'A =\n{obj.A}')
nonNumbaFunction1 和 nonNumbaFunction2 两个函数都不能在 numba 类中工作。所以我目前的解决方案看起来像这样,在我看来这不是很好
import numpy as np
from numba import jitclass
from numba import int64, float64
from collections import OrderedDict
specs = OrderedDict()
specs['A'] = float64[:, :]
@jitclass(specs)
class NumbaClass(object):
def __init__(self, n, m):
self.A = np.zeros((n, m))
# solution for numba jitclass
def numbaFunction(self, idx, values):
for i in range(len(values)):
idxi = idx[i]
for j in range(len(values)):
idxj = idx[j]
self.A[idxi, idxj] = values[i, j]
if __name__ == "__main__":
n = 6
m = 8
obj = NumbaClass(n, m)
print(f'A =\n{obj.A}')
idx = np.array([0, 2, 5])
values = np.arange(len(idx)**2).reshape(len(idx), len(idx))
print(f'values =\n{values}')
obj.numbaFunction(idx, values)
print(f'A =\n{obj.A}')
所以我的问题是:
- 有谁知道 numba 中这种索引的解决方案,或者是否有其他矢量化解决方案?
-
nonNumbaFunction1有更快的解决方案吗?
知道插入的数组很小(4x4 到 10x10)可能很有用,但是这种索引出现在嵌套循环中,所以它也必须快速安静!稍后我也需要对三维对象进行类似的索引。
【问题讨论】:
标签: python multidimensional-array indexing jit numba