【发布时间】:2022-01-20 11:04:54
【问题描述】:
这是我的问题。我有两个矩阵A 和B,具有复杂的条目,维度分别为(n,n,m,m) 和(n,n)。
下面是我为得到一个矩阵C而执行的操作-
C = np.sum(B[:,:,None,None]*A, axis=(0,1))
计算一次以上大约需要 6-8 秒。因为我必须计算很多这样的Cs,所以需要很多时间。有没有更快的方法来做到这一点? (我在多核 CPU 上使用 JAX NumPy 来做这些;普通的 NumPy 需要更长的时间)
n=77 和 m=512,如果您想知道的话。我可以在处理集群时进行并行化,但是数组的绝对大小会消耗大量内存。
【问题讨论】:
标签: python arrays numpy linear-algebra jax