【发布时间】:2020-03-18 21:26:14
【问题描述】:
我想实现numpy.random.choice(除了它的replace参数)来看看它是如何工作的。
from random import uniform
from math import fsum
def select(array, total_count, probability):
probability_accumulative = []
last_element = 0
for i in range(len(probability)):
probability_accumulative.append(last_element + probability[i])
last_element = probability_accumulative[i]
result = []
if(len(array) != len(probability)):
raise ValueError("array and probability must have the same size.")
elif(fsum(probability) != 1.0):
raise ValueError("probabilities do not sum to 1.")
else:
for i in range(total_count):
rand = uniform(0, 1)
for j in range(len(probability_accumulative)):
if(rand < probability_accumulative[j]):
result.append(array[j])
break
return result
它似乎工作得很好,所以我决定编写另一个脚本来检查我的实现比numpy.random.choice慢了多少。
This is the test script I wrote for it:
from random_selection import select
from collections import Counter
from numpy.random import choice
from time import time
def test(array, total_count, probability, method):
methods = {
"numpy.random.choice": choice(array, total_count, p=probability),
"random_selection.select": select(array, total_count, probability)
}
if(method in methods):
probability_dict = {}
rand_items = methods[method]
items_counter = Counter(rand_items)
for item, count in items_counter.most_common():
probability_dict[item] = f"{100 * count / total_count:.1f}%"
return probability_dict
else:
raise ValueError(f"Method {method} has not been defined.")
def main():
total_count = 1000000
array = ['a', 'b', 'c', 'd']
probability = [0.7, 0.1, 0.1, 0.1]
print(f"array: {array}")
print(f"probability: {probability}")
print(f"size: {total_count}")
print()
print('random_selection.select: ')
start_time = time()
result = test(array, total_count, probability, 'random_selection.select')
end_time = time()
print(result)
print(f"{(end_time - start_time):.4f} s")
print()
print('numpy.random.choice: ')
start_time = time()
result = test(array, total_count, probability, 'numpy.random.choice')
end_time = time()
print(result)
print(f"{(end_time - start_time):.4f} s")
if __name__ == "__main__":
main()
我很惊讶我的实现速度更快!
这是一百万个数组大小的结果:
array: ['a', 'b', 'c', 'd']
probability: [0.7, 0.1, 0.1, 0.1]
size: 1000000
random_selection.select:
{'a': '70.0%', 'c': '10.0%', 'd': '10.0%', 'b': '10.0%'}
2.5119 s
numpy.random.choice:
{'a': '70.0%', 'b': '10.0%', 'd': '10.0%', 'c': '10.0%'}
3.1098 s
如果我将大小增加到 1000 万,差异会变得更加明显:
array: ['a', 'b', 'c', 'd']
probability: [0.7, 0.1, 0.1, 0.1]
size: 10000000
random_selection.select:
{'a': '70.0%', 'b': '10.0%', 'd': '10.0%', 'c': '10.0%'}
25.6174 s
numpy.random.choice:
{'a': '70.0%', 'b': '10.0%', 'c': '10.0%', 'd': '10.0%'}
31.8087 s
这是为什么呢?
【问题讨论】:
-
好吧,对于初学者来说,
numpy用于处理 arrays 而不是list对象。 -
@juanpa.arrivillaga 在 Python 中不是 arrays represented as lists 吗?我相信我理解数组和列表之间的区别;
list可以有混合类型的元素,但array不能,对吧?但是我们在 Python 中没有内置数组。必须导入array模块。我不明白这如何使这个实现更快。 -
在
numpy的上下文中,说“一个数组”意味着numpy.ndarray的一个实例。您也可以更随意地使用术语“数组”,从这个意义上说,Python 列表很像数组。但它们不是 numpy 数组。 -
不,数组不表示为列表。当人们谈论 Python 中的数组时,他们通常谈论的是
numpy.ndarray对象,或者来自array内置模块的数组。它们都不同于list对象。但我认为有人在你如何分析它时发现了一个问题。