累加和平分
在任何一般情况下,似乎建议计算权重的累积总和,并使用 bisect 模块中的 bisect 在生成的排序数组中找到一个随机点
def weighted_choice(weights):
cs = numpy.cumsum(weights)
return bisect.bisect(cs, numpy.random.random() * cs[-1])
如果速度是一个问题。下面给出更详细的分析。
注意:如果数组不是平面的,numpy.unravel_index 可以用来将平面索引转换为整形索引,见https://stackoverflow.com/a/19760118/1274613
实验分析
使用numpy 内置函数有四种或多或少明显的解决方案。使用timeit 比较它们会得到以下结果:
import timeit
weighted_choice_functions = [
"""import numpy
wc = lambda weights: numpy.random.choice(
range(len(weights)),
p=weights/weights.sum())
""",
"""import numpy
# Adapted from https://stackoverflow.com/a/19760118/1274613
def wc(weights):
cs = numpy.cumsum(weights)
return cs.searchsorted(numpy.random.random() * cs[-1], 'right')
""",
"""import numpy, bisect
# Using bisect mentioned in https://stackoverflow.com/a/13052108/1274613
def wc(weights):
cs = numpy.cumsum(weights)
return bisect.bisect(cs, numpy.random.random() * cs[-1])
""",
"""import numpy
wc = lambda weights: numpy.random.multinomial(
1,
weights/weights.sum()).argmax()
"""]
for setup in weighted_choice_functions:
for ps in ["numpy.ones(40)",
"numpy.arange(10)",
"numpy.arange(200)",
"numpy.arange(199,-1,-1)",
"numpy.arange(4000)"]:
timeit.timeit("wc(%s)"%ps, setup=setup)
print()
结果输出是
178.45797914802097
161.72161589498864
223.53492237901082
224.80936180002755
1901.6298267539823
15.197789980040397
19.985687876993325
20.795070077001583
20.919113760988694
41.6509403079981
14.240949985047337
17.335801470966544
19.433710905024782
19.52205040602712
35.60536142199999
26.6195822560112
20.501282756973524
31.271995796996634
27.20013752405066
243.09768892999273
这意味着 numpy.random.choice 非常慢,甚至专用的 numpy searchsorted 方法也比 type-naive bisect 变体慢。 (这些结果是使用 Python 3.3.5 和 numpy 1.8.1 获得的,因此对于其他版本可能会有所不同。)基于 numpy.random.multinomial 的函数对于大权重的效率低于基于累积求和的方法。据推测,argmax 必须遍历整个数组并在每一步运行比较这一事实起着重要作用,这也可以从增加和减少权重列表之间的 4 秒差异中看出。