【问题标题】:DFS recursion issues when deleting the rightmost node删除最右边节点时的 DFS 递归问题
【发布时间】:2021-04-17 21:43:51
【问题描述】:

我现在正在研究 DFS 方法来计算总和的路径。问题陈述是:

给定一棵二叉树和一个数字“S”,找到树中的所有路径,使得每条路径的所有节点值之和等于“S”。请注意,路径可以在任何节点开始或结束,但所有路径必须遵循从父节点到子节点(从上到下)的方向。

我的做法是:

def all_sum_path(root, target):
    global count
    count = 0
    find_sum_path(root, target, [])
    return count

def find_sum_path(root, target, allPath):
    global count
    if not root:
        return 0
    # add a space for current node
    allPath.append(0)
    # add current node values to all path
    allPath = [i+root.value for i in allPath]
    print(allPath)
    # check if current path == target
    for j in allPath:
        if j == target:
            count += 1
    # recursive
    find_sum_path(root.left, target, allPath)
    find_sum_path(root.right, target, allPath)
    # remove the current path
    print('after', allPath)
    allPath.pop()
    print('after pop', allPath)

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(all_sum_path(root, 11))

main()

返回:

[1]
[8, 7]
[14, 13, 6]
after [14, 13, 6]
after pop [14, 13]
[13, 12, 5, 5]
after [13, 12, 5, 5]
after pop [13, 12, 5]
after [8, 7, 0, 0]
after pop [8, 7, 0]
[10, 9, 9]
[12, 11, 11, 2]
after [12, 11, 11, 2]
after pop [12, 11, 11]
[13, 12, 12, 3, 3]
after [13, 12, 12, 3, 3]
after pop [13, 12, 12, 3]
after [10, 9, 9, 0, 0]
after pop [10, 9, 9, 0]
after [1, 0, 0]
after pop [1, 0]
4

我认为问题在于我没有成功删除列表中最右边的节点。然后我更新了我的代码如下,我删除了allPath最右边的节点并创建了一个名为newAllPath的新列表来记录已经加上当前节点值的节点。

def all_sum_path(root, target):
    global count
    count = 0
    find_sum_path(root, target, [])
    return count

def find_sum_path(root, target, allPath):
    global count
    if not root:
        return 0
    # add a space for current node
    allPath.append(0)
    # add current node values to all path
    newAllPath = [i+root.value for i in allPath]
    print(allPath, newAllPath)
    # check if current path == target
    for j in newAllPath:
        if j == target:
            count += 1
    # recursive
    find_sum_path(root.left, target, newAllPath)
    find_sum_path(root.right, target, newAllPath)
    # remove the current path
    print('after', allPath, newAllPath)
    allPath.pop()
    print('after pop', allPath, newAllPath)

class TreeNode():
    def __init__(self, _value):
        self.value = _value
        self.left, self.right, self.next = None, None, None

def main():
    root = TreeNode(1)
    root.left = TreeNode(7)
    root.right = TreeNode(9)
    root.left.left = TreeNode(6)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(2)
    root.right.right = TreeNode(3)

    print(all_sum_path(root, 12))

    root = TreeNode(12)
    root.left = TreeNode(7)
    root.right = TreeNode(1)
    root.left.left = TreeNode(4)
    root.right.left = TreeNode(10)
    root.right.right = TreeNode(5)

    print(all_sum_path(root, 11))

main()

返回:

[0] [1]
[1, 0] [8, 7]
[8, 7, 0] [14, 13, 6]
after [8, 7, 0] [14, 13, 6]
after pop [8, 7] [14, 13, 6]
[8, 7, 0] [13, 12, 5]
after [8, 7, 0] [13, 12, 5]
after pop [8, 7] [13, 12, 5]
after [1, 0] [8, 7]
after pop [1] [8, 7]
[1, 0] [10, 9]
[10, 9, 0] [12, 11, 2]
after [10, 9, 0] [12, 11, 2]
after pop [10, 9] [12, 11, 2]
[10, 9, 0] [13, 12, 3]
after [10, 9, 0] [13, 12, 3]
after pop [10, 9] [13, 12, 3]
after [1, 0] [10, 9]
after pop [1] [10, 9]
after [0] [1]
after pop [] [1]
3

我不确定为什么我无法在第一种方法中成功删除最正确的节点。但是,在我的第二种方法中,一旦我删除了allPath 中最右边的节点,它也会删除newAllPath 中的节点。

感谢您的帮助。我很困惑,一整天都被困在这里。

【问题讨论】:

    标签: python recursion depth-first-search


    【解决方案1】:

    除非您的树的深度超过 1000 个节点,否则您可以使用递归来获得更简单的代码:

    def findSums(node,target):
        if not node : return
        if node.value == target: yield [node.value]           # target reached, return path
        for child in (node.left,node.right):                  # traverse tree DFS
            yield from findSums(child,target)                 # paths skipping this node 
            for subPath in findSums(child,target-node.value): # paths with remainder
                yield [node.value]+subPath                    # value + sub-path
    
    for sp in findSums(root,11):
        print(sp)
    
    # [7, 4]
    # [1, 10]
    

    要打印您的二叉树,请参阅:https://stackoverflow.com/a/49844237/5237560

    【讨论】:

    • 喜欢你使用for child in (node.left, node.right) 去重复代码的方式
    • 太棒了!让我尽力理解算法。我太绿了。哈哈。
    【解决方案2】:

    功能原则

    这是一个重要的问题,我不会尝试调试您的程序,因为它违背了递归的使用方式。重申一下,递归是一种函数式遗产,因此将其与函数式风格一起使用会产生最佳结果。这意味着避免 -

    • .append+= 1.pop这样的突变
    • left = ...right = ...allPath = ... 等重新分配
    • count这样的全局变量
    • 其他副作用,如print

    分解

    尝试将任务的所有关注点包装到一个函数中是个坏主意。将问题分解成不同的部分有很多好处 -

    • 较小的函数更易于阅读、编写、测试和调试
    • 单一用途的函数更容易重用

    首先,我们将使用我们在您之前的@​​987654321@ 中已经写过的find_sum -

    def find_sum(t, q, path = []):
      if not t:
        return
      elif t.value == q:
        yield [*path, t.value]
      else:
        yield from find_sum(t.left, q - t.value, [*path, t.value])
        yield from find_sum(t.right, q - t.value, [*path, t.value])
    

    使用find_sum,我们可以轻松写出all_sum -

    def all_sum(t, q):
      for n in traverse(t):
        yield from find_sum(n, q)
    

    这需要我们写一个通用的traverse -

    def traverse(t):
      if not t:
        return
      else:
        yield from traverse(t.left)
        yield t
        yield from traverse(t.right)
    

    让我们看看它在示例树上的工作 -

                   12
                 /    \
                /      \
               7        1
              / \      / \
             4   3    10  5
                /    /
               1    1
    

    我们使用您的 TreeNode 构造函数来表示 -

    t1 = TreeNode \
      ( 12
      , TreeNode(7, TreeNode(4), TreeNode(3, TreeNode(1)))
      , TreeNode(1, TreeNode(10, TreeNode(1)), TreeNode(5))
      )
    
    print(list(all_sum(t1, 11)))
    
    [[7, 4], [7, 3, 1], [10, 1], [1, 10]]
    

    计算所有总数

    如果唯一的目标是计算总和,我们可以将count_all_sum 写成all_sum 的简单包装 -

    def count_all_sum (t, q):
      return len(list(all_sum(t, q)))
    

    【讨论】:

    • 感谢您的帮助。我会尽力应用你的算法来解决 DFS 的问题。我来自统计学,正在努力成为 IT 公司的 DS。
    猜你喜欢
    • 2022-07-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2012-07-29
    • 2023-02-23
    • 1970-01-01
    • 1970-01-01
    • 2021-03-15
    相关资源
    最近更新 更多