【问题标题】:Efficiently return insertion point indices with fractional components using NumPy使用 NumPy 有效地返回带有小数部分的插入点索引
【发布时间】:2020-03-08 07:46:54
【问题描述】:

我希望有效地计算应在数组中插入元素以保持顺序的索引,但包括一个小数部分,表示数组中两个最近点之间的“距离”。

应该可以使用索引和分数取回原始值。在实践中,以及性能很重要的原因,我需要对大量数据点执行此操作。

为了证明我的意思,我通过np.searchsorted 和一些if 语句提出了一些工作逻辑,但无法使用NumPy 对逻辑进行矢量化。我也很高兴看到一个有效的解决方案,它利用 numba 并具有与 NumPy 相当或更好的性能。甚至是我不知道的 NumPy、Scipy 等现成的解决方案。

我还在下面包含了一些基准测试代码。

import numpy as np

np.random.seed(0)

datapoint = np.random.random() * np.random.choice([1, -1]) * 500  # -274.4067
line = np.linspace(-500, 500, 101)  # [-500, -490, ... , 0, ..., 490, 500] - an ordered array, may not be linspace

def get_position(line, point):
    position = np.searchsorted(line, point, side='right')
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main, fraction

idx, frac = get_position(line, datapoint)              # (22, 0.55932480363376269)
print(line[idx] + frac * (line[idx + 1] - line[idx]))  # -274.4067; test to see if you get back original value

def run_multiple(line, data):
    out = np.empty((data.shape[0], 3))
    for i in range(data.shape[0]):
        idx, frac = get_position(line, data[i])
        out[i, 0] = data[i]
        out[i, 1] = idx
        out[i, 2] = frac
    return out

基准测试

# Python 3.6.0, NumPy 1.11.3, Numba 0.30.1
# Note: Numba 0.30.1 does not support "side" argument of np.searchsorted; not able to upgrade

n = 10**5  # Actual n will be larger
res = run_multiple(line, np.random.random(n) * np.random.choice([1, -1], n) * 500)  # 901 ms per loop

# array([[ -4.22132874e+02,   7.00000000e+00,   7.86712571e-01],
#        [ -4.28972809e+02,   7.00000000e+00,   1.02719119e-01],
#        [  4.23625869e+02,   9.20000000e+01,   3.62586939e-01],
#        ..., 
#        [ -1.88627877e+02,   3.10000000e+01,   1.37212282e-01],
#        [  4.98162640e+01,   5.40000000e+01,   9.81626397e-01],
#        [  1.35777097e+02,   6.30000000e+01,   5.77709684e-01]])

【问题讨论】:

标签: python arrays algorithm performance numpy


【解决方案1】:

如果 Numba(或您正在使用的版本)不支持某些功能,最好查看Numba source code 并查看已有的功能。 很多时候,至少部分问题已经实现。

代码

import numpy as np
import numba as nb

#almost copied from Numba source
#https://github.com/numba/numba/blob/master/numba/targets/arraymath.py
"""Copyright (c) 2012, Anaconda, Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.

Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
@nb.njit()
def searchsorted_right(a, v):
    n = len(a)
    if np.isnan(v):
        # Find the first nan (i.e. the last from the end of a,
        # since there shouldn't be many of them in practice)
        for i in range(n, 0, -1):
            if not np.isnan(a[i - 1]):
                return i
        return 0
    lo = 0
    hi = n
    while hi > lo:
        mid = (lo + hi) >> 1
        if a[mid]<= v:
            # mid is too low => go up
            lo = mid + 1
        else:
            # mid is too high, or is a NaN => go down
            hi = mid
    return lo

@nb.njit()
def get_position(line, point):
    position = searchsorted_right(line, point)
    size = line.shape[0]
    if position == 0:
        main = 0
        fraction = 0
    elif position == size:
        main = size-1
        fraction = 0
    else:
        main = position - 1
        fraction = (point - line[position-1]) / (line[position] - line[position-1])
    return main, fraction

@nb.njit(parallel=True)
def run_multiple(line, data):
    out = np.empty((data.shape[0], 3))
    for i in nb.prange(data.shape[0]):
        idx, frac = get_position(line, data[i])
        out[i, 0] = data[i]
        out[i, 1] = idx
        out[i, 2] = frac
    return out

时间安排

n = 10**5
line = np.linspace(-500, 500, 101)
points = np.random.random(n) * np.random.choice([1, -1], n) * 500

%timeit run_multiple(line, points)
#1.08 ms ± 14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

#@user3483203
%timeit frac(line, points)
#8.65 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

【讨论】:

    【解决方案2】:

    为了向量化这一点,我会屏蔽掉边缘情况,并在最后担心它们。无论如何,您只需要考虑position == size 条件,因为各个列中的低条件仅为零,out 数组已经满足了这一条件。

    def frac(line, points):
        pos = np.searchsorted(line, points, side='right')
        low = pos == 0
        high = pos == line.shape[0]
        m = ~(low | high)
        ii = points[m]
        jj = pos[m]
        frac = (ii - line[jj-1]) / (line[jj] - line[jj-1])
        out = np.zeros((points.shape[0], 3))
        out[:, 0] = points
        out[m, 1] = jj - 1
        out[m, 2] = frac
        out[high, 1] = line.shape[0] - 1
        return out
    

    基准测试

    n = 10**5
    line = np.linspace(-500, 500, 101)
    points = np.random.random(n) * np.random.choice([1, -1], n) * 500
    
    In [5]: %timeit run_multiple(line, points)
    1.23 s ± 53.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    In [7]: %timeit frac(line, points)
    13.4 ms ± 290 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    
    In [8]: np.allclose(frac(line, points), run_multiple(line, points))
    Out[8]: True
    

    【讨论】:

      猜你喜欢
      • 2014-06-04
      • 1970-01-01
      • 2012-09-13
      • 1970-01-01
      • 2013-09-06
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多