【问题标题】:Implementing memoization in recursive code break functionality在递归代码中断功能中实现记忆
【发布时间】:2016-12-17 19:00:43
【问题描述】:

我似乎无法弄清楚是什么破坏了我的代码。

我写了一个带金字塔的代码:

[[1],
 [2,3],
 [4,5,6],
 [7,8,9,10]]

从头部开始 (pyramid[0][0]) 计算通过移动到下面的项目或递归到下面的项目并向右移动可能实现的最大总和

在本例中,输出应为 20。

这是没有记忆的代码,可以正常工作:

def max_trail(pyramid,row=0,col=0):
    if row == (len(pyramid)-1):
        return pyramid[row][col]
    else:
        return pyramid[row][col] + max(max_trail(pyramid ,row+1 ,col),
                                       max_trail(pyramid ,row+1, col+1))

但是当我尝试应用 memoization 时,沿途出现了一些问题。 我错过了什么?

这是损坏的代码:

def max_trail_mem(pyramid,row=0,col=0,mem=None):
    if mem==None:
        mem={}
    key = ((row),(col))
    if row == (len(pyramid)-1):
        if key not in mem:
            value = pyramid[row][col]
            mem[key]=value
            return mem[key]
        return mem[key]
    else:
        key = (row+1),(col)
        if key not in mem:
            mem[(row+1),(col)] = max_trail_mem(pyramid ,row+1 ,col,mem)
        key = (row+1),(col+1)
        if key not in mem:
            mem[(row+1),(col+1)]=max_trail_mem(pyramid ,row+1, col+1,mem)
    return max(mem[(row+1),(col)],mem[(row+1),(col+1)])

这让我可怜的学生生活减少了好几个小时。 任何帮助将不胜感激!

【问题讨论】:

  • 欢迎来到 StackOverflow!请花点时间拨打tour 并查看what can I ask here?。正如它所写的那样,您的问题是寻求调试帮助“为什么这段代码不起作用?”这在what can I ask here? 中特别概述为stackoverflow 的题外话。
  • 您是否尝试过测试失败的地方?您可以添加打印语句以进行快速测试或使用pdb

标签: python recursion memoization


【解决方案1】:

您在上次返回时忘记了 pyramid[row][col] + 之前的 max(...。添加它会为您的示例提供20(请参见最后一行代码):

def max_trail_mem(pyramid,row=0,col=0,mem=None):
    if mem==None:
        mem={}
    key = ((row),(col))
    if row == (len(pyramid)-1):
        if key not in mem:
            value = pyramid[row][col]
            mem[key]=value
            return mem[key]
        return mem[key]
    else:
        key = (row+1),(col)
        if key not in mem:
            mem[(row+1),(col)] = max_trail_mem(pyramid ,row+1 ,col,mem)
        key = (row+1),(col+1)
        if key not in mem:
            mem[(row+1),(col+1)] = max_trail_mem(pyramid ,row+1, col+1,mem)
    return pyramid[row][col] + max(mem[(row+1),(col)],mem[(row+1),(col+1)])

【讨论】:

  • 谢谢!这样可行!但是......似乎这里的记忆真的没有改善运行时间......它有什么贡献吗?我尝试了一个大小为 10 的金字塔,它在我的版本中返回了相同的运行时间,没有记忆,并且......这里有什么建议吗?
  • 我得到 41.2 µs 的记忆和 201 µs 的纯递归版本。因此,对于大小为 10 的金字塔,它的速度大约快 5 倍。对于大小为 15 的金字塔,差异为 101 µs 和 6.15 ms,即超过 60 倍。对于大小 20,此因子超过 1000。
  • 谢谢你,迈克,我尝试了一个更大的金字塔,它确实有效!
【解决方案2】:
from timeit import timeit
import math

from repoze.lru import CacheMaker
cache_maker=CacheMaker()


def max_trail(pyramid,row=0,col=0):
    if row == (len(pyramid)-1):
        return pyramid[row][col]
    else:

        mt1 = max_trail(pyramid ,row+1 ,col)
        mt2 = max_trail(pyramid ,row+1, col+1)
        return pyramid[row][col] + max(mt1, mt2)

@cache_maker.lrucache(maxsize='1000', name='pyramid')
def max_trail_with_memoization(pyramid,row=0,col=0):
    if row == (len(pyramid)-1):
        return pyramid[row][col]
    else:

        mt1 = max_trail(pyramid ,row+1 ,col)
        mt2 = max_trail(pyramid ,row+1, col+1)
        return pyramid[row][col] + max(mt1, mt2)

# Build pyramid
pyramid = ()
c = 0
for i in range(20):
    row = ()
    for j in range(i):
        c += 1
        row += (c,)
    if row:
        pyramid += (tuple(row),)

# Repetitions to time
number = 20

# Time it
print('without memoization:  ', timeit('f(t)', 'from __main__ import max_trail as f, pyramid as t', number=number))
print('with memoization      ', timeit('f(t)', 'from __main__ import max_trail_with_memoization as f, pyramid as t', number=number))




print max_trail(pyramid)

返回:

without memoization:   3.9645819664001465
with memoization:      0.18628692626953125

【讨论】:

    猜你喜欢
    • 2023-03-22
    • 2021-06-22
    • 1970-01-01
    • 2018-02-28
    • 2012-11-12
    • 2012-09-27
    • 1970-01-01
    • 1970-01-01
    • 2012-11-26
    相关资源
    最近更新 更多