【问题标题】:Python: recall cached function result dependent on new function parameterPython:根据新的函数参数调用缓存的函数结果
【发布时间】:2017-10-20 00:09:29
【问题描述】:

我对缓存和记忆的概念相当陌生。我已经阅读了一些其他讨论和资源 hereherehere,但一直无法很好地关注它们。

假设我在一个类中有两个成员函数。 (下面的简化示例。)假设第一个函数total 计算量很大。第二个函数subtotal 在计算上很简单,除了它使用第一个函数的返回值,因此也变得计算量很大,因为它当前需要重新调用total 来获得它的返回结果。

我想缓存第一个函数的结果并将其用作第二个函数的输入,如果输入 ysubtotal 将输入 x 共享给最近的通话的total。那就是:

  • 如果调用 subtotal(),其中 y 等于 a
    x 的值 total 的先前调用,然后使用该缓存结果而不是
    重新调用total
  • 否则,只需使用x = y 调用total()

例子:

class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b

    def total(self, x):
        return (self.a + self.b) * x     # some time-expensive calculation

    def subtotal(self, y, z):
        return self.total(x=y) + z       # Don't want to have to re-run total() here
                                         # IF y == x from a recent call of total(),
                                         # otherwise, call total().

【问题讨论】:

  • 你试过这个吗:stackoverflow.com/a/18723434/2570677。我已经在我的代码中使用了它,它运行良好。
  • 假设你指的是@functools.lru_cache?
  • 从您链接到的资源中,是什么阻止您仅使用基本缓存功能装饰total?您只需输入@functools.lru_cache(maxsize=N),它就会缓存N 相同参数的结果。为什么这在您的场景中不起作用?
  • @BradSolomon 我指的是包含实现的答案(没有任何外部模块)。它适用于 python 2.7。

标签: python-3.x caching memoization


【解决方案1】:

对于 Python3.2 或更高版本,您可以使用functools.lru_cache。 如果您要直接用functools.lru_cache 装饰total,那么lru_cache 将根据selfx 这两个参数的值缓存total 的返回值。由于 lru_cache 的内部 dict 存储对 self 的引用,因此将 @lru_cache 直接应用于类方法会创建对 self 的循环引用,这使得类的实例不可解除引用(因此内存泄漏)。

Here is a workaround 允许您将lru_cache 与类方法一起使用——它基于除第一个参数self 之外的所有参数缓存结果,并使用weakref 来避免循环引用问题:

import functools
import weakref

def memoized_method(*lru_args, **lru_kwargs):
    """
    https://stackoverflow.com/a/33672499/190597 (orly)
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapped_func(self, *args, **kwargs):
            # We're storing the wrapped method inside the instance. If we had
            # a strong reference to self the instance would never die.
            self_weak = weakref.ref(self)
            @functools.wraps(func)
            @functools.lru_cache(*lru_args, **lru_kwargs)
            def cached_method(*args, **kwargs):
                return func(self_weak(), *args, **kwargs)
            setattr(self, func.__name__, cached_method)
            return cached_method(*args, **kwargs)
        return wrapped_func
    return decorator


class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b

    @memoized_method()
    def total(self, x):
        print('Calling total (x={})'.format(x))
        return (self.a + self.b) * x


    def subtotal(self, y, z):
        return self.total(x=y) + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

打印

Calling total (x=10)

只有一次。


或者,您可以通过以下方式使用 dict 滚动自己的缓存:

class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b
        self._total = dict()

    def total(self, x):
        print('Calling total (x={})'.format(x))
        self._total[x] = t = (self.a + self.b) * x
        return t

    def subtotal(self, y, z):
        t = self._total[y] if y in self._total else self.total(y)
        return t + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

lru_cache 相对于这个基于字典的缓存的一个优势是 lru_cache 是线程安全的。 lru_cache 也有一个 maxsize 参数可以帮助 防止内存使用量无限制地增长(例如,由于 长时间运行的进程多次调用 total 并使用不同的 x 值)。

【讨论】:

    【解决方案2】:

    感谢大家的回复,阅读它们并了解幕后情况很有帮助。正如@Tadhg McDonald-Jensen 所说,似乎我在这里不需要更多的东西,而不是@functools.lru_cache。 (我在 Python 3.5 中。)关于 @unutbu 的评论,我没有收到用 @lru_cache 装饰 total() 的错误。让我纠正我自己的例子,我会在这里为其他初学者保留这个:

    from functools import lru_cache
    from datetime import datetime as dt
    
    class MyObject(object):
        def __init__(self, a, b):
            self.a, self.b = a, b
    
        @lru_cache(maxsize=None)
        def total(self, x):        
            lst = []
            for i in range(int(1e7)):
                val = self.a + self.b + x    # time-expensive loop
                lst.append(val)
            return np.array(lst)     
    
        def subtotal(self, y, z):
            return self.total(x=y) + z       # if y==x from a previous call of
                                             # total(), used cached result.
    
    myobj = MyObject(1, 2)
    
    # Call total() with x=20
    a = dt.now()
    myobj.total(x=20)
    b = dt.now()
    c = (b - a).total_seconds()
    
    # Call subtotal() with y=21
    a2 = dt.now()
    myobj.subtotal(y=21, z=1)
    b2 = dt.now()
    c2 = (b2 - a2).total_seconds()
    
    # Call subtotal() with y=20 - should take substantially less time
    # with x=20 used in previous call of total().
    a3 = dt.now()
    myobj.subtotal(y=20, z=1)
    b3 = dt.now()
    c3 = (b3 - a3).total_seconds()
    
    print('c: {}, c2: {}, c3: {}'.format(c, c2, c3))
    c: 2.469753, c2: 2.355764, c3: 0.016998
    

    【讨论】:

    • self.aself.b 会改变吗?如果是这样,缓存的值应该被清除,因为total 的计算值会改变。您可以通过设置ab 的setter 调用total.cache_clear() 的可设置属性来实现它。
    • PS:我将 lru_cache 应用于引发错误的类方法是错误的。虽然没有报错,不过it does cause a memory leak
    • 这里self.aself.b不会被修改。不过谢谢,知道这很有帮助。
    【解决方案3】:

    在这种情况下,我会做一些简单的事情,也许不是最优雅的方式,但可以解决问题:

    class MyObject(object):
        param_values = {}
        def __init__(self, a, b):
            self.a, self.b = a, b
    
        def total(self, x):
            if x not in MyObject.param_values:
              MyObject.param_values[x] = (self.a + self.b) * x
              print(str(x) + " was never called before")
            return MyObject.param_values[x]
    
        def subtotal(self, y, z):
            if y in MyObject.param_values:
              return MyObject.param_values[y] + z
            else:
              return self.total(y) + z
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2018-07-20
      • 2010-11-13
      • 2021-02-21
      • 2021-10-10
      • 2011-01-14
      • 1970-01-01
      • 2022-11-15
      • 1970-01-01
      相关资源
      最近更新 更多