【问题标题】:Why is my implementation of numpy.random.choice faster?为什么我的 numpy.random.choice 实现更快?
【发布时间】:2020-03-18 21:26:14
【问题描述】:

我想实现numpy.random.choice(除了它的replace参数)来看看它是如何工作的。

This is what I came up with:

from random import uniform
from math import fsum

def select(array, total_count, probability):
    probability_accumulative = []
    last_element = 0
    for i in range(len(probability)):
        probability_accumulative.append(last_element + probability[i])
        last_element = probability_accumulative[i]

    result = []

    if(len(array) != len(probability)):
        raise ValueError("array and probability must have the same size.")
    elif(fsum(probability) != 1.0):
        raise ValueError("probabilities do not sum to 1.")
    else:
        for i in range(total_count):
            rand = uniform(0, 1)
            for j in range(len(probability_accumulative)):
                if(rand < probability_accumulative[j]):
                    result.append(array[j])
                    break

    return result

它似乎工作得很好,所以我决定编写另一个脚本来检查我的实现比numpy.random.choice慢了多少。

This is the test script I wrote for it:

from random_selection import select
from collections import Counter
from numpy.random import choice
from time import time

def test(array, total_count, probability, method):
    methods = {
        "numpy.random.choice": choice(array, total_count, p=probability),
        "random_selection.select": select(array, total_count, probability)
    }

    if(method in methods):
        probability_dict = {}
        rand_items = methods[method]
        items_counter = Counter(rand_items)

        for item, count in items_counter.most_common():
            probability_dict[item] = f"{100 * count / total_count:.1f}%"
        return probability_dict
    else:
        raise ValueError(f"Method {method} has not been defined.")


def main():
    total_count = 1000000
    array = ['a', 'b', 'c', 'd']
    probability = [0.7, 0.1, 0.1, 0.1]

    print(f"array: {array}")
    print(f"probability: {probability}")
    print(f"size: {total_count}")

    print()

    print('random_selection.select: ')
    start_time = time()
    result = test(array, total_count, probability, 'random_selection.select')
    end_time = time()
    print(result)
    print(f"{(end_time - start_time):.4f} s")

    print()

    print('numpy.random.choice: ')
    start_time = time()
    result = test(array, total_count, probability, 'numpy.random.choice')
    end_time = time()
    print(result)
    print(f"{(end_time - start_time):.4f} s")


if __name__ == "__main__":
    main()

我很惊讶我的实现速度更快!

这是一百万个数组大小的结果:

array: ['a', 'b', 'c', 'd']
probability: [0.7, 0.1, 0.1, 0.1]
size: 1000000

random_selection.select:
{'a': '70.0%', 'c': '10.0%', 'd': '10.0%', 'b': '10.0%'}
2.5119 s

numpy.random.choice:
{'a': '70.0%', 'b': '10.0%', 'd': '10.0%', 'c': '10.0%'}
3.1098 s

如果我将大小增加到 1000 万,差异会变得更加明显:

array: ['a', 'b', 'c', 'd']
probability: [0.7, 0.1, 0.1, 0.1]
size: 10000000

random_selection.select:
{'a': '70.0%', 'b': '10.0%', 'd': '10.0%', 'c': '10.0%'}
25.6174 s

numpy.random.choice:
{'a': '70.0%', 'b': '10.0%', 'c': '10.0%', 'd': '10.0%'}
31.8087 s

这是为什么呢?

【问题讨论】:

  • 好吧,对于初学者来说,numpy 用于处理 arrays 而不是 list 对象。
  • @juanpa.arrivillaga 在 Python 中不是 arrays represented as lists 吗?我相信我理解数组和列表之间的区别; list 可以有混合类型的元素,但 array 不能,对吧?但是我们在 Python 中没有内置数组。必须导入array 模块。我不明白这如何使这个实现更快。
  • numpy 的上下文中,说“一个数组”意味着numpy.ndarray 的一个实例。您也可以更随意地使用术语“数组”,从这个意义上说,Python 列表很像数组。但它们不是 numpy 数组。
  • 不,数组不表示为列表。当人们谈论 Python 中的数组时,他们通常谈论的是numpy.ndarray 对象,或者来自array 内置模块的数组。它们都不同于 list 对象。但我认为有人在你如何分析它时发现了一个问题。

标签: python numpy


【解决方案1】:

您的测试代码没有按照您的预期进行。 test 函数 always 调用两个随机选择函数中的 both。您的计时仅检测您的分析代码在与请求的函数相对应的结果上的性能差异。

问题在于以下几行:

methods = {
    "numpy.random.choice": choice(array, total_count, p=probability),
    "random_selection.select": select(array, total_count, probability)
}

这些无条件调用choiceselect函数,并将返回值放入字典。这几乎肯定不是你所期望的。您可能希望将lambda 函数放入字典中,以便在调用时使用适当的参数调用所需的函数。

【讨论】:

    【解决方案2】:

    这并不奇怪。 Python 内置的随机库比 numpy.random 库更轻量级,因此我们希望基于随机库的简单实现会稍微快一些。有更深入的解释here

    【讨论】:

      【解决方案3】:

      正如大家评论的那样,当您调用列表中的任何numpy 函数时,从listnp.ndarray 的转换一定是一个耗时的过程。如果您直接使用 np.ndarray 对象尝试此操作,不要希望击败像 numpy 这样的基于 Cython 的库。

      【讨论】:

        猜你喜欢
        • 2013-09-08
        • 1970-01-01
        • 2011-05-06
        • 1970-01-01
        • 2021-10-14
        • 2021-10-12
        • 2017-05-17
        • 2014-03-12
        • 2014-03-29
        相关资源
        最近更新 更多