您必须使您的问题更加精确,其中还包括您已经尝试过的工作代码示例。例如,不清楚为什么在这种张量收缩中使用字典。字典查找看起来对于这个计算来说是一件很麻烦的事情,但也许我没有明白你真正想要做什么。
张量收缩实际上在 Python (Numpy) 中很容易实现,有一些方法可以找到收缩张量的最佳方法,而且它们真的很容易使用 (np.einsum)。
创建一些数据(这应该是问题的一部分)
import numpy as np
import time
i=20
j=20
k=20
l=20
m=20
n=20
p=20
q=20
#I don't know what complex 2 means, I assume it is complex128 (real and imaginary part are in float64)
#size of all arrays is 1.6e5
Sum_=np.random.rand(m,n,p,q).astype(np.complex128)
S_=np.random.rand(i,j,m,n).astype(np.complex128)
B_=np.random.rand(m,n,p,q).astype(np.complex128)
T_=np.random.rand(p,q,k,l).astype(np.complex128)
天真的方式
此代码与使用 Cython 或 Numba 在循环中编写它基本相同,无需调用 BLAS 例程 (ZGEMM) 或优化收缩顺序 -> 8 个嵌套循环来完成这项工作。
t1=time.time()
A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_)
print(time.time()-t1)
这导致运行时间非常慢,大约 330 秒。
如何将速度提高 7700 倍
%timeit A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
#42.9 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
为什么这么快?
让我们看看收缩路径和内部结构。
path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
print(path[1])
# Complete contraction: mnpq,ijmn,mnpq,pqkl->ijkl
# Naive scaling: 8
# Optimized scaling: 6
# Naive FLOP count: 1.024e+11
# Optimized FLOP count: 2.562e+08
# Theoretical speedup: 399.750
# Largest intermediate: 1.600e+05 elements
#--------------------------------------------------------------------------
#scaling current remaining
#--------------------------------------------------------------------------
# 4 mnpq,mnpq->mnpq ijmn,pqkl,mnpq->ijkl
# 6 mnpq,ijmn->ijpq pqkl,ijpq->ijkl
# 6 ijpq,pqkl->ijkl ijkl->ijkl
和
path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal",einsum_call=True)
print(path[1])
#[((2, 0), set(), 'mnpq,mnpq->mnpq', ['ijmn', 'pqkl', 'mnpq'], False), ((2, 0), {'n', 'm'}, 'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0), {'p', 'q'}, 'ijpq,pqkl->ijkl', ['ijkl'], True)]
在多个精心选择的步骤中进行收缩可将所需的触发器减少 400 倍。但这并不是 einsum 在这里所做的唯一事情。看看'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0) True 代表 BLAS 收缩 -> tensordot call ->(矩阵矩阵乘法)。
在内部看起来基本上如下:
#consider X as a 4th order tensor {mnpq}
#consider Y as a 4th order tensor {ijmn}
X_=X.reshape(m*n,p*q) #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
Y_=Y.reshape(i*j,m*n) #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
res=np.dot(Y_,X_) #-> dot is just a wrapper for highly optimized BLAS functions, in case of complex128 ZGEMM
output=res.reshape(i,j,p,q) #-> just another view on the data (4D), costs almost nothing (no copy, just a view)