【问题标题】:How to speed up the trampolined cps version fib function and support mutual recursion in python?python中如何加速蹦床cps版fib函数并支持相互递归?
【发布时间】:2022-02-17 19:26:13
【问题描述】:

我尝试为 cps 版本的斐波那契函数实现蹦床。但我不能让它快速(添加缓存)并支持mutual_recursion。

实现代码:

import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable

START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3


@dataclass
class CTX:
    kind: int
    result: Any    # TODO ......
    f: Callable
    args: Optional[list]
    kwargs: Optional[dict]


def trampoline(f):
    ctx = CTX(START, None, None, None, None)

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        nonlocal ctx
        if ctx.kind in (CONTINUE, CONTINUE_END):
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE
            return
        elif ctx.kind == START:
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE

        result = None
        while ctx.kind != RETURN:
            args = ctx.args
            kwargs = ctx.kwargs
            result = f(*args, **kwargs)
            if ctx.kind == CONTINUE_END:
                ctx.kind = RETURN
            else:
                ctx.kind = CONTINUE_END

        return result

    return decorator

这是可运行的示例。

@functools.lru_cache
def fib(n):
    if n == 0:
        return 1
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

@trampoline
def fib_cps(n, k):
    if n == 0:
        return k(1)
    elif n == 1:
        return k(1)
    else:
        return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))

def fib_cps_wrapper(n):
    return fib_cps(n, lambda i:i)


@trampoline
def fib_tail(n, acc1=1, acc2=1):
    if n < 2:
        return acc1
    else:
        return fib_tail(n - 1, acc1 + acc2, acc1)


if __name__ == "__main__":
    print(fib(100))
    print(fib_tail(10000))
    print(fib_cps_wrapper(40))

运行号码40太慢了。 当n 更大时,fib 超出了最大递归深度。但是在添加lru_cache 之后会很快。 iter trampolined 版本的递归深度还可以,运行速度非常快。

这是其他人的作品:

  1. 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
  2. supportmutual_recursion:https://github.com/0x65/trampoline 但是太难理解了。

【问题讨论】:

  • 我很困惑你在这里问什么。您拨打的哪个电话太慢了?如果它与您正在使用的 trampoline 装饰器有关,您真的应该将它包含在您的代码中,而不是从某个神秘的地方导入它。
  • 我更新了我的问题。并添加一些其他人的作品(他们的作品不容易理解)

标签: python trampolines cps


【解决方案1】:

查看您分享的链接,有很多有趣的解决方案。我特别受到this 的启发并改变了一些事情。回顾一下,您需要一个尾递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。还有另一个有趣的discussion 是关于尾递归上下文中的相互递归的,它可能会帮助您理解主要问题。


我已经编写了一个同时进行缓存和相互递归的装饰器:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:

from collections import namedtuple
import functools

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
    f._first_call = True
    f._cache = {}

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        if f._first_call:
            f._new_args = args
            f._new_kwargs = kwargs
        
            try:
                f._first_call = False
                while True:
                    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
                    if cache_key in f._cache:
                        return f._cache[cache_key]

                    result = f(*f._new_args, **f._new_kwargs)

                    if not isinstance(result, TailRecArguments):
                        f._cache[cache_key] = result

                    if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                        f._new_args = result.args
                        f._new_kwargs = result.kwargs
                    else:
                        break

                return result
            finally:
                f._first_call = True
        else:
            return TailRecArguments(f, args, kwargs)

    return wrapper

乍一看似乎相当复杂,但它重用了链接中讨论的一些概念。


初始化

f._first_call = True
f._cache = {}

除了STARTCONTINUERETURN 之类的状态,在这种情况下,我只需要区分_first_call 和以下状态即可。实际上,第一次调用函数后,下一次调用会返回一个存储参数的TailRecArgument

f._cache 是该特定功能的缓存。


尾递归

if f._first_call:
    f._new_args = args
    f._new_kwargs = kwargs

    try:
        f._first_call = False
        while True:
            result = f(*f._new_args, **f._new_kwargs)

            if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                f._new_args = result.args
                f._new_kwargs = result.kwargs
            else:
                break

        return result
    finally:
        f._first_call = True
else:
    return TailRecArguments(f, args, kwargs)

这个版本的尾递归如何工作?在while 循环中,在第一次调用装饰函数后,使用返回的新参数不断调用该函数。

我什么时候可以退出循环?一旦返回值不是TailRecArguments类型,这意味着最后一次函数调用并没有递归调用自身,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的解决方法是将调用的函数存储在TailRecArguments 中。通过这种方式,我可以检查用于下一个循环的参数是用于同一函数(result.wrapped_func == f)还是用于另一个尾递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,而是我可以返回它们,因为它们肯定会在遇到的第一个尾递归函数的while 循环中执行。唯一的缺点是每次参数属于另一个函数时,f._first_call 都会重置。


缓存

while True:
    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
    if cache_key in f._cache:
        return f._cache[cache_key]

    result = f(*f._new_args, **f._new_kwargs)

    if not isinstance(result, TailRecArguments):
        f._cache[cache_key] = result

在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我将它放在while 循环中。不可能,因为只有在 while 循环内,函数才会被持续调用,我可以检查缓存命中。

我在创建cache_key 时有点作弊,因为我使用了functools 模块的内部函数。它是同一模块中@cache 装饰器使用的那个,您可以使用

import inspect
import functools
print(inspect.getsource(functools._make_key))

还有其他方法可以从*args**kwargs 创建缓存键,例如this one,这又指向_make_key 的实现。为了让你的代码更稳定,当然要避免使用私有成员。

正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...。我想缓存值,而不是尾递归调用的参数。

(实际上,我认为您可以将所有TailRecArguments 临时存储在一个列表中,并在递归调用返回实际值时在缓存中添加与该列表大小一样多的条目。这会使解决方案复杂化,但如果您有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我将继续处理。


测试

这些是我用来测试装饰器的几个基本函数:

@tail_recursive
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    """
    return True if n == 0 else odd(n - 1)

@tail_recursive
def odd(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> odd(100)
    False
    >>> odd(101)
    True
    """
    return False if n == 0 else even(n - 1)

@tail_recursive
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(30)
    265252859812191058636308480000000
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive
def fib(n, a = 0, b = 1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(20)
    >>> fib(30)
    832040
    """
    return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)

if __name__ == '__main__':
    import doctest
    doctest.testmod()

请注意,在这些示例中缓存不是很有用,以阶乘为例:fact(10) 实际上永远不会使用fact(8)

fact(8) fact(10)
fact(10, 1)
fact(9, 10)
fact(8, 1) fact(8, 90)
... ...

累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。


更新 - 缓存优化

这是对原始答案中使用的缓存策略的部分修复。主要问题是考虑到通用尾递归算法的工作原理,在缓存键中包含所有参数效率低下(参见阶乘示例)。

第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要差得多,但测试让一切都变得更加清晰:

class Logger:
    def __init__(self, name):
        self._name = name
        self._entries = []
    
    def log(self, s):
        self._entries.append(s)

    def print(self):
        log_prefix = f"[{self._name}] - "
        print(log_prefix + f"\n{log_prefix}".join(self._entries))

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
        get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
            functools._make_key(args, kwargs, False),\
        get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
            value):
    def decorator(f):
        f._first_call = True
        f._cache = {}

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            if f._first_call:
                f._new_args = args
                f._new_kwargs = kwargs
            
                try:
                    f._first_call = False
                    f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
                    while True:
                        cache_key = get_cache_key(f._new_args, f._new_kwargs)
                        if cache_key in f._cache:
                            logger.log('cache hit for ' + str(cache_key))
                            return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)

                        result = f(*f._new_args, **f._new_kwargs)

                        if not isinstance(result, TailRecArguments):
                            f._cache[f._initial_key] = result

                        if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                            f._new_args = result.args
                            f._new_kwargs = result.kwargs
                        else:
                            break

                    return result
                finally:
                    f._first_call = True
            else:
                return TailRecArguments(f, args, kwargs)

        return wrapper
    return decorator

除了用于确认缓存命中的Logger 类之外,主要区别在于每个函数现在都有一个名为_initial_key 的新成员,它存储第一次调用的键。这样,如果我调用fact(5)5就变成了_initial_key,结果放到f._cache[5]中。

这可以优化相互递归和尾递归函数,但在某些情况下无效。让我们从最好的情况开始:

fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
    get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(5)
    120
    >>> fact(30)
    265252859812191058636308480000000
    >>> fact_logger.print()
    [fact] - cache hit for 5
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive 装饰器初始化包括(记录器)get_cache_key,它指定只有第一个参数 n 应该是缓存键的一部分,get_result_after_cache_hit 指定如何在一个缓存命中。在上述情况下,当fact(30) 达到fact(5, &lt;partial_factorial&gt;) 时,结果立即计算为&lt;partial_factorial&gt; * f._cache[5]

even-odd 也是如此,只是在这种情况下tail_recursive 的默认参数绰绰有余:

even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    >>> even(104)
    True
    >>> even_logger.print()
    [even] - cache hit for 100
    """
    return True if n == 0 else odd(n - 1)

不幸的是,这不适用于例如斐波那契函数。您应该通过在每次调用期间打印参数来轻松说服自己,结果如下:

30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...

建立缓存键规则需要一个更复杂的逻辑,这可能会使tail_recursive 装饰器变得非常不可读且不易移植。

【讨论】:

  • 好作品!。正如您在两个不同的地方所说,我很感激您可以提供更完美的解决方案。 (顺便说一句,cps 版本 fib 的缓存很慢。)我希望有一种方法不要更改 fib_cps continue 代码但仍然很快。
  • @jiamo 很高兴您对此表示赞赏。我已经更新了答案,对缓存策略进行了可能的优化,当然,这是兼容的。
猜你喜欢
  • 2021-06-01
  • 1970-01-01
  • 1970-01-01
  • 2018-11-02
  • 2016-03-24
  • 1970-01-01
  • 1970-01-01
  • 2011-12-27
  • 2020-10-03
相关资源
最近更新 更多