【问题标题】:Indexing ranges of columns of array when only the indexes of the ranges are given仅给定范围的索引时索引数组列的范围
【发布时间】:2021-08-14 11:30:38
【问题描述】:

当只给出所需范围的索引时,我正在寻找一种有效的方法来索引具有多个范围的 numpy 数组的列。

例如,给定以下数组,范围大小为r_size=3

import numpy as np
arr = np.arange(18).reshape((2,9))

array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
       [ 9, 10, 11, 12, 13, 14, 15, 16, 17]])

这意味着总共有 3 组范围 [r0, r1, r2],它们在数组中的元素分布为:

[[r0_00, r0_01, r0_02, r1_00, r1_01, r1_02, r2_00, r2_01, r2_02]
 [r0_10, r0_11, r0_12, r1_10, r1_11, r1_12, r2_10, r2_11, r2_12]]

因此,如果我想访问r0r2 范围,那么我将获得:

arr    = np.arange(18).reshape((2,9))
r_size = 3
ranges = [0, 2]
# --------------------------------------------------------
# Line that index arr, with the variable ranges... Output:
# --------------------------------------------------------
array([[ 0,  1,  2,  6,  7,  8],
       [ 9, 10, 11, 15, 16, 17]])

我发现最快的方法如下:

import numpy as np
from itertools import chain

arr    = np.arange(18).reshape((2,9))
r_size = 3
ranges = [0,2]

arr[:, list(chain(*[range(r_size*x,r_size*x+r_size) for x in ranges]))]

array([[ 0,  1,  2,  6,  7,  8],
       [ 9, 10, 11, 15, 16, 17]])

但我不确定在速度方面是否可以提高。

提前致谢!

【问题讨论】:

    标签: python numpy indexing range


    【解决方案1】:

    您可以先将数组拆分为 r_size 块:

    >>> splits = np.split(arr, r_size, axis=1)
    [array([[ 0,  1,  2],
            [ 9, 10, 11]]), 
     array([[ 3,  4,  5],
            [12, 13, 14]]), 
     array([[ 6,  7,  8],
            [15, 16, 17]])]
    

    np.stack堆叠并选择正确的ranges

    >>> stack = np.stack(splits)[ranges]
    array([[[ 0,  1,  2],
            [ 9, 10, 11]],
    
           [[ 6,  7,  8],
            [15, 16, 17]]])
    

    并在axis=1 上与np.hstacknp.concantenate 水平连接:

    >>> np.stack(stack)
    array([[ 0,  1,  2,  6,  7,  8],
           [ 9, 10, 11, 15, 16, 17]])
    

    总体来说是这样的:

    >>> np.hstack(np.stack(np.split(arr, r_size, axis=1))[ranges])
    array([[ 0,  1,  2,  6,  7,  8],
           [ 9, 10, 11, 15, 16, 17]])
    

    或者,您可以专门使用np.reshapes,这样会更快:

    初始重塑:

    >>> arr.reshape(len(arr), -1, r_size)
    array([[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8]],
    
           [[ 9, 10, 11],
            [12, 13, 14],
            [15, 16, 17]]])
    

    使用ranges 进行索引:

    >>> arr.reshape(len(arr), -1, r_size)[:, ranges]
    array([[[ 0,  1,  2],
            [ 6,  7,  8]],
    
           [[ 9, 10, 11],
            [15, 16, 17]]])
    

    然后,重新塑造成最终形式:

    >>> arr.reshape(len(arr),  -1, r_size)[:, ranges].reshape(len(arr), -1)
    

    【讨论】:

    • 感谢您的回答。它完美地完成了任务,但我担心它可能比问题末尾提出的方法慢:%timeit np.hstack(np.stack(np.split(arr, r_size, axis=1))[ranges]) 给出:36.9 µs ± 278 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 而:%timeit arr[:, list(chain(*[range(r_size*x,r_size*x+r_size) for x in ranges]))] 给出:4.84 µs ± 14.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    • 我用np.reshape提供了一个替代解决方案
    • 是的 :) 它更快。非常感谢!。它给出:4.41 µs ± 19.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)%timeit arr.reshape(len(arr), r_size, -1)[:, ranges].reshape(len(arr), -1)
    • 您的编辑确实非常重要。我通常倾向于使用在我的最小示例中尺寸是唯一的张量。这样落入这种陷阱的机会就更少了!
    • 如果我没记错的话,花哨的索引(使用索引列表来访问数组的元素)将始终复制数据并将其重新整形为最终形式将强制 reshape 进行复制。这意味着数据将被复制两次,如果您正在处理大型数组,这可能会降低性能。我认为您可以通过直接更改数组的形状和步幅来规避第一个副本(通过使用精美的索引)。
    【解决方案2】:

    您将不可避免地需要复制数据以在连续数组中获得所需的结果。尽管为了提高效率,我建议尽量减少复制数据的次数。任何类型的整形操作都可以用np.lib.stride_tricks.as_strided表示。

    假设原始数组包含 64 位整数,那么每个元素是 8 个字节排列成某种形状:

    import numpy as np
    arr = np.arange(18).reshape((2,9))
    arr.shape, arr.strides
    

    输出:

    ((2, 9), (72, 8))
    

    所以每列跳过 8 个字节,每行跳过 72 个字节。 arr.reshape(len(arr), -1, r_size)可以表示为:

    np.lib.stride_tricks.as_strided(arr, (2,3,3), (72,24,8))
    

    输出:

    array([[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8]],
    
           [[ 9, 10, 11],
            [12, 13, 14],
            [15, 16, 17]]])
    

    arr.reshape(len(arr), -1, r_size)[:, ranges]可以表示为:

    np.lib.stride_tricks.as_strided(arr, (2,2,3), (72,24*2,8))
    

    输出:

    array([[[ 0,  1,  2],
            [ 6,  7,  8]],
    
           [[ 9, 10, 11],
            [15, 16, 17]]])
    

    到目前为止,我们只更改了数组的元数据,这意味着没有数据被复制。此操作的性能成本几乎为零。但是要获得最终的数组,您需要以某种方式复制数据:

    np.lib.stride_tricks.as_strided(arr, (2,2,3), (72,24*2,8)).reshape(len(arr), -1)
    

    输出:

    array([[ 0,  1,  2,  6,  7,  8],
           [ 9, 10, 11, 15, 16, 17]])
    

    这不是一个通用的解决方案,但它可能会给您一些关于如何优化的想法。

    不幸的是,我的时间安排并不支持这些说法,但它仍然很直观,值得对一些更大的数组进行测试。

    【讨论】:

    • 非常感谢。这是一个有趣的选择。
    猜你喜欢
    • 2022-11-06
    • 1970-01-01
    • 2016-03-29
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多