如果您需要非常快速地处理大型数组,您甚至可以使用 numbas prange 并行处理计数(对于小型数组,由于并行处理开销,它会更慢)。
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
请注意,当您使用 numba 时,您通常希望写出循环,因为 numba 非常擅长优化。
我实际上是根据这里提到的其他解决方案来计时的(使用我的 Python 模块 simple_benchmark):
要重现的代码:
import numpy as np
from numba import njit, prange
@njit
def n_nonzero(a):
return a[a != 0].size
@njit
def count_non_zero(np_arr):
return len(np.nonzero(np_arr)[0])
@njit()
def methodB(a):
return (a!=0).sum()
@njit(parallel=True)
def parallel_nonzero_count(arr):
flattened = arr.ravel()
sum_ = 0
for i in prange(flattened.size):
sum_ += flattened[i] != 0
return sum_
@njit()
def count_loop(a):
s = 0
for i in a:
if i != 0:
s += 1
return s
from simple_benchmark import benchmark
args = {}
for exp in range(2, 20):
size = 2**exp
arr = np.random.random(size)
arr[arr < 0.3] = 0.0
args[size] = arr
b = benchmark(
funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
arguments=args,
argument_name='array size',
warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)