【问题标题】:Can I speed up this basic linear algebra code?我可以加快这个基本的线性代数代码吗?
【发布时间】:2014-03-10 16:13:34
【问题描述】:

我想知道是否可以使用 Numpy 或数学技巧来优化以下内容。

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p

其中g 是长度为n 的向量,bnxn 矩阵,dt 是迭代次数,t1t2 是标量。

我很快就没有关于如何进一步优化的想法,因为p 在循环中使用,在等式的所有三个项中:添加到自身时;在点积中;并在标量乘法中。

但也许有不同的方式来表示这个函数,或者有其他技巧来提高它的效率。如果可能的话,我宁愿不使用Cython 等,但如果速度提升显着,我愿意使用它。提前致谢,如果问题超出范围,我们深表歉意。

更新:

到目前为止提供的答案更侧重于输入/输出的值可能是什么,以避免不必要的操作。我现在已经用变量的适当初始化值更新了 MWE(我没想到优化想法来自那方面——抱歉)。 g 将在[-1, 1] 范围内,b 将在[-infinity, infinity] 范围内。逼近输出不是一个选项,因为返回的向量稍后会被提供给评估函数——对于非常相似的输入,逼近可能会返回相同的向量,所以它不是一个选项。


MWE:

import numpy as np
import timeit

iterations = 10000

setup = """
import numpy as np
n  = 100
g  = np.random.uniform(-1, 1, (n,)) # Updated.
b  = np.random.uniform(-1, 1, (n,n)) # Updated.
dt = 10
t1 = 1
t2 = 1/2

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p
"""

functions = [
  """
    p = f1(g, b, dt, t1, t2)
  """
]

if __name__ == '__main__':
  for function in functions:
    print(function)
    print('Time = {}'.format(timeit.timeit(function, setup=setup,
                                           number=iterations)))

【问题讨论】:

  • 你的意思是长度为 n 的向量,而不是 n-dimensional vector
  • 措辞很好的问题!然而,这将是一个很难加速的事情......这是numba.jitnumba.pydata.org)应该非常适合(Cython)的那种东西。但是,这是一个需要添加的严重依赖项,并且您确实说过,如果可能的话,您宁愿坚持“直接”的 numpy。
  • @JoeKington 谢谢。我记得大约一个月前在这段特定的代码上尝试了Cython,但速度提升可以忽略不计。我想那是因为 (a) 我是 Cython 的新手; (b) 我保留了大部分内容,而不是使用C 循环代替点积。如果它可以显着提高性能,我会很高兴切换到Cython。现在看看numba.jit——谢谢! :-)
  • 如果迭代次数只有 10 次,cython/numba 对你没有多大帮助。如果没有数学技巧,对整体速度的最大贡献将是DGEMM 调用的效率。您目前是否正在使用优化的 BLAS 和 numpy?这当然是假设您将使用比您描述的矩阵更大的矩阵,因为每次调用只需 250us,如图所示。
  • @Ophion 当我使用 Cython 运行测试时,我使用了超过 1000 次迭代,但我没有看到任何显着的改进——但是,我将其归因于我缺乏 Cython 知识,而不是缺乏潜力。另外,是的,我正在使用优化的 BLAS 和 Numpy。矩阵的迭代次数和长度一般会少于 1000 次,但是这个函数被调用超过一百万次。即使只减少 10% 的执行时间也将是一项重大改进。

标签: python performance math numpy


【解决方案1】:

要让代码在没有cythonjit 的情况下运行得更快会非常困难,一些数学技巧可能是更简单的方法。在我看来,如果我们在正 N 中为n 定义一个k(g, b) = f1(g, b, n+1, t1, t2)/f1(g, b, n, t1, t2),那么k 函数的限制应该是t1+t2(还没有确凿的证据,只是一种直觉;它可能也是 E(g)=0 & E(p)=0 的特例。)。对于t1=1t2=0.5k() 似乎很快接近极限,对于N>100,它几乎是1.5 的常数。

所以我认为数值近似方法应该是最简单的方法。

In [81]:

t2=0.5
data=[f1(g, b, i+2, t1, t2)/f1(g, b, i+1, t1, t2) for i in range(1000)]
In [82]:

plt.figure(figsize=(10,5))
plt.plot(data[0], '.-', label='1')
plt.plot(data[4], '.-', label='5')
plt.plot(data[9], '.-', label='10')
plt.plot(data[49], '.-', label='50')
plt.plot(data[99], '.-', label='100')
plt.plot(data[999], '.-', label='1000')
plt.xlim(xmax=120)
plt.legend()
plt.savefig('limit.png')

In [83]:

data[999]
Out[83]:
array([ 1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5])

【讨论】:

  • 这也支持我的回答,因为函数关系本质上是线性递归。然而,毫无价值的是,OP 声明迭代次数约为 10,因此大 N 的限制可能无效(尽管在这种情况下它会有效)。
  • +1!好主意!但是请参阅我更新的帖子以了解预期的输入。为混淆道歉——我只是没想到优化想法会以输入/输出值和近似值的形式出现。你是对的:输出会饱和,但值并不总是足够大以至于发生这种情况 - 事实上,输入以非常小的幅度开始,通常随着函数的每次调用而增加。
【解决方案2】:

我不愿给出这个答案,因为我认为这可能是您提供给我们的输入数据的产物。尽管如此,请注意tanh(x) ~ 1 代表x>>1。你的输入数据,在我运行它的任何时候都有x = np.dot(p,b) >> 1,因此我们可以用f2替换f1

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
      p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p

def f2(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
      p += t1 + t2*p
  return p

print np.allclose(f1(g,b,dt,t1,t2), f2(g,b,dt,t1,t2))

这确实表明这两个函数在数值上是等价的。请注意,f2 是non-homogeneous linear recurrence relation,如果您选择这样做,可以一步解决。

【讨论】:

  • +1!谢谢。但不幸的是,点积的 tanh 不一定近似为 1。请参阅我更新的帖子,了解预期的输入是什么,以及为什么在我的情况下近似不是一个可行的选择。为混乱道歉。
猜你喜欢
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-11-23
  • 2013-09-04
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
相关资源
最近更新 更多