【问题标题】:Efficient Way to Repeatedly Split Large NumPy Array and Record Middle重复拆分大型 NumPy 数组和记录中间的有效方法
【发布时间】:2021-04-15 01:35:02
【问题描述】:

我有一个大型 NumPy 数组 nodes = np.arange(100_000_000),我需要通过以下方式重新排列这个数组:

  1. 记录然后删除数组中的中间值
  2. 将数组拆分为left 一半和right 一半
  3. 对每一半重复步骤 1-2
  4. 当所有值都用完时停止

因此,对于较小的输入示例nodes = np.arange(10),输出将是:

[5 2 8 1 4 7 9 0 3 6]

这是天真的做的:

import numpy as np

def split(node, out):
    mid = len(node) // 2
    out.append(node[mid])
    return node[:mid], node[mid+1:]


def reorder(a):
    nodes = [a.tolist()]
    out = []

    while nodes:
        tmp = []
        for node in nodes:
            for n in split(node, out):
                if n:
                    tmp.append(n)
        nodes = tmp

    return np.array(out)

if __name__ == "__main__":
    nodes = np.arange(10)
    print(reorder(nodes))

但是,这对于 nodes = np.arange(100_000_000) 来说太慢了,所以我正在寻找一个更快的解决方案。

【问题讨论】:

  • 太慢有多慢?即现在有多快,你需要多快?
  • 我认为任何小于 2 秒(最好小于 1 秒)的时间都可以。
  • 输入数组中的数字是否总是像从 arange 一样按顺序开始,还是可以是任意的数字组合?也可以有重复吗?
  • 是的,订单总是从0n-1,如arange。没有重复。
  • 我更新了我的答案并投票给了另一个答案,我认为这可能是你摆脱 numpy 的最佳选择(使用相同的算法,Cython 或 C/C++ 或 rust 会更快)。您会看到我在对该答案的评论中指出,由于 a 中的每个值都只是该值的索引,因此无需实际创建该数组,因此答案仅取决于其大小。

标签: python performance numpy


【解决方案1】:

您可以通过处理切片组使用 Numpy 向量化您的函数

这是一个实现:

# Similar to [e for tmp in zip(a, b) for e in tmp] ,
# but on Numpy arrays and much faster
def interleave(a, b):
    assert len(a) == len(b)
    return np.column_stack((a, b)).reshape(len(a) * 2)

# n is the length of the input range (len(a) in your example)
def fast_reorder(n):
    if n == 0:
        return np.empty(0, dtype=np.int32)

    startSlices = np.array([0], dtype=np.int32)
    endSlices = np.array([n], dtype=np.int32)
    allMidSlices = np.empty(n, dtype=np.int32)  # Similar to "out" in your implementation
    midInsertCount = 0                               # Actual size of allMidSlices

    # Generate a bunch of middle values as long as there is valid slices to split
    while midInsertCount < n:
        # Generate the new mid/left/right slices
        midSlices = (endSlices + startSlices) // 2

        # Computing the next slices is not needed for the last step
        if midInsertCount + len(midSlices) < n:
            # Generate the nexts slices (possibly with invalid ones)
            newStartSlices = interleave(startSlices, midSlices+1)
            newEndSlices = interleave(midSlices, endSlices)

            # Discard invalid slices
            isValidSlices = newStartSlices < newEndSlices
            startSlices = newStartSlices[isValidSlices]
            endSlices = newEndSlices[isValidSlices]

        # Fast appending
        allMidSlices[midInsertCount:midInsertCount+len(midSlices)] = midSlices
        midInsertCount += len(midSlices)

    return allMidSlices[0:midInsertCount]

在我的机器上,这比您的标量实现快 89 倍,输入 np.arange(100_000_000) 从 2min35 下降到 1.75s。它还消耗更少的内存(大约少 3~4 倍)。请注意,如果您想要更快的代码,那么您可能需要使用本地语言,如 C 或 C++。

【讨论】:

  • 我没有测试你的解决方案,但我相信你的话,我认为它涵盖了我之前在评论中提到的所有内容(尽管当时范围是 1000 亿个值,所以我无论如何都看不到可能解决问题的方法:P)我要做的唯一评论是,您可以在代码中看到 a 中的值甚至不需要,因为 a 中的每个值都等于它的索引因此,如果您将 func 更改为仅获取数组的长度而不是数组本身,则无需创建 a。
  • @DavidOldford 是的 100e9 值有点太大了;)。一开始我不确定输入 a 是否始终是从 0 到 len(a)-1 的范围。但最后,我含蓄地提出了这个假设。很好的收获,谢谢!
【解决方案2】:

编辑: 该问题已更新为具有更小的输入数组,因此出于历史原因,我将其保留在下面。基本上,这可能是一个错字,但我们经常习惯于计算机处理非常大的数字,当涉及到内存时,它们可能是一个真正的问题。

已经有一个我认为符合要求的其他人提交的基于 numpy 的解决方案。

您的代码需要大量 RAM 才能容纳 1000 亿个 64 位整数。你有 800GB 的内存吗?然后将 numpy 数组转换为比数组大得多的列表(numpy 数组中的每个打包 64 位 int 将成为内存效率低得多的 python int 对象,并且列表将具有指向该对象的指针)。然后,您制作了许多列表切片,这些切片不会复制数据,但会复制指向数据的指针并使用更多 RAM。您还可以一次将所有结果值附加到一个列表中。列表通常可以非常快地添加项目,但是具有如此极端的大小,这不仅会很慢,而且分配列表的方式可能会非常浪费 RAM 并导致重大问题(我相信当它们获得时它们的大小会翻倍到一定程度的丰满度,因此您最终会分配比您需要的更多的 RAM,并进行许多分配和可能的副本)。你在什么机器上运行这个?有一些方法可以改进您的代码,但除非您在超级计算机上运行它,否则我不知道您是否会完成该计算。我只有……只有?有 32GB 的 RAM,我什至不会尝试创建 100B int_64 numpy 数组,因为我不想用 ssd 写入寿命来获得大量虚拟内存。

至于改进您的代码坚持使用 numpy 数组,不要更改为 python 列表,它将大大增加您需要的 RAM。预先分配一个 numpy 数组来放入答案。然后你需要一个新算法。任何递归或递归(即分割输入的循环)都需要跟踪大量状态,您的节点列表将非常庞大,并且再次使用大量 RAM。您可以使用 len(a) 来指示从列表中删除的值,并每次扫描整个数组以找出下一步该做什么,但这将节省 RAM,从而有利于大量搜索巨大的数组。我觉得有一种算法可以从每一端截取数字并将它们放在输出中并跟踪开头和结尾,但至少我还没有弄清楚。

我还认为有一种更简单的算法,您只需跟踪已完成的拆分次数,而不是制作一个巨大的切片列表并将其全部保存在内存中。取左半边的中间,然后取右半边的中间,然后加一,当你取左半边的中间时,你知道你必须跳到右半边,然后计数是一,所以你跳到原来的右半边是左半边,一直在...根据两半的深度和输入的长度,您应该能够在不扫描或跟踪所有这些切片的情况下跳来跳去,尽管我无法专注于有很多时间在我的脑海中思考这个问题。

如果您确实需要突破限制,那么您应该考虑使用 C/C++,这样您就可以尽可能高效地使用 RAM,并且因为您正在做大量的小事情t 很好地映射到 python 性能。

【讨论】:

  • 糟糕!你说的对。应该是100_000_000。感谢您发现我被1000 关闭了。我已经更新了我的问题
猜你喜欢
  • 2015-10-06
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2016-04-19
  • 1970-01-01
  • 2021-08-31
相关资源
最近更新 更多