【问题标题】:Efficient tensor contraction with Python使用 Python 进行高效的张量收缩
【发布时间】:2020-06-15 18:58:02
【问题描述】:

我有一段涉及张量收缩的瓶颈计算代码。假设我想计算一个张量 A_{i,j,k,l}( X ),其单个 x\in X 的非零条目为 N ~ 10^5,并且 X 表示具有 M 个总点的网格, M~1000 左右。对于张量 A 的单个元素,方程的 rhs 类似于:

A_{ijkl}(M) = Sum_{m,n,p,q} S_{i,j, m,n }(M) B_{m,n,p,q}(M) T_{ p ,q,k,l }(M)

另外,中间张量B_{m,n,p,q}(M)是通过数组的数值卷积得到的,使得:

B_{m,n,p,q}(M) = ( L_{m,n} * F_{p,q} )(M)

其中“*”是卷积算子,所有张量的元素数量都与 A 大致相同。我的问题与求和的效率有关;考虑到问题的复杂性,计算 A 的单个 rhs 需要很长时间。我有一个“键”系统,其中每个张量元素都通过从字典中获取的唯一键组合(例如 T 的 ( p,q,k,l ) )访问。然后该特定键的字典提供与该键关联的 Numpy 数组以执行操作,并且所有操作(卷积、乘法...)都使用 Numpy 完成。我已经看到最耗时的部分实际上是由于嵌套循环(我循环遍历 A 张量的所有键(i,j,k,l),并且对于每个键,需要像上面那样的 rhs计算)。有没有有效的方法来做到这一点?考虑一下:

1) 使用简单的 4 +1 D numpy 数组会导致高内存使用,因为所有张量都是复杂类型 2)我尝试了几种方法:Numba 在使用字典时非常有限,并且目前不支持我需要的一些重要的 Numpy 功能。例如,numpy.convolve() 仅采用前 2 个参数,但不采用“模式”参数,这在这种情况下大大减少了所需的卷积间隔,我不需要卷积的“完整”输出

3) 我最近的方法是尝试使用 Cython 来实现这部分的所有内容......但是考虑到代码的逻辑,这非常耗时并且更容易出错。

关于如何使用 Python 处理这种复杂性的任何想法?

谢谢!

【问题讨论】:

标签: python complexity-theory tensor


【解决方案1】:

您必须使您的问题更加精确,其中还包括您已经尝试过的工作代码示例。例如,不清楚为什么在这种张量收缩中使用字典。字典查找看起来对于这个计算来说是一件很麻烦的事情,但也许我没有明白你真正想要做什么。

张量收缩实际上在 Python (Numpy) 中很容易实现,有一些方法可以找到收缩张量的最佳方法,而且它们真的很容易使用 (np.einsum)。

创建一些数据(这应该是问题的一部分)

import numpy as np
import time

i=20
j=20
k=20
l=20

m=20
n=20
p=20
q=20

#I don't know what complex 2 means, I assume it is complex128 (real and imaginary part are in float64)

#size of all arrays is 1.6e5
Sum_=np.random.rand(m,n,p,q).astype(np.complex128)
S_=np.random.rand(i,j,m,n).astype(np.complex128)
B_=np.random.rand(m,n,p,q).astype(np.complex128)
T_=np.random.rand(p,q,k,l).astype(np.complex128)

天真的方式

此代码与使用 Cython 或 Numba 在循环中编写它基本相同,无需调用 BLAS 例程 (ZGEMM) 或优化收缩顺序 -> 8 个嵌套循环来完成这项工作。

t1=time.time()
A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_)
print(time.time()-t1)

这导致运行时间非常慢,大约 330 秒。

如何将速度提高 7700 倍

%timeit A=np.einsum("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
#42.9 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

为什么这么快?

让我们看看收缩路径和内部结构。

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal")
print(path[1])

    #  Complete contraction:  mnpq,ijmn,mnpq,pqkl->ijkl
#         Naive scaling:  8
#     Optimized scaling:  6
#      Naive FLOP count:  1.024e+11
#  Optimized FLOP count:  2.562e+08
#   Theoretical speedup:  399.750
#  Largest intermediate:  1.600e+05 elements
#--------------------------------------------------------------------------
#scaling                  current                                remaining
#--------------------------------------------------------------------------
#   4             mnpq,mnpq->mnpq                     ijmn,pqkl,mnpq->ijkl
#   6             mnpq,ijmn->ijpq                          pqkl,ijpq->ijkl
#   6             ijpq,pqkl->ijkl                               ijkl->ijkl

path=np.einsum_path("mnpq,ijmn,mnpq,pqkl",Sum_,S_,B_,T_,optimize="optimal",einsum_call=True)
print(path[1])

#[((2, 0), set(), 'mnpq,mnpq->mnpq', ['ijmn', 'pqkl', 'mnpq'], False), ((2, 0), {'n', 'm'}, 'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0), {'p', 'q'}, 'ijpq,pqkl->ijkl', ['ijkl'], True)]

在多个精心选择的步骤中进行收缩可将所需的触发器减少 400 倍。但这并不是 einsum 在这里所做的唯一事情。看看'mnpq,ijmn->ijpq', ['pqkl', 'ijpq'], True), ((1, 0) True 代表 BLAS 收缩 -> tensordot call ->(矩阵矩阵乘法)。

在内部看起来基本上如下:

#consider X as a 4th order tensor {mnpq}
#consider Y as a 4th order tensor {ijmn}

X_=X.reshape(m*n,p*q)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
Y_=Y.reshape(i*j,m*n)       #-> just another view on the data (2D), costs almost nothing (no copy, just a view)
res=np.dot(Y_,X_)           #-> dot is just a wrapper for highly optimized BLAS functions, in case of complex128 ZGEMM
output=res.reshape(i,j,p,q) #-> just another view on the data (4D), costs almost nothing (no copy, just a view)

【讨论】:

  • 非常感谢您的详细回复。不幸的是,数组是 5D,这意味着在您的具体示例中,缺少一个额外的维度,为此我遇到了内存问题。这个额外的维度总共包含约 1000 个点,因此张量的大小很快就会在内存中变得无法管理。进行外循环的原因有两个:首先,因为我遇到的内存问题,其次,因为考虑的张量是“稀疏的”,这意味着我只考虑非零条目。我读过 Scipy 为 2D 提供稀疏矩阵支持,但是 d>2D 呢?
  • 在大多数情况下,优化的第一部分是避免使用较大的临时数组。也不清楚数据实际上有多稀疏。每种编程语言(C、Fortran、...)都是如此。如果问题足够大,你永远不会用具有次优复杂性的算法赢得比赛。处理器上的大量电路用于缓存数据也是有原因的。简而言之,提供一个简单但相关的示例。如果数据足够小以适合较低级别的缓存甚至处理器寄存器,则您可以在一秒钟内对近 TB 的数据进行操作。
猜你喜欢
  • 1970-01-01
  • 2016-10-29
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2022-08-18
  • 2016-06-20
  • 2010-12-12
相关资源
最近更新 更多