【发布时间】:2021-01-27 16:27:55
【问题描述】:
我目前正在替换我编写的一些代码,假设输入是 numpy 数组,因此它将任意列表作为输入。不幸的是,到目前为止我制作的解决方案比原始代码慢得多。有人可以建议我如何恢复原始解决方案的速度吗?
该代码应该为上三角矩阵表示生成一个布尔索引。如果没有输入检查和类似的东西,这就是代码的核心:
一些导入和示例输入:
import numpy as np
descriptor = list(range(100))
descriptor_arr = np.array(descriptor)
value = [0, 2, 13, 14, 11, 23, 45, 16]
这是我当前基于列表的版本:
def get_idx_slow(descriptor, value):
ix, iy = np.triu_indices(len(descriptor), 1)
pattern_in_value = [p in value for p in descriptor]
return [(pattern_in_value[idx_x] & pattern_in_value[idx_y]) for idx_x, idx_y in zip(ix, iy)]
这是之前基于数组的版本:
def get_idx_fast(descriptor, value):
ix, iy = np.triu_indices(len(descriptor), 1)
selection_x = np.any(np.array([descriptor[ix] == v for v in value]), axis=0)
selection_y = np.any(np.array([descriptor[iy] == v for v in value]), axis=0)
return selection_x & selection_y
我的计时结果:
%timeit get_idx_slow(descriptor, value)
1.2 ms ± 33.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit get_idx_fast(descriptor_arr, value)
217 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
【问题讨论】:
-
这可能很难或不可能。 numpy 高度优化并利用多处理和/或向量指令。