问题在于,对于f 和n > 2 的每次调用,都会导致对f 的三个额外调用。例如,如果我们调用f(5),我们会收到以下调用:
- f(5)
- f(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- f(2)
- f(1)
- g(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- f(2)
- g(5)
因此,我们拨打了 1 次 f(5)、1 次 f(4)、2 次 f(3)、4 次 f(2)、3 次 f(1) 和 2 次 f(0) 电话。
p>
由于我们多次调用例如f(3),因此意味着每次都会消耗资源,特别是因为f(3) 本身会进行额外调用。
我们可以让 Python 存储函数调用的结果,并返回结果,例如使用lru_cache [Python-doc]。这种技术称为记忆化:
from functools import lru_cache
def g(n):
return n * n * (n+1)
@lru_cache(maxsize=32)
def f(n):
if n <= 2:
return (p, q, r)[n]
else:
return a*f(n-1) + b*f(n-2) + c*f(n-3) + g(n)
这将产生如下调用图:
- f(5)
- f(4)
- f(3)
- f(2)
- f(1)
- f(0)
- g(3)
- g(4)
- g(5)
所以现在我们只计算一次f(3),lru_cache 会将它存储在缓存中,如果我们第二次调用f(3),我们将永远不会计算f(3) 本身,缓存将返回预先计算的值。
不过这里可以优化一下,因为我们每次调用f(n-1)、f(n-2)和f(n-3),我们只需要存储最后三个值,每次根据最后三个计算下一个值值,并移动变量,例如:
def f(n):
if n <= 2:
return (p, q, r)[n]
f3, f2, f1 = p, q, r
for i in range(3, n+1):
f3, f2, f1 = f2, f1, a * f1 + b * f2 + c * f3 + g(i)
return f1