【发布时间】:2018-12-20 00:29:16
【问题描述】:
作为我正在编写的项目的一部分,我正在生成很多很多列联表。
工作流程是:
- 获取具有连续(浮点)行的大型数据数组,并通过分箱将其转换为离散整数值(例如,结果行的值是 0-9)
- 将两行切成向量 X 和 Y 并从中生成contingency table,这样我就有了二维频率分布
- 例如,我有一个 10 x 10 的数组,计算出现的 (xi, yi) 的数量
- 使用列联表做一些信息论数学
最初,我是这样写的:
def make_table(x, y, num_bins):
ctable = np.zeros((num_bins, num_bins), dtype=np.dtype(int))
for xn, yn in zip(x, y):
ctable[xn, yn] += 1
return ctable
这很好用,但是太慢了,占用了整个项目 90% 的运行时间。
我能想到的最快的纯 python 优化是这样的:
def make_table(x, y, num_bins):
ctable = np.zeros(num_bins ** 2, dtype=np.dtype(int))
reindex = np.dot(np.stack((x, y)).transpose(),
np.array([num_bins, 1]))
idx, count = np.unique(reindex, return_counts=True)
for i, c in zip(idx, count):
ctable[i] = c
return ctable.reshape((num_bins, num_bins))
这(不知何故)要快得多,但对于看起来不应该成为瓶颈的东西来说,它仍然相当昂贵。是否有任何有效的方法可以做到这一点,我只是没有看到,或者我应该放弃并在 cython 中做到这一点?
另外,这里有一个基准函数。
def timetable(func):
size = 5000
bins = 10
repeat = 1000
start = time.time()
for i in range(repeat):
x = np.random.randint(0, bins, size=size)
y = np.random.randint(0, bins, size=size)
func(x, y, bins)
end = time.time()
print("Func {na}: {ti} Ms".format(na=func.__name__, ti=(end - start)))
【问题讨论】:
-
除了 Cython,您可能还想考虑 Numba (numba.pydata.org) - 哪个做得更好会有所不同,但 Numba 可能更容易启动和运行。跨度>
标签: python numpy information-theory