【问题标题】:Count the number of non zero values in a numpy array in Numba计算 Numba 中 numpy 数组中非零值的数量
【发布时间】:2019-02-22 15:21:15
【问题描述】:

非常简单。我试图在用 Numba (njit()) 编译的 NumPy jit 中计算数组中非零值的数量。 Numba 不允许我尝试以下操作。

  1. a[a != 0].size
  2. np.count_nonzero(a)
  3. len(a[a != 0])
  4. len(a) - len(a[a == 0])

如果还有更快、更 Python 和优雅的方式,我不想使用 for 循环。

对于想要查看完整代码示例的评论者...

import numpy as np
from numba import njit

@njit()
def n_nonzero(a):
    return a[a != 0].size

【问题讨论】:

  • 请显示您尝试过的至少一段实际的、完整的代码,包括import 语句、装饰器和示例测试工具。
  • [list(filter((0).__ne__, l)) for l in a]
  • @MarkSetchell 当然...

标签: python numpy numba


【解决方案1】:

您也可以考虑计算非零值:

import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

我知道这似乎不对,但请耐心等待:

import numpy as np
import numba as nb

@nb.njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

@nb.njit()
def count_len_nonzero(a):
    return len(np.nonzero(a)[0])

@nb.njit()
def count_sum_neq_zero(a):
    return (a != 0).sum()

np.random.seed(100)
a = np.random.randint(0, 3, 1000000000, dtype=np.uint8)
c = np.count_nonzero(a)
assert count_len_nonzero(a) == c
assert count_sum_neq_zero(a) == c
assert count_loop(a) == c

%timeit count_len_nonzero(a)
# 5.94 s ± 141 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_sum_neq_zero(a)
# 848 ms ± 80.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit count_loop(a)
# 189 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

它实际上比np.count_nonzero 快,由于某种原因可能会变得相当慢:

%timeit np.count_nonzero(a)
# 4.36 s ± 69.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

【讨论】:

  • 是的,当 numba 看到可以优化的循环时,它确实表现出色。 +1
【解决方案2】:

如果您需要非常快速地处理大型数组,您甚至可以使用 numbas prange 并行处理计数(对于小型数组,由于并行处理开销,它会更慢)。

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

请注意,当您使用 numba 时,您通常希望写出循环,因为 numba 非常擅长优化。

我实际上是根据这里提到的其他解决方案来计时的(使用我的 Python 模块 simple_benchmark):

要重现的代码:

import numpy as np
from numba import njit, prange

@njit
def n_nonzero(a):
    return a[a != 0].size

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

@njit() 
def methodB(a): 
    return (a!=0).sum()

@njit(parallel=True)
def parallel_nonzero_count(arr):
    flattened = arr.ravel()
    sum_ = 0
    for i in prange(flattened.size):
        sum_ += flattened[i] != 0
    return sum_

@njit()
def count_loop(a):
    s = 0
    for i in a:
        if i != 0:
            s += 1
    return s

from simple_benchmark import benchmark

args = {}
for exp in range(2, 20):
    size = 2**exp
    arr = np.random.random(size)
    arr[arr < 0.3] = 0.0
    args[size] = arr

b = benchmark(
    funcs=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop),
    arguments=args,
    argument_name='array size',
    warmups=(n_nonzero, count_non_zero, methodB, np.count_nonzero, parallel_nonzero_count, count_loop)
)

【讨论】:

  • 在并行循环迭代之间共享sum_ 是否安全? (我不太了解并行化 Numba 的保证)
  • 是的,numba 有一些可以安全并行化的缩减。求和和乘法就是其中之一。这是因为 numba 意识到它可以使用 sum_ = 0 并行处理每个进程,然后在每个进程完成后添加它们。我还检查了与np.count_nonzero 的一致性。
【解决方案3】:

您可以使用np.nonzero 并诱导它的长度:

@njit
def count_non_zero(np_arr):
    return len(np.nonzero(np_arr)[0])

count_non_zero(np.array([0,1,0,1]))
# 2

【讨论】:

  • 那个 [0] 似乎是它做的事情。非常感谢!
  • 有趣的是np.nonzero 使用np.count_nonzero(在c-api 级别)来确定它将在第二次迭代中填充的数组的大小。我虽然使用numba 的全部意义在于能够不受惩罚地迭代。 :)
【解决方案4】:

不确定我是否在这里犯了错误,但这似乎快了 6 倍:

# Make something worth checking
a=np.random.randint(0,3,1000000000,dtype=np.uint8)  

In [41]: @njit() 
    ...: def methodA(a): 
    ...:     return len(np.nonzero(a)[0])                                                                                           

# Call and check result
In [42]: methodA(a)                                                                                 
Out[42]: 666644445

In [43]: %timeit methodA(a)                                                                         
4.65 s ± 28.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [44]: @njit() 
    ...: def methodB(a): 
    ...:     return (a!=0).sum()                                                                                         

# Call and check result    
In [45]: methodB(a)                                                                                 
Out[45]: 666644445

In [46]: %timeit methodB(a)                                                                         
724 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

【讨论】:

    猜你喜欢
    • 2013-12-23
    • 2021-01-05
    • 1970-01-01
    • 2020-02-23
    • 2018-03-31
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多