All Articles

Binary Tree Path Sums

Image from Unsplash by Annie Spratt
Image from Unsplash by Annie Spratt

You could say that today’s date is a palindrome (02-02-2020). In light of the number of twos 👯‍♂️, I’ll be writing about a couple of binary tree problems (one today, another tomorrow).

Let’s get started on today’s problem.

From LeetCode:

You are given a binary tree in which each node contains an integer value. Find the number of paths that sum to a given value. The path does not need to start or end at the root or a leaf, but it must go downwards (traveling only from parent nodes to child nodes).

Input:
"""
      10
     /  \ 
    5   -3
   / \    \ 
  3   2   11
 / \   \ 
3  -2   1
"""
8

Output:
3
# 1: 1 -> 5 -> 3
# 2: 5 -> 2 -> 1
# 3: -3 -> 11

If you’re wondering why this question’s been labelled as “Easy” on LeetCode, you’re definitely not alone! I’d consider a question of this sort to be of moderate difficulty.

Preliminary Considerations

This question is asking us to work with path sums, so we know that we’ll probably need to do:

  • Run some standard tree traversal (DFS/BFS)
  • Maintain some kind of accumulated state

If you jump into coding right away (a terrible idea), your first few lines might look something like this:

def traverse(subtree, curr_sum=0):
    nonlocal target_sum
    nonlocal res

    curr_sum += subtree.val
    if curr_sum == target_sum:
        # Increment the result

traverse(tree)

The above seems far from complete –– we know that any matching path doesn’t necessarily have to start from the root.

Our overarching approach will still be to analyse all paths from the tree’s root to its leaves, but in addition to accumulating the sum of the nodes along a path, we’ll also need maintain a pathsums hashmap to track the sums of subpaths (i.e. from the root to a non-leaf node):

# <sum of path>: <num of paths with this sum>
pathsums = collections.defaultdict(int)

Why do we track seen path sums? Storing binary tree path sums is useful in the same way that storing array prefix sums is.

Similarities to Array Prefix Sums

Array:       [5, 10, 15, 20]
Prefix Sums: [5, 15, 30, 50]

To find the sum of a subarray which starts from index i and ends at index j (inclusive), we take prefix_sums[j] and subtract prefix_sums[i-1] from it. Similarly, if we need to find a specific subarray whose sum equals a given value x, we can simply check if we’ve seen some prefix_sum[i] whose value is the same as prefix_sums[j] - x.

Each element in a numerical array is analogous to a node along a binary tree path. Suppose our target sum is 35:

"""
Tree Path: root -> subtree -> subtree -> leaf
Values:     5        10         15        20
Path Sums:  5        15         30        50
"""
  • As we move down the tree and reach the leaf node, our curr_sum is raised to a value of 50.
  • At this point, we would have stored entries for pathsums[5], pathsums[15], and pathsums[30].
  • We can look up pathsums[50 - 35] (i.e. pathsums[15]) and conclude that a path summing to 35 must exist since we’ve seen pathsums for 50 and 15 along our current path.

Traversing the Tree

Let’s rewrite traverse() to match the logic we’ve described:

def traverse(subtree, curr_sum=0):
    nonlocal target_sum
    nonlocal res
 
    if not subtree:
        return

    curr_sum += subtree.val

    if curr_sum == target_sum:
        res += 1

    if curr_sum - target_sum in pathsums:
        # If a pathsum for curr_sum - target_sum
        # exists, then we know that a path which
        # sums to target_sum also exists
        res += pathsums[curr_sum - target_sum]

    pathsums[curr_sum] += 1

Notice that pathsums isn’t just a simple set of keys, it’s a mapping between path sums and their number of occurrences. We maintain counters for every possible path sum because the same path sum can technically occur multiple times:

"""
Tree Path: root -> subtree -> subtree -> subtree -> leaf
Values:     5         3          2          0        -2
Path Sums:  5         8         10         10         8
"""

In the above path, the path sums of 8 and 10 both occur twice.

Let’s flesh out the remainder of our traversal function. Specifically, we need to:

  • Recurse on the lower levels of the tree
  • Delete seen path sums as we backtrack, since they won’t apply to the next paths that we analyse
def traverse(subtree, curr_sum=0):
    nonlocal target_sum
    nonlocal res
 
    if not subtree:
        return

    curr_sum += subtree.val

    if curr_sum == target_sum:
        res += 1

    if curr_sum - target_sum in pathsums:
        res += pathsums[curr_sum - target_sum]

    pathsums[curr_sum] += 1

    traverse(subtree.left, curr_sum)
    traverse(subtree.right, curr_sum)

    pathsums[curr_sum] -= 1

Our traversal helper function is now complete. Applying it to the main algorithm is pretty straightforward:

def get_num_of_matching_paths(tree, target_sum):
    res = 0
    pathsums = collections.defaultdict(int)

    def traverse(subtree, curr_sum=0):
        # Update pathsums, recurse down the tree
        # ...

    traverse(tree)
    return res

Given n number of nodes, our algorithm runs in O(n) space and O(n) time.

With that, we’ve solved the Binary Tree Path Sums problem 🌲🧮.

Full Solution

class Node:
    def __init__(self, val, left, right):
        self.val = val
        self.left = left
        self.right = right


def get_num_of_matching_paths(tree, target_sum):
    res = 0
    pathsums = collections.defaultdict(int)

    def traverse(subtree, curr_sum=0):
        nonlocal target_sum
        nonlocal res

        if not subtree:
            return

        curr_sum += subtree.val

        if curr_sum == target_sum:
            res += 1

        if curr_sum - target_sum in pathsums:
            res += pathsums[curr_sum-target_sum]

        pathsums[curr_sum] += 1

        traverse(subtree.left, curr_sum)
        traverse(subtree.right, curr_sum)

        pathsums[curr_sum] -= 1

    traverse(tree)
    return res