问题描述
以下示例来自在线来源,我不确定为什么我们需要将 allPaths
的新副本附加到 currentPath
。我认为在我们通过执行 del currentPath[-1]
返回递归调用堆栈时删除最后一个元素可确保我们不会将先前的路径添加到新路径
class TreeNode:
def __init__(self,val,left=None,right=None):
self.val = val
self.left = left
self.right = right
def find_paths(root,required_sum):
allPaths = []
find_paths_recursive(root,required_sum,[],allPaths)
return allPaths
def find_paths_recursive(currentNode,currentPath,allPaths):
if currentNode is None:
return
# add the current node to the path
currentPath.append(currentNode.val)
# if the current node is a leaf and its value is equal to required_sum,save the current path
if currentNode.val == required_sum and currentNode.left is None and currentNode.right is None:
allPaths.append(list(currentPath))
else:
# traverse the left sub-tree
find_paths_recursive(currentNode.left,required_sum -
currentNode.val,allPaths)
# traverse the right sub-tree
find_paths_recursive(currentNode.right,allPaths)
# remove the current node from the path to backtrack,# we need to remove the current node while we are going up the recursive call stack.
del currentPath[-1]
def main():
root = TreeNode(12)
root.left = TreeNode(7)
root.right = TreeNode(1)
root.left.left = TreeNode(4)
root.right.left = TreeNode(10)
root.right.right = TreeNode(5)
required_sum = 23
print("Tree paths with required_sum " + str(required_sum) +
": " + str(find_paths(root,required_sum)))
main()
解决方法
重要的是要意识到在整个过程中只有一个 currentPath
列表。它是在初始调用中创建的:
find_paths_recursive(root,required_sum,[],allPaths)
# ^^---here!
发生在单个列表上的只是元素被附加到它上面,然后再次从它中删除(回溯时)。它在它的一生中不断地增长和缩小,增长和缩小,……。但它是相同的,单个列表实例。
如果您将该列表附加到 allPaths
而不复制,例如:
allPaths.append(currentPath)
...然后意识到虽然该列表是 allPaths
的成员,但它将被未来的 append
和 delete
操作改变!甚至更多:因为上面的语句稍后再次执行:
allPaths.append(currentPath)
...完全相同的列表被追加,它已经在allPaths
...因为只有一个currentPath
列表!因此,您最终会得到 allPaths
对同一个列表的重复引用。
结论:获取 currentPath
的副本很重要,它不会再被 currentPath
上的未来突变所改变。这就像对 currentPath
的当前状态进行快照。
find_paths_recursive
函数的设计使得附加到 allPaths
是将结果返回给调用者的方式。
def find_paths(root,required_sum):
allPaths = []
find_paths_recursive(root,allPaths)
return allPaths
在 find_paths
中,allPaths
作为一个空列表传递给 find_paths_recursive
,完成后,它将包含结果(从根到叶的路径满足所描述的条件).
我建议将问题分解为不同的部分。首先我们编写一个通用的 paths
函数 -
def paths (t = None,p = ()):
if not t:
return
elif t.left or t.right:
yield from paths(t.left,(*p,t.val))
yield from paths(t.right,t.val))
else:
yield (*p,t.val)
mytree = TreeNode \
( 12,TreeNode(7,TreeNode(4)),TreeNode(1,TreeNode(10))
)
现在我们可以看到 paths
是如何工作的 -
for p in paths(mytree):
print(p)
(12,7,4)
(12,1,10)
现在我们可以写出专门化solver
的{{1}} -
paths
def solver (t = None,q = 0):
for p in paths(t):
if sum(p) == q:
yield p
是一个生成所有可能解决方案的生成器。对于此类程序来说,这是一个不错的选择,因为您可以在找到所需的解决方案后立即暂停/取消求解 -
solver
输出并不是特别有趣,因为 for sln in solver(mytree,23):
print(sln)
中的每个分支总和为 23 -
mytree
如果我们让 (12,10)
具有不同的值,我们可以看到更有趣的输出 -
anothertree
anothertree = TreeNode \
( 1,TreeNode(4),TreeNode(5)),TreeNode(9,TreeNode(2),TreeNode(7))
)
for sln in solver(anothertree,12):
print(sln)