【问题标题】:indexing arrays with arrays in numba (njit): variable-dimension ndarrays使用 numba (njit) 中的数组索引数组:可变维度 ndarrays
【发布时间】:2021-08-06 21:17:28
【问题描述】:

我有一个带有 D 元素的一维整数数组(即idx = np.array([i0, i1, ...]), s.t. idx.size = D),其中每个元素对应于具有 D 维的 ND 数组的该维的索引(即data s.t. data.ndim = D)。如何使用索引数组idx 索引data 数组?


在 python 中我会使用 data[tuple(idx)],但在 numba nopython 模式下不支持 tuple

我当前的解决方法是使用 data.ravel() 并将 ND 索引转换为扁平数组的一维索引,但似乎必须有一个更简单(并且计算速度更快)的解决方案。某处有take_along_each_axis(data, idx) 方法吗?

【问题讨论】:

  • 为什么需要使用njit?这是一个基本的 numpy 操作,在编译后的代码中实现。它没有在 python 中迭代。
  • date.__getitem__(idx)
  • @hpaulj 我的最终目标是比索引数组稍微复杂一些的计算。 __getitem__() 使用数组(而不是元组)与使用它进行索引相同,它顺序获取每个元素而不是使用每个元素来索引每个维度,即 data.__getitem__([0, 0, ... 0]) 返回 data 本身,而不是第 0 个条目。

标签: python arrays numpy numba


【解决方案1】:

让我们做一些时间测试:

In [135]: data = np.ones((100,100,100,100)); idx = (50,50,50,50)

这几乎是 1 Gb 的内存 - 不足以造成内存错误,但仍然应该是一个合理的测试。实际上,对于更小数组的基本索引,我得到了相同的时间。对于其他idx

In [136]: timeit data[idx]
212 ns ± 9.25 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

解释器将其翻译成方法调用:

In [137]: timeit data.__getitem__(idx)
283 ns ± 4.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

索引“平面”数组,可以这样做:

In [138]: timeit data.flat[np.ravel_multi_index(idx,data.shape)]
6.65 µs ± 75.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

或将转换排除在循环之外:

In [139]: %%timeit x=np.ravel_multi_index(idx,data.shape)
     ...: data.flat[x] 
574 ns ± 23.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [142]: %%timeit x=np.ravel_multi_index(idx,data.shape);df=data.flat
     ...: df[x]
345 ns ± 6.39 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

我认为在某些情况下平面索引更快,但这不是一个。

所以一个独立的操作我看不出写njit 版本的意义。我想如果它是一些更大的操作的一部分,它可能是值得的。

【讨论】:

  • 感谢您的回答。所以,扁平化确实慢了 2 倍以上。实际上ravel_multi_index 没有在 numba 中实现,所以我不得不编写自己的版本,这可能更慢(尽管预先缓存了数组形状的产品)。
猜你喜欢
  • 1970-01-01
  • 2019-06-07
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2018-01-22
  • 2019-11-20
相关资源
最近更新 更多