查看您分享的链接,有很多有趣的解决方案。我特别受到this 的启发并改变了一些事情。回顾一下,您需要一个尾递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。还有另一个有趣的discussion 是关于尾递归上下文中的相互递归的,它可能会帮助您理解主要问题。
我已经编写了一个同时进行缓存和相互递归的装饰器:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:
from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
乍一看似乎相当复杂,但它重用了链接中讨论的一些概念。
初始化
f._first_call = True
f._cache = {}
除了START、CONTINUE 和RETURN 之类的状态,在这种情况下,我只需要区分_first_call 和以下状态即可。实际上,第一次调用函数后,下一次调用会返回一个存储参数的TailRecArgument。
f._cache 是该特定功能的缓存。
尾递归
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
这个版本的尾递归如何工作?在while 循环中,在第一次调用装饰函数后,使用返回的新参数不断调用该函数。
我什么时候可以退出循环?一旦返回值不是TailRecArguments类型,这意味着最后一次函数调用并没有递归调用自身,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的解决方法是将调用的函数存储在TailRecArguments 中。通过这种方式,我可以检查用于下一个循环的参数是用于同一函数(result.wrapped_func == f)还是用于另一个尾递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,而是我可以返回它们,因为它们肯定会在遇到的第一个尾递归函数的while 循环中执行。唯一的缺点是每次参数属于另一个函数时,f._first_call 都会重置。
缓存
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我将它放在while 循环中。不可能,因为只有在 while 循环内,函数才会被持续调用,我可以检查缓存命中。
我在创建cache_key 时有点作弊,因为我使用了functools 模块的内部函数。它是同一模块中@cache 装饰器使用的那个,您可以使用
import inspect
import functools
print(inspect.getsource(functools._make_key))
还有其他方法可以从*args 和**kwargs 创建缓存键,例如this one,这又指向_make_key 的实现。为了让你的代码更稳定,当然要避免使用私有成员。
正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...。我想缓存值,而不是尾递归调用的参数。
(实际上,我认为您可以将所有TailRecArguments 临时存储在一个列表中,并在递归调用返回实际值时在缓存中添加与该列表大小一样多的条目。这会使解决方案复杂化,但如果您有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我将继续处理。
测试
这些是我用来测试装饰器的几个基本函数:
@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()
请注意,在这些示例中缓存不是很有用,以阶乘为例:fact(10) 实际上永远不会使用fact(8)
fact(8) |
fact(10) |
|
fact(10, 1) |
|
fact(9, 10) |
| fact(8, 1) |
fact(8, 90) |
| ... |
... |
累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。
更新 - 缓存优化
这是对原始答案中使用的缓存策略的部分修复。主要问题是考虑到通用尾递归算法的工作原理,在缓存键中包含所有参数效率低下(参见阶乘示例)。
第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要差得多,但测试让一切都变得更加清晰:
class Logger:
def __init__(self, name):
self._name = name
self._entries = []
def log(self, s):
self._entries.append(s)
def print(self):
log_prefix = f"[{self._name}] - "
print(log_prefix + f"\n{log_prefix}".join(self._entries))
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
functools._make_key(args, kwargs, False),\
get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
value):
def decorator(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
while True:
cache_key = get_cache_key(f._new_args, f._new_kwargs)
if cache_key in f._cache:
logger.log('cache hit for ' + str(cache_key))
return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[f._initial_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
return decorator
除了用于确认缓存命中的Logger 类之外,主要区别在于每个函数现在都有一个名为_initial_key 的新成员,它存储第一次调用的键。这样,如果我调用fact(5),5就变成了_initial_key,结果放到f._cache[5]中。
这可以优化相互递归和尾递归函数,但在某些情况下无效。让我们从最好的情况开始:
fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(5)
120
>>> fact(30)
265252859812191058636308480000000
>>> fact_logger.print()
[fact] - cache hit for 5
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive 装饰器初始化包括(记录器)get_cache_key,它指定只有第一个参数 n 应该是缓存键的一部分,get_result_after_cache_hit 指定如何在一个缓存命中。在上述情况下,当fact(30) 达到fact(5, <partial_factorial>) 时,结果立即计算为<partial_factorial> * f._cache[5]。
even-odd 也是如此,只是在这种情况下tail_recursive 的默认参数绰绰有余:
even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
>>> even(104)
True
>>> even_logger.print()
[even] - cache hit for 100
"""
return True if n == 0 else odd(n - 1)
不幸的是,这不适用于例如斐波那契函数。您应该通过在每次调用期间打印参数来轻松说服自己,结果如下:
30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...
建立缓存键规则需要一个更复杂的逻辑,这可能会使tail_recursive 装饰器变得非常不可读且不易移植。