【发布时间】:2017-12-13 04:10:24
【问题描述】:
我正在尝试计算 numpy 数组中包含的多个时间序列之间的成对距离。请看下面的代码
print(type(sales))
print(sales.shape)
<class 'numpy.ndarray'>
(687, 157)
所以,sales 包含 687 个长度为 157 的时间序列。使用 pdist 计算时间序列之间的 DTW 距离。
import fastdtw
import scipy.spatial.distance as sd
def my_fastdtw(sales1, sales2):
return fastdtw.fastdtw(sales1,sales2)[0]
distance_matrix = sd.pdist(sales, my_fastdtw)
---编辑:尝试不使用pdist()-----
distance_matrix = []
m = len(sales)
for i in range(0, m - 1):
for j in range(i + 1, m):
distance_matrix.append(fastdtw.fastdtw(sales[i], sales[j]))
---编辑:并行化内部for循环-----
from joblib import Parallel, delayed
import multiprocessing
import fastdtw
num_cores = multiprocessing.cpu_count() - 1
N = 687
def my_fastdtw(sales1, sales2):
return fastdtw.fastdtw(sales1,sales2)[0]
results = [[] for i in range(N)]
for i in range(0, N- 1):
results[i] = Parallel(n_jobs=num_cores)(delayed(my_fastdtw) (sales[i],sales[j]) for j in range(i + 1, N) )
所有方法都很慢。并行方法大约需要 12 分钟。有人可以建议一种有效的方法吗?
---编辑:按照下面答案中提到的步骤---
lib 文件夹如下所示:
VirtualBox:~/anaconda3/lib/python3.6/site-packages/fastdtw-0.3.2-py3.6- linux-x86_64.egg/fastdtw$ ls
_fastdtw.cpython-36m-x86_64-linux-gnu.so fastdtw.py __pycache__
_fastdtw.py __init__.py
所以,那里有一个 cython 版本的 fastdtw。安装时,我没有收到任何错误。即使是现在,当我在程序执行过程中按下CTRL-C时,我可以看到正在使用纯python版本(fastdtw.py):
/home/vishal/anaconda3/lib/python3.6/site-packages/fastdtw/fastdtw.py in fastdtw(x, y, radius, dist)
/home/vishal/anaconda3/lib/python3.6/site-packages/fastdtw/fastdtw.py in __fastdtw(x, y, radius, dist)
代码仍然像以前一样缓慢。
【问题讨论】:
-
阅读
pdist所说的关于提供自己的函数的内容。注意它调用了多少次。fastdtw产生什么?dm中的内容是什么?我认为pdist期望距离函数有一个简单的数字。 -
@hpaulj,你是对的,每次调用
fastdtw都会产生一个float,这是pdist需要的距离,它还返回一个路径。请参阅我更新的帖子。 -
看起来
pdist是在给定 Python 函数时进行相同类型的迭代。只有在使用它自己的编译指标之一时它才会更快。任何速度提升都必须来自fastdtw端。
标签: python numpy cython joblib