【问题标题】:Efficient generic Python memoize高效的通用 Python memoize
【发布时间】:2012-12-28 18:47:57
【问题描述】:

我有一个通用的 Python 记忆器:

cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, str(args))
        result = cache.get(key, None)
        if result is None:
            result = f(*args)
            cache[key] = result
        return result

    return decorated

它有效,但我对此并不满意,因为有时它效率不高。最近,我将它与一个将列表作为参数的函数一起使用,显然用整个列表制作键会减慢一切。最好的方法是什么? (即,有效地计算键,无论参数是什么,无论它们有多长或多复杂)

我想这个问题实际上是关于如何有效地从 args 和通用记忆器的函数生成密钥 - 我在一个程序中观察到糟糕的密钥(生成成本太高)对运行时产生了重大影响。我的 prog 使用 'str(args)' 需要 45 秒,但我可以使用手工制作的键将其减少到 3 秒。不幸的是,手工制作的密钥是特定于这个程序的,但我想要一个快速的记忆器,我不必每次都为缓存推出特定的手工制作的密钥。

【问题讨论】:

  • 显然用整个列表制作键会减慢一切。不,不会的。 dict 无论您使用什么键,存储都一样高效,并且查找时间为 O(1)。恐怕你的应用出了点完全不同的问题。
  • Pypi 上有很多食谱和模块。 Python 3.2+ 还附带 functools.lru_cache
  • 真的是这样吗?在 Python 中,列表不被认为是可散列的,dict 存储键必须是可散列的。也许OP实际上有一点。
  • 是的,它不适用于列表,所以我实际上在 args 上使用了 str。我忘了提。
  • 对于列表和其他可变容器,恐怕没有一种方法至少不是 O(n) (至少如果你想按值区分这些容器,而不是按身份区分) )。要区分两个 n 元素容器,您需要考虑每个元素(即使对于散列,尽管元组和字符串可以并且确实缓存它们的散列)。为什么需要使用不可散列的参数进行记忆化?

标签: python memoization


【解决方案1】:

首先,如果您非常确定 O(N) 散列在这里是合理且必要的,并且您只是想使用比 hash(str(x)) 更快的算法来加快速度,试试这个:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        result ^= hash(element)
    return result

当然,这不适用于可能很深的序列,但有一个明显的解决方法:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        try:
            result ^= hash(element)
        except TypeError:
            result ^= hash_seq(element)
    return result

我不认为这是一个足够好的哈希算法,因为它会为同一个列表的不同排列返回相同的值。但我很确定没有足够好的哈希算法会更快。至少如果它是用 C 或 Cython 编写的,如果这是你的方向,你最终可能会想要这样做。

另外,值得注意的是,这在str(或marshal)不正确的许多情况下是正确的,例如,如果您的list 可能有一些可变元素,其repr 涉及其@987654334 @ 而不是它的值。但是,它仍然不是在所有情况下都是正确的。特别是,它假设“迭代相同的元素”意味着任何可迭代类型的“相等”,这显然不能保证是真的。误报不是什么大问题,但误报是(例如,两个 dicts 具有相同的键但不同的值可能会虚假地比较相等并共享一个备忘录)。

此外,它不使用额外的空间,而不是使用相当大的乘数的 O(N)。

无论如何,值得先尝试一下,然后再决定是否值得分析它是否值得分析以确保足够好并进行微调以进行微优化。

这是一个简单的 Cython 版本的浅层实现:

def test_cy_xor(iterable):
    cdef int result = hash(type(iterable))
    cdef int h
    for element in iterable:
        h = hash(element)
        result ^= h
    return result

通过快速测试,纯 Python 实现非常慢(正如您所料,所有 Python 循环,与 strmarshal 中的 C 循环相比),但 Cython 版本很容易获胜:

    test_str(    3):  0.015475
test_marshal(    3):  0.008852
    test_xor(    3):  0.016770
 test_cy_xor(    3):  0.004613
    test_str(10000):  8.633486
test_marshal(10000):  2.735319
    test_xor(10000): 24.895457
 test_cy_xor(10000):  0.716340

只是在 Cython 中迭代序列并且什么都不做(实际上只是对 PyIter_Next 的 N 次调用和一些引用计数,所以你不会在本机 C 中做得更好)是 @ 的 70% 987654341@。您可能可以通过要求实际序列而不是可迭代来使其更快,甚至通过要求list 来加快速度,尽管无论哪种方式都可能需要编写显式 C 而不是 Cython 才能获得好处。

无论如何,我们如何解决订购问题?显而易见的 Python 解决方案是散列 (i, element) 而不是 element,但所有这些元组操作都会将 Cython 版本的速度降低 12 倍。标准解决方案是在每个异或之间乘以某个数字。但是,当您使用它时,值得尝试让这些值很好地分布在短序列、小int 元素和其他非常常见的边缘情况下。选择正确的数字很棘手,所以……我只是从tuple 借了所有东西。这是完整的测试。

_hashtest.pyx:

cdef _test_xor(seq):
    cdef long result = 0x345678
    cdef long mult = 1000003
    cdef long h
    cdef long l = 0
    try:
        l = len(seq)
    except TypeError:
        # NOTE: This probably means very short non-len-able sequences
        # will not be spread as well as they should, but I'm not
        # sure what else to do.
        l = 100
    for element in seq:
        try:
            h = hash(element)
        except TypeError:
            h = _test_xor(element)
        result ^= h
        result *= mult
        mult += 82520 + l + l
    result += 97531
    return result

def test_xor(seq):
    return _test_xor(seq) ^ hash(type(seq))

hashtest.py:

import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest

def test_str(seq):
    return hash(str(seq))

def test_marshal(seq):
    return hash(marshal.dumps(seq))

def test_cy_xor(seq):
    return _hashtest.test_xor(seq)

# This one is so slow that I don't bother to test it...
def test_xor(seq):
    result = hash(type(seq))
    for i, element in enumerate(seq):
        try:
            result ^= hash((i, element))
        except TypeError:
            result ^= hash(i, hash_seq(element))
    return result

smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]

def run():
    for seq in smalltest, bigtest:
        for f in test_str, test_marshal, test_cy_xor:
            print('%16s(%5d): %9f' % (f.func_name, len(seq),
                                      timeit.timeit(lambda: f(seq), number=10000)))

if __name__ == '__main__':
    run()

输出:

    test_str(    3):  0.014489
test_marshal(    3):  0.008746
 test_cy_xor(    3):  0.004686
    test_str(10000):  8.563252
test_marshal(10000):  2.744564
 test_cy_xor(10000):  0.904398

以下是一些可以加快速度的潜在方法:

  • 如果您有很多深度序列,不要在 hash 周围使用 try,而是调用 PyObject_Hash 并检查 -1。
  • 如果你知道你有一个序列(或者,甚至更好,特别是一个list),而不仅仅是一个可迭代的,PySequence_ITEM(或PyList_GET_ITEM)可能会隐含地比PyIter_Next更快上面用过。

在任何一种情况下,一旦您开始调用 C API 调用,通常更容易放弃 Cython 并用 C 编写函数。(您仍然可以使用 Cython 围绕该 C 函数编写一个简单的包装器,而不是手动编码扩展模块。)此时,只需直接借用tuplehash 代码,而不是重新实现相同的算法。

如果您首先要寻找避免O(N) 的方法,那是不可能的。如果你看看tuple.__hash__frozenset.__hash__ImmutableSet.__hash__ 是如何工作的(顺便说一句,最后一个是纯Python 并且非常易读),它们都采用O(N)。但是,它们都缓存了哈希值。因此,如果您经常对 same tuple(而不是不相同但相等的)进行散列处理,它会接近恒定时间。 (它是O(N/M),其中M 是您与每个tuple 通话的次数。)

如果您可以假设您的 list 对象在调用之间永远不会发生变化,那么您显然可以做同样的事情,例如,将 dict 映射为 idhash 作为外部缓存。但总的来说,这显然不是一个合理的假设。 (如果您的 list 对象永远不会发生变异,那么只需切换到 tuple 对象而不用担心所有这些复杂性会更容易。)

但是您可以将 list 对象包装在一个添加缓存哈希值成员(或槽)的子类中,并在收到变异调用时使缓存无效(append__setitem____delitem__ , ETC。)。然后你的hash_seq 可以检查。

最终结果与tuples 的正确性和性能相同:摊销O(N/M),除了tuple M 是您使用每个相同的tuple 调用的次数,而对于@ 987654383@ 是您调用每个相同的 list 的次数,而不会在两者之间发生变异。

【讨论】:

  • +1 有趣且详细。使用您的hash_seq,2 个不同的输入序列是否有可能产生相同的输出值?
  • 纯 Python 版本在我的实际 prog 上给出了 87s,v 最初是 45s,而 marshall.dumps 是 15s。我还没有尝试过 Cython,但我很想尝试一下。
  • @Gerrat:当然,它总是有可能——你不能在没有冲突的情况下将可能无限数量的值散列到单个int 中。我的第一个实现的问题在于,即使在一些合理的用例中也会发生这种情况——例如,[1, 2, 3][3, 2, 1] 的哈希值相同!第二个版本(带有*= mult 位)修复了这个问题,但我仍然不会发誓这是一个“足够好”的哈希函数。
  • @Frank:你知道如何使用 Cython 的pyximport 进行简单的测试吗?如果没有,如果您有任何问题,请问我。
  • @abarnert 我试图实现你上面的 Cython 代码。我设法编译它。它似乎非常快,但我出错了,因为我的 prog 的结果是错误的。我想我会以某种方式发生冲突,这会导致缓存错误。您是否有一个完整的版本,其中包含您可以发布的素数值? (PS:这是我第一次使用 Cython - 但我已经喜欢它了!)
【解决方案2】:

你可以尝试几件事:

使用 marshal.dumps 而不是 str 可能会稍微快一些(至少在我的机器上):

>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219

此外,如果您的函数计算成本很高,并且可能自己返回 None,那么您的记忆函数每次都会重新计算它们(我可能会到达这里,但不知道更多,我只能猜测)。 结合这两件事给出:

import marshal
cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, marshal.dumps(args))
        if key in cache:
            return cache[key]

        cache[key] = f(*args)
        return cache[key]

    return decorated

【讨论】:

  • 谢谢!重写代码并没有做任何重要的事情。使用 marshal.dumps 将运行时间从 45 秒缩短到 15 秒。 - 我们能做得更好吗?
  • 您可以查看 Raymond Hettinger 在活动状态下的各种记忆装饰器。我认为他负责 Python 3.2 上的 LRU 缓存。您可能还可以做其他事情,但可能不是“一般”。我看到你已经修改了你的问题以指定你想要一些通用的东西。 ...在这种情况下,Raymond 的东西可能会尽可能好。
  • 在 len=3 上测试似乎有点不公平——但我用各种不同的长度重复了你的测试;对于大多数中等范围的值,marshall 的速度大约是 str 的两倍,对于非常小的和非常大的列表来说要快得多(在后一种情况下,我猜这是因为内存使用或分配?),所以这绝对是赢家。至于我们是否可以做得更好,请参阅我的帖子——但这可能取决于您是否想继续使用纯 Python。
  • 我 45 岁的结果是我在程序中的列表,可以是 > 1000 个字符。
猜你喜欢
  • 1970-01-01
  • 2018-07-19
  • 2012-05-15
  • 2016-09-08
  • 1970-01-01
  • 1970-01-01
  • 2010-09-12
  • 2012-05-25
  • 2011-10-12
相关资源
最近更新 更多