You could say that today’s date is a palindrome (02-02-2020). In light of the number of twos 👯♂️, I’ll be writing about a couple of binary tree problems (one today, another tomorrow).
Let’s get started on today’s problem.
From LeetCode:
You are given a binary tree in which each node contains an integer value. Find the number of paths that sum to a given value. The path does not need to start or end at the root or a leaf, but it must go downwards (traveling only from parent nodes to child nodes).
Input: """ 10 / \ 5 -3 / \ \ 3 2 11 / \ \ 3 -2 1 """ 8 Output: 3 # 1: 1 -> 5 -> 3 # 2: 5 -> 2 -> 1 # 3: -3 -> 11
If you’re wondering why this question’s been labelled as “Easy” on LeetCode, you’re definitely not alone! I’d consider a question of this sort to be of moderate difficulty.
Preliminary Considerations
This question is asking us to work with path sums, so we know that we’ll probably need to do:
- Run some standard tree traversal (DFS/BFS)
- Maintain some kind of accumulated state
If you jump into coding right away (a terrible idea), your first few lines might look something like this:
def traverse(subtree, curr_sum=0):
nonlocal target_sum
nonlocal res
curr_sum += subtree.val
if curr_sum == target_sum:
# Increment the result
traverse(tree)
The above seems far from complete –– we know that any matching path doesn’t necessarily have to start from the root.
Our overarching approach will still be to analyse all paths from the tree’s root to its leaves, but in addition to accumulating the sum of the nodes along a path, we’ll also need maintain a pathsums
hashmap to track the sums of subpaths (i.e. from the root to a non-leaf node):
# <sum of path>: <num of paths with this sum>
pathsums = collections.defaultdict(int)
Why do we track seen path sums? Storing binary tree path sums is useful in the same way that storing array prefix sums is.
Similarities to Array Prefix Sums
Array: [5, 10, 15, 20]
Prefix Sums: [5, 15, 30, 50]
To find the sum of a subarray which starts from index i
and ends at index j
(inclusive), we take prefix_sums[j]
and subtract prefix_sums[i-1]
from it. Similarly, if we need to find a specific subarray whose sum equals a given value x
, we can simply check if we’ve seen some prefix_sum[i]
whose value is the same as prefix_sums[j] - x
.
Each element in a numerical array is analogous to a node along a binary tree path. Suppose our target sum is 35
:
"""
Tree Path: root -> subtree -> subtree -> leaf
Values: 5 10 15 20
Path Sums: 5 15 30 50
"""
- As we move down the tree and reach the leaf node, our
curr_sum
is raised to a value of50
. - At this point, we would have stored entries for
pathsums[5]
,pathsums[15]
, andpathsums[30]
. - We can look up
pathsums[50 - 35]
(i.e.pathsums[15]
) and conclude that a path summing to35
must exist since we’ve seen pathsums for50
and15
along our current path.
Traversing the Tree
Let’s rewrite traverse()
to match the logic we’ve described:
def traverse(subtree, curr_sum=0):
nonlocal target_sum
nonlocal res
if not subtree:
return
curr_sum += subtree.val
if curr_sum == target_sum:
res += 1
if curr_sum - target_sum in pathsums:
# If a pathsum for curr_sum - target_sum
# exists, then we know that a path which
# sums to target_sum also exists
res += pathsums[curr_sum - target_sum]
pathsums[curr_sum] += 1
Notice that pathsums
isn’t just a simple set
of keys, it’s a mapping between path sums and their number of occurrences. We maintain counters for every possible path sum because the same path sum can technically occur multiple times:
"""
Tree Path: root -> subtree -> subtree -> subtree -> leaf
Values: 5 3 2 0 -2
Path Sums: 5 8 10 10 8
"""
In the above path, the path sums of 8
and 10
both occur twice.
Let’s flesh out the remainder of our traversal function. Specifically, we need to:
- Recurse on the lower levels of the tree
- Delete seen path sums as we backtrack, since they won’t apply to the next paths that we analyse
def traverse(subtree, curr_sum=0):
nonlocal target_sum
nonlocal res
if not subtree:
return
curr_sum += subtree.val
if curr_sum == target_sum:
res += 1
if curr_sum - target_sum in pathsums:
res += pathsums[curr_sum - target_sum]
pathsums[curr_sum] += 1
traverse(subtree.left, curr_sum)
traverse(subtree.right, curr_sum)
pathsums[curr_sum] -= 1
Our traversal helper function is now complete. Applying it to the main algorithm is pretty straightforward:
def get_num_of_matching_paths(tree, target_sum):
res = 0
pathsums = collections.defaultdict(int)
def traverse(subtree, curr_sum=0):
# Update pathsums, recurse down the tree
# ...
traverse(tree)
return res
Given n
number of nodes, our algorithm runs in O(n)
space and O(n)
time.
With that, we’ve solved the Binary Tree Path Sums problem 🌲🧮.
Full Solution
class Node:
def __init__(self, val, left, right):
self.val = val
self.left = left
self.right = right
def get_num_of_matching_paths(tree, target_sum):
res = 0
pathsums = collections.defaultdict(int)
def traverse(subtree, curr_sum=0):
nonlocal target_sum
nonlocal res
if not subtree:
return
curr_sum += subtree.val
if curr_sum == target_sum:
res += 1
if curr_sum - target_sum in pathsums:
res += pathsums[curr_sum-target_sum]
pathsums[curr_sum] += 1
traverse(subtree.left, curr_sum)
traverse(subtree.right, curr_sum)
pathsums[curr_sum] -= 1
traverse(tree)
return res