All Articles

Merging Binary Search Trees

Image from Unsplash by Colin Watts
Image from Unsplash by Colin Watts

Yesterday, we dealt with binary trees. Today, we’ll look at a special subset –– binary search trees.

From EPI:

Merge two binary search trees into a single binary search tree.

Input:
"""
  5
 / \ 
3   7

  4
 / \ 
2   6
"""

Output:
"""
     5
   /   \ 
  3     7
 / \   / 
2   4 6
"""

This question is a little bit tedious, but don’t worry, we’ll break it up into manageable chunks.

Your first intuition might be to:

  • Initialise two pointers that start from the roots of both BSTs
  • Run some comparison function to determine which node should be the new root
  • Continue recursing downwards… in some manner?

You’ll find that this strategy falls apart very quickly. There are no actual guarantees that the resultant root will be one of the roots from the original BSTs.

Transforming BSTs

Instead, you’ll need to transform both BSTs into a different kind of data structure. This data structure isn’t terribly different from a BST –– it involves the use of nodes and double pointers. As you might’ve guessed, I’m speaking about doubly-linked lists.

Converting both BSTs into doubly-linked lists is beneficial as it allows us to merge both sets of data together (using a trivial, mergesort-style operation) while preserving the ordering of our data.

Let’s first pin down this conversion subroutine:

def convert_bst_to_list(tree):
    def traverse(subtree):
        if not subtree:
            return None, None

        head = tail = subtree
        if subtree.left:
            left_head, left_tail = traverse(subtree.left)
            head = left_head
            # left_tail <-> subtree
            left_tail.right, subtree.left = subtree, left_tail

        if subtree.right:
            right_head, right_tail = traverse(subtree.right)
            tail = right_tail
            # subtree <-> right_head
            subtree.right, right_head.left = right_head, subtree

        head.left = tail.right = None
        return head, tail

    head, tail = traverse(tree)
    return head

In this algorithm, notice how we’re returning head, tail from each invocation of the traverse() helper.

You can think of your leaf nodes as doubly-linked lists of length 1. When you move up one level to your leaf’s parent, your goal becomes to connect this parent to both leaf nodes below it, resulting in a doubly-linked list of length 3:

"""
   4
  / \ 
 3   5 
 
becomes 3 <-> 4 <-> 5
"""

Merging Linked Lists

Our next step will be to combine the two lists while maintaining their sorted order.

If you’re familiar with the merge() operation in mergesort, this step is easy. We’ll:

  • Use two runners (i.e. running pointers) to move through the linked lists
  • Let a third runner collect the nodes in sorted order
def merge_lists(L1, L2):
    runner1 = L1
    runner2 = L2
    head = runner3 = Node() # Sentinel

    while runner1 and runner2:
        if runner1.data <= runner2.data:
            runner1_right = runner1.right

            runner3.right = runner1
            runner1.left, runner1.right = runner3, None

            runner1 = runner1_right

        else:
            runner2_right = runner2.right

            runner3.right = runner2
            runner2.left, runner2.right = runner3, None

            runner2 = runner2_right

        runner3 = runner3.right

    if runner1:
        runner3.right = runner1
        runner1.left = runner3

    elif runner2:
        runner3.right = runner2
        runner2.left = runner3

    return head.right

Transforming a Linked List

The last part of our algorithm is quite possibly the trickiest.

We know that our intended output is another BST, so it almost goes without saying that our final step involves converting a doubly-linked list into a BST.

A natural way to go about this is to let a runner move (in 0.5 n steps) to the middle of the list to mark the BST’s root. We then walk through the left and right sublists (in 2 * 0.25 n steps) to find the respective subtree roots.

This turns out to be quite expensive (O n log n), as you’ll be repeatedly moving through the same nodes over and over again in order to access midpoints in each list and sublist.

Instead, you’ll want to construct your BST incrementally as you move down the linked list. By the time you’ve reached the midpoint of the list, the resultant BST’s left subtree should already be completely built:

def convert_list_to_bst(L):

    # First pass: Get length    
    length = 0
    runner = L
    while runner:
        length += 1
        runner = runner.right

    # Second pass: Build BST
    runner = L

    def traverse(start, last):
        nonlocal runner

        if not start <= last:
            return None

        # Note: `+1` below is optional, depends on
        # whether you prefer to select the left
        # middle node or the right middle node
        mid = (start + last + 1) // 2
        left = traverse(start, mid-1)

        node = runner
        runner = runner.right

        right = traverse(mid+1, last)

        node.left, node.right = left, right
        return node

    return traverse(0, length-1)

You’ll want to pass through the list once to get its full length. This allows you to determine the number of iterations required in your recursive calls. If you’re having trouble wrapping your head around the logic in the traverse() helper, consider the following linked list:

"""
Index:  0     1     2     3     4
List:   1 <-> 3 <-> 5 <-> 7 <-> 9
Runner: ^
"""

Let’s walk through our first step –– returning the node at index 0 as a standalone subtree (i.e. a subtree of size 1):

  • traverse(0, 4) invokes left = traverse(0, 1)
  • traverse(0, 1) invokes left = traverse(0, 0)
  • traverse(0, 0) invokes left = traverse(0, -1), which immediately returns None
  • We use runner and save a reference to the node at index 0
  • We update runner and move it to the node at index 1
  • traverse(0, 0) invokes right = traverse(1, 0), which also immediately returns None
  • We assign node.left and node.right both to None
  • We return node as a standalone subtree

The node we’ve returned here is propagated through the call stack and linked to its parent, which we subsequently return in traverse(0, 1).

Extending this logic forward, we see that as the runner moves from one list node to the next, the nodes on its left get reformed into BST nodes.

Putting All Three Steps Together

Putting our helper functions together gives us a clear picture of our final algorithm:

def merge_two_bsts(T1, T2):
    L1 = convert_bst_to_list(T1)
    L2 = convert_bst_to_list(T2)
    L3 = merge_lists(L1, L2)
    return convert_list_to_bst(L3)

Given n number of nodes, our algorithm runs in O(log n) space and O(n) time.

With that, we’ve solved the Merging Binary Search Trees problem 🌲🌲.

Full Solution

def convert_bst_to_list(tree):
    def traverse(subtree):
        if not subtree:
            return None, None

        head = tail = subtree
        if subtree.left:
            left_head, left_tail = traverse(subtree.left)
            head = left_head
            # left_tail <-> subtree
            left_tail.right, subtree.left = subtree, left_tail

        if subtree.right:
            right_head, right_tail = traverse(subtree.right)
            tail = right_tail
            # subtree <-> right_head
            subtree.right, right_head.left = right_head, subtree

        head.left = tail.right = None
        return head, tail

    head, tail = traverse(tree)
    return head


def merge_lists(L1, L2):

    runner1 = L1
    runner2 = L2
    head = runner3 = Node()

    while runner1 and runner2:
        if runner1.data <= runner2.data:
            runner1_right = runner1.right

            runner3.right = runner1
            runner1.left, runner1.right = runner3, None

            runner1 = runner1_right

        else:
            runner2_right = runner2.right

            runner3.right = runner2
            runner2.left, runner2.right = runner3, None

            runner2 = runner2_right

        runner3 = runner3.right

    if runner1:
        runner3.right = runner1
        runner1.left = runner3

    elif runner2:
        runner3.right = runner2
        runner2.left = runner3

    return head.right


def convert_list_to_bst(L):

    length = 0
    runner = L
    while runner:
        length += 1
        runner = runner.right

    runner = L

    def traverse(start, last):
        nonlocal runner

        if not start <= last:
            return None

        mid = (start + last + 1) // 2
        left = traverse(start, mid-1)

        node = runner
        runner = runner.right

        right = traverse(mid+1, last)

        node.left, node.right = left, right
        return node

    return traverse(0, length-1)


def merge_two_bsts(T1, T2):
    L1 = convert_bst_to_list(T1)
    L2 = convert_bst_to_list(T2)
    L3 = merge_lists(L1, L2)
    return convert_list_to_bst(L3)