【发布时间】:2022-01-14 15:29:04
【问题描述】:
我正在尝试优化numpy.packbits:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def _numba_pack(arr, div, su):
for i in prange(div):
s = 0
for j in range(i*8, i*8+8):
s = 2*s + arr[j]
su[i] = s
def numba_packbits(arr):
div, mod = np.divmod(arr.size, 8)
su = np.zeros(div + (mod>0), dtype=np.uint8)
_numba_pack(arr[:div*8], div, su)
if mod > 0:
su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
return su
>>> X = np.random.randint(2, size=99, dtype=bool)
>>> print(numba_packbits(X))
[ 75 24 79 61 209 189 203 187 47 226 170 61 0]
它看起来比np.packbits(X) 慢 2 - 2.5 倍。这是如何在numpy 内部实现的?这可以在numba 中改进吗?
我在通过conda install 安装的numpy == 1.21.2 和numba == 0.53.1 上工作。我的平台是:
结果:
import benchit
from numpy import packbits
%matplotlib inline
benchit.setparams(rep=5)
sizes = [100000, 300000, 1000000, 3000000, 10000000, 30000000]
N = sizes[-1]
arr = np.random.randint(2, size=N, dtype=bool)
fns = [numba_packbits, packbits]
in_ = {s/1000000: (arr[:s], ) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Millions of bits')
t.plot(logx=True, figsize=(12, 6), fontsize=14)
更新
Jérôme 的回应:
@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64_byJérôme(arr, su, pos):
for i in range(64):
j = i * 8
su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
@njit(parallel=True)
def _numba_pack_byJérôme(arr, div, su):
for i in prange(div//64):
_numba_pack_x64_byJérôme(arr[i*8:(i+64)*8], su[i:i+64], i)
for i in range(div//64*64, div):
j = i * 8
su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
def numba_packbits_byJérôme(arr):
div, mod = np.divmod(arr.size, 8)
su = np.zeros(div + (mod>0), dtype=np.uint8)
_numba_pack_byJérôme(arr[:div*8], div, su)
if mod > 0:
su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
return su
用法:
>>> print(numba_packbits_byJérôme(X))
[ 75 24 79 61 209 189 203 187 47 226 170 61 0]
结果:
【问题讨论】:
-
当
X.size可以被8 整除时,你的代码是否会变得更快,从而mod > 0为假?例如,X = np.random.randint(2, size=80, dtype=bool) -
@ForceBru 我专注于大数据,所以没有效果
-
numpy implementation 是用 C 语言编写的,如果可用,则使用 SIMD。是什么让您认为 numba 可以加速它?
-
检查
numpy函数(或方法)的source。如果是built-in,那就已经编译好了。你自己的numba代码可能不会有太大的改进——除非numba已经实现了它。在转换使用 python 级迭代的代码/函数时,您将获得最大的改进。 -
谢谢。在 Windows 上,Numba 版本可以很好地扩展,但在 Linux 上根本不行:它只有 1.5 倍,使用 6 核......这非常令人惊讶,因为这是一个令人尴尬的并行代码,并且没有一个实现似乎是内存绑定的我的 Linux。我想 Numba 版本肯定有问题(可能与您的代码无关)。我有一个可能的解释,但如果这是我认为的,这将是相当复杂/深刻的。我会检查我的假设。
标签: python numpy numba bit-packing