最新更新
我调整了下面的代码以使用 numba,对于更大的输入,它的速度提高了 2 倍,并且可以处理更大的输入而不会在我的本地被杀死。
这会将用 numba.njit 修饰的函数转换为机器码,并在输入量很大时发光。由于编译,后续运行总是比第一次运行快。我知道这对于 10k-20k 可能不是必需的,但对于 2**64 可能需要。 (P.S. 我在工作中使用 numba 来提高速度refer here to see how it can speed up)
使用 numba
import numpy as np
import numba as nb
ojb_A = np.array(ojb_A)
ojb_B = np.array(ojb_B)
cnt_A = np.array(cnt_A)
cnt_B = np.array(cnt_B)
@nb.njit
def count(out, ojb, cnt, i):
for index in nb.prange(ojb.shape[0]):
out[ojb[i]] = out.get(ojb[i], 0) + cnt[index]
def split_out(out):
return list(out.keys()), list(out.values())
out = nb.typed.Dict.empty(
key_type=nb.core.types.int64,
value_type=nb.core.types.int64,
)
count(out, ojb_A, cnt_A, 0)
count(out, ojb_B, cnt_B, 1)
ojb_C, cnt_C = split_out(out)
new_test.py
import random
import sys
import time
import numba as nb
random.seed(31212223)
MAX = 2 ** int(sys.argv[2])
def f3():
ojb_A = random.sample(range(1, 2**30), MAX)
cnt_A = random.sample(range(1, 2**30), MAX)
ojb_B = random.sample(range(1, 2**30), MAX)
cnt_B = random.sample(range(1, 2**30), MAX)
s1 = time.time()
from collections import defaultdict
# @nb.jit
def count(out, ojb, cnt):
for index,obj in enumerate(ojb):
out[obj] += cnt[index]
def split_out(out):
return list(out.keys()), list(out.values())
out = defaultdict(int)
count(out, ojb_A, cnt_A)
count(out, ojb_B, cnt_B)
ojb_C, cnt_C = split_out(out)
# print(ojb_C, cnt_C)
s2 = time.time()
print('quamrana', s2 - s1)
def f3_1():
ojb_A = random.sample(range(1, 2**30), MAX)
cnt_A = random.sample(range(1, 2**30), MAX)
ojb_B = random.sample(range(1, 2**30), MAX)
cnt_B = random.sample(range(1, 2**30), MAX)
s1 = time.time()
import numpy as np
import numba as nb
ojb_A = np.array(ojb_A)
ojb_B = np.array(ojb_B)
cnt_A = np.array(cnt_A)
cnt_B = np.array(cnt_B)
@nb.njit
def count(out, ojb, cnt, i):
for index in nb.prange(ojb.shape[0]):
out[ojb[i]] = out.get(ojb[i], 0) + cnt[index]
def split_out(out):
return list(out.keys()), list(out.values())
out = nb.typed.Dict.empty(
key_type=nb.core.types.int64,
value_type=nb.core.types.int64,
)
count(out, ojb_A, cnt_A, 0)
count(out, ojb_B, cnt_B, 1)
ojb_C, cnt_C = split_out(out)
# print(ojb_C, cnt_C)
s2 = time.time()
print('eroot163pi', s2 - s1)
if __name__ == '__main__':
# Two runs to show subsequent run is faster
eval(f'{sys.argv[1]}()')
eval(f'{sys.argv[1]}()')
在 2^22-27 的庞大案例上运行
- Numba 进程成功完成 2^26,但正常的 python 被杀死
- Numba 在这些输入方面始终更快,并且能够处理受 2^27 限制的更大输入
(base) xxx@xxx:~$ python test.py f3 22
quamrana 3.2768890857696533
quamrana 3.2760112285614014
(base) xxx@xxx:~$ python test.py f3_1 22
eroot163pi 2.4150922298431396
eroot163pi 1.8658664226531982
(base) xxx@xxx:~$ python test.py f3 23
quamrana 6.903605937957764
quamrana 7.187666654586792
(base) xxx@xxx:~$ python test.py f3_1 23
eroot163pi 4.326314926147461
eroot163pi 3.6970062255859375
(base) xxx@xxx:~$ python test.py f3 24
quamrana 14.135217905044556
quamrana 14.102455615997314
(base) xxx@xxx:~$ python test.py f3_1 24
eroot163pi 8.097218751907349
eroot163pi 7.514840602874756
(base) xxx@xxx:~$ python test.py f3 25
quamrana 29.825793743133545
quamrana 30.300193786621094
(base) xxx@xxx:~$ python test.py f3_1 25
eroot163pi 16.243808031082153
eroot163pi 15.114825010299683
(base) xxx@xxx:~$ python test.py f3 26
Killed
(base) xxx@xxx:~$ python test.py f3_1 26
eroot163pi 35.73880386352539
eroot163pi 34.74332834847338
(base) xxx@xxx:~$ python test.py f3_1 27
Killed
其他方法的表现
这是上述一些方法在庞大的测试用例上的比较。因为我自己对性能很好奇,所以 quamrana 的表现非常好。
对于大于 2**26 的数组长度,上述所有解决方案尝试均被终止。 (由于我自己的系统限制,16gb,12core)
对于其余部分,quamrana 的解决方案始终表现最佳,Kapil 的解决方案排名第二。每个解决方案都提供几乎 2-3 倍于下一个解决方案始终
test.py
import random
import sys
import time
random.seed(31212223)
MAX = 2 ** int(sys.argv[2])
ojb_A = random.sample(range(1, 2**30), MAX)
cnt_A = random.sample(range(1, 2**30), MAX)
ojb_B = random.sample(range(1, 2**30), MAX)
cnt_B = random.sample(range(1, 2**30), MAX)
def f1():
s1 = time.time()
import pandas as pd
df1 = pd.DataFrame(zip(ojb_A,cnt_A))
df2 = pd.DataFrame(zip(ojb_B,cnt_B))
df_combined = pd.concat([df1,df2]).groupby(0).sum()
ojb_C = list(df_combined.index.values)
cnt_C = list(df_combined[1])
s2 = time.time()
print('Kapil', s2 - s1)
def f2():
s1 = time.time()
from collections import Counter
counts = Counter(dict(zip(ojb_A, cnt_A)))
counts.update(dict(zip(ojb_B, cnt_B)))
ojb_C, cnt_C = zip(*counts.items())
s2 = time.time()
print('fsimonjetz', s2 - s1)
def f3():
s1 = time.time()
from collections import defaultdict
def count(out, ojb, cnt):
for index,obj in enumerate(ojb):
out[obj] += cnt[index]
def split_out(out):
return list(out.keys()), list(out.values())
out = defaultdict(int)
count(out, ojb_A, cnt_A)
count(out, ojb_B, cnt_B)
ojb_C, cnt_C = split_out(out)
# print(ojb_C, cnt_C)
s2 = time.time()
print('quamrana', s2 - s1)
if __name__ == '__main__':
eval(f'{sys.argv[1]}()')
在 2^20、2^21、2^22 上的表现
(base) xxx@xxx:~$ python test.py f1 20
Kapil 2.0021448135375977
(base) xxx@xxx:~$ python test.py f2 20
fsimonjetz 2.720785617828369
(base) xxx@xxx:~$ python test.py f3 20
quamrana 0.717628002166748
(base) xxx@xxx:~$ python test.py f1 21
Kapil 4.06165337562561
(base) xxx@xxx:~$ python test.py f2 21
fsimonjetz 6.2198286056518555
(base) xxx@xxx:~$ python test.py f3 21
quamrana 1.563591718673706
(base) xxx@xxx:~$ python test.py f1 22
Kapil 8.361187934875488
(base) xxx@xxx:~$ python test.py f2 22
fsimonjetz 14.354418992996216
(base) xxx@xxx:~$ python test.py f3 22
quamrana 3.355391025543213