您提到要避免使用大型掩码数组。除非您的“大数组”和“特定值”数组都非常大,否则我不会尝试避免这种情况。通常,numpy 最好允许创建相对较大的临时数组。
但是,如果您确实需要更严格地控制内存使用,您有多种选择。一个典型的技巧是只对操作的一部分进行向量化并迭代较短的输入(这在下面的第二个示例中显示)。它避免了 Python 中的嵌套循环,并且可以显着减少所涉及的内存使用量。
我将展示三种不同的方法。还有其他几个(如果你真的需要严格的控制和性能,包括下降到 C 或 Cython),但希望这能给你一些想法。
附带说明,对于这些小输入,数组创建的开销将压倒差异。我所指的速度和内存使用仅适用于大型(>~1e6 个元素)数组。
完全矢量化,但内存占用最多
最简单的方法是一次计算所有距离,然后将蒙版缩小到与初始数组相同的形状。例如:
import numpy as np
vals = np.array([14,21,48,54,92,215])
other = np.array([20,50,90,210])
dist = np.abs(vals[:,None] - other[None,:])
mask = np.all(dist > 3, axis=1)
result = vals[mask]
部分矢量化,中间内存使用
另一种选择是为“特定值”数组中的每个元素迭代地构建掩码。这将遍历较短的“特定值”数组(在本例中为 other)的所有元素:
import numpy as np
vals = np.array([14,21,48,54,92,215])
other = np.array([20,50,90,210])
mask = np.ones(len(vals), dtype=bool)
for num in other:
dist = np.abs(vals - num)
mask &= dist > 3
result = vals[mask]
最慢,但内存使用率最低
最后,如果你真的想减少内存使用,你可以遍历你的大数组中的每一项:
import numpy as np
vals = np.array([14,21,48,54,92,215])
other = np.array([20,50,90,210])
result = []
for num in vals:
if np.all(np.abs(num - other) > 3):
result.append(num)
这种情况下的临时列表可能会比之前版本中的掩码占用更多的内存。但是,如果您愿意,可以使用 np.fromiter 来避免临时列表。下面的时序比较显示了一个例子。
时序比较
让我们比较一下这些函数的速度。我们将在“大数组”中使用 10,000,000 个元素,在“特定值”数组中使用 4 个值。这些函数的相对速度和内存使用在很大程度上取决于两个数组的大小,因此您应该只将其视为一个模糊的准则。
import numpy as np
vals = np.random.random(1e7)
other = np.array([0.1, 0.5, 0.8, 0.95])
tolerance = 0.05
def basic(vals, other, tolerance):
dist = np.abs(vals[:,None] - other[None,:])
mask = np.all(dist > tolerance, axis=1)
return vals[mask]
def intermediate(vals, other, tolerance):
mask = np.ones(len(vals), dtype=bool)
for num in other:
dist = np.abs(vals - num)
mask &= dist > tolerance
return vals[mask]
def slow(vals, other, tolerance):
def func(vals, other, tolerance):
for num in vals:
if np.all(np.abs(num - other) > tolerance):
yield num
return np.fromiter(func(vals, other, tolerance), dtype=vals.dtype)
在这种情况下,部分矢量化版本胜出。在vals 明显长于other 的大多数情况下,这是可以预料的。但是,第一个示例 (basic) 几乎一样快,并且可以说更简单。
In [7]: %timeit basic(vals, other, tolerance)
1 loops, best of 3: 1.45 s per loop
In [8]: %timeit intermediate(vals, other, tolerance)
1 loops, best of 3: 917 ms per loop
In [9]: %timeit slow(vals, other, tolerance)
1 loops, best of 3: 2min 30s per loop
无论您选择哪种方式来实现,这些都是常见的矢量化“技巧”,会出现在许多问题中。在 Python、Matlab、R 等高级语言中,尝试完全向量化通常很有用,然后如果内存使用存在问题,则混合使用向量化和显式循环。哪个最好通常取决于输入的相对大小,但在高级科学编程中优化速度与内存使用时,这是一种常见的尝试模式。