【问题标题】:Improve performance of python script using numba jit使用 numba jit 提高 python 脚本的性能
【发布时间】:2020-04-08 20:49:27
【问题描述】:

我正在运行一个示例 python 模拟来预测加权和常规骰子。我想使用 numba 来帮助加速我的脚本,但我收到一个错误:

<timed exec>:6: NumbaWarning: 
Compilation is falling back to object mode WITH looplifting enabled because Function "roll" failed type inference due to: Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>

File "<timed exec>", line 9:
<source missing, REPL/exec in use?>

这是我的原始代码:我可以使用另一种类型的 numba 表达式吗?现在我正在使用 2500 卷的输入进行测试;想把这个时间缩短到 4 秒(目前是 8.5 秒)。

%%time
from numba import jit
import random
import matplotlib.pyplot as plt
import numpy

@jit
def roll(sides, bias_list):
    assert len(bias_list) == sides, "Enter correct number of dice sides"
    number = random.uniform(0, sum(bias_list))
    current = 0
    for i, bias in enumerate(bias_list):
        current += bias
        if number <= current:
            return i + 1

no_of_rolls = 2500
weighted_die = {}
normal_die = {}
#weighted die

for i in range(no_of_rolls):
        weighted_die[i+1]=roll(6,(0.15, 0.15, 0.15, 0.15, 0.15, 0.25))
#regular die  
for i in range(no_of_rolls):
        normal_die[i+1]=roll(6,(0.167, 0.167, 0.167, 0.167, 0.167, 0.165))

plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()

【问题讨论】:

  • 如果您只是想要更快的模拟,您可以使用random.choices 生成 weighted_die 和 normal_die 的列表。
  • 你能举个例子吗?随意使用我的代码并随意更改。
  • @LeeWhieldon——提供了一个代码示例。但是,不明白您为什么说代码需要 8.5 秒才能运行。我得到了几毫秒的时间(不包括绘图)。
  • @DarrylG,谢谢!这确实提高了性能。欣赏!

标签: python numba


【解决方案1】:

使用随机选择

重构代码

import random
import matplotlib.pyplot as plt

no_of_rolls = 2500

# weights
normal_weights = (0.167, 0.167, 0.167, 0.167, 0.167, 0.165)
bias_weights = (0.15, 0.15, 0.15, 0.15, 0.15, 0.25)

# Replaced roll function with random.choices 
# Reference: https://www.w3schools.com/python/ref_random_choices.asp
bias_rolls = random.choices(range(1, 7), weights = bias_weights, k = no_of_rolls)
normal_rolls = random.choices(range(1, 7), weights = normal_weights, k = no_of_rolls)

# Create dictionaries with same structure as posted code
weighted_die = dict(zip(range(no_of_rolls), bias_rolls))
normal_die = dict(zip(range(no_of_rolls), normal_rolls))

# Use posted plotting calls
plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()

性能

*Not including plotting.*
Original code: ~6 ms
Revised code:  ~2 ms
(3x improvement, but not sure why the post mentions 8 seconds to run)

【讨论】:

  • 谢谢!这确实加快了我的代码速度。我的电脑一定没有你的那么强大,但不管更新是否有帮助。欣赏!
  • @LeeWhieldon--嗯,仍然令人惊讶,因为我的电脑是一台 8 年前的旧台式机,配备 i7 CPU 920 @ 2.67 GHz,存在间歇性硬件问题,因此确实需要更换。
  • 我有一个 i7 CPU @ 1.80GHz。这就是我想象的原因:)
【解决方案2】:

您可以使用 guvectorize 加速它

%%time
from numba import guvectorize
import matplotlib.pyplot as plt
import numpy as np
import random

sides = 6
bias_list = (0.15, 0.15, 0.15, 0.15, 0.15, 0.25)

@guvectorize(["f8[:,:], uint8[:]"], "(n, k) -> (n)", nopython=True)
def roll(biases, side):
    for i in range(biases.shape[0]):
        number = random.uniform(0, np.sum(biases[i,:]))
        current = 0
        for j, bias in enumerate(biases[i,:]):
            current += bias
            if number <= current:
                side[i] = j + 1
                break

no_of_rolls = 2500
biases = np.zeros((no_of_rolls,len(bias_list)))

biases[:,] = np.array(bias_list)

normal_die = roll(biases)

print(normal_die)

这在我的 PC 上花费了大约 200 毫秒,而您的代码大约需要 6 秒。

【讨论】:

  • 我运行了您建议的代码,但在我的机器上似乎并没有给我太多的速度(实际上要多 1 秒)。
  • 对。刚刚意识到绘图需要很多时间。
猜你喜欢
  • 2022-01-24
  • 1970-01-01
  • 2014-07-17
  • 2010-11-09
  • 1970-01-01
  • 2015-05-12
  • 2015-05-08
  • 2015-06-15
  • 2020-02-29
相关资源
最近更新 更多