【问题标题】:how to "broadcast" np.searchsorted [duplicate]如何“广播” np.searchsorted [重复]
【发布时间】:2021-12-13 08:59:14
【问题描述】:

我有一个二维 NumPy 数组。每行都经过排序并包含少量元素(如 10),但有大量行(如 1e6)。它可能看起来像这样:

haystacks = [
    [1, 4, 7, 8],
    [2, 5, 5, 7],
    [10, 11, 25, 30],
    ...
]

我也有一个一维数组。该数组的元素数与第一个数组的行数一样多。所以,也许:

needles = [10, 6, 15 ...]

我想对 2d 数组中相应行上的 1d 数组中的每个元素执行二进制搜索。我会使用np.searchsorted,但它似乎不支持这个用例。

我在物理系统的大型模拟中使用它。所以,性能非常重要。下面的代码可以运行,但是速度太慢了。

positions = []
for needle, haystack in zip(needles, haystacks):
   positions.append(np.searchsorted(haystack, needle))

print(positions)

NumPy 解决方案是首选。其他库还可以,但我无法让 Numba 正常工作。

有人有什么想法吗?

【问题讨论】:

  • 你试过分配内存给positions吗?使用 NumPy 数组而不是列表来存储变量。应该会加快速度。
  • 考虑到行有多小,在 Python 循环中使用广播 线性 搜索而不是二进制搜索可能会更快。

标签: python numpy performance


【解决方案1】:

这是一个有效的 numba 解决方案。如果您考虑到多处理,您可能希望将 enumerate 替换为 numba.prange

import numpy as np
import numba
haystacks = np.array([
    [1, 4, 7, 8],
    [2, 5, 5, 7],
    [10, 11, 25, 30],
])

needles = np.array([10, 6, 16])
@numba.njit
def search(needles, haystacks):
    positions = np.zeros_like(needles)

    for idx, _ in enumerate(needles):
        
        positions[idx] = np.searchsorted( haystacks[idx], needles[idx],)

    return positions

print(search(needles, haystacks))

Numba 提供更好的性能:

import timeit
print(timeit.timeit(lambda: search_np(needles, haystacks), number=100_000))
print(timeit.timeit(lambda: search_nb(needles, haystacks), number=100_000))
#1.103232195999908 for np
#0.3278064189998986 for nb

【讨论】:

    【解决方案2】:

    这是一个聪明的方法,将指针添加到数组 (og_combo),跨轴 1 排序(快速),并找到差异为零的索引:

    import numpy as np
    import pandas as pd
    
    haystack = np.random.randint(0, 10, size=(1_000_000, 4))
    needles = np.random.randint(0, 10, size=(1_000_000,))
    
    haystack = np.sort(haystack, axis=1)
    haystack_inf = np.hstack((haystack, (np.ones(len(haystack)) * 1e10).reshape((-1, 1))))
    
    og_combo = np.hstack((haystack_inf, needles.reshape(-1, 1)))
    combo_sorted = np.sort(og_combo, axis=1)
    diff = og_combo - combo_sorted
    df = pd.DataFrame(np.argwhere(diff != 0), columns=["col1", "col2"])
    final = df.groupby("col1")["col2"].first().values
    

    Haystack(随机示例):

    array([[1, 2, 7, 8],
           [0, 2, 7, 8],
           [2, 2, 4, 9],
           ...,
           [0, 3, 5, 6],
           [0, 1, 4, 9],
           [0, 4, 6, 7]])
    

    针(相同的随机示例):

    array([5, 6, 2, ..., 9, 0, 2])
    

    最终位置(相同的随机示例):

    array([2, 2, 2, ..., 4, 1, 1])
    

    【讨论】:

      猜你喜欢
      • 2021-07-23
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-11-01
      • 1970-01-01
      • 2012-01-27
      • 2017-07-25
      • 2019-11-25
      相关资源
      最近更新 更多