Yesterday, we dealt with binary trees. Today, we’ll look at a special subset –– binary search trees.
From EPI:
Merge two binary search trees into a single binary search tree.
Input: """ 5 / \ 3 7 4 / \ 2 6 """ Output: """ 5 / \ 3 7 / \ / 2 4 6 """
This question is a little bit tedious, but don’t worry, we’ll break it up into manageable chunks.
Your first intuition might be to:
- Initialise two pointers that start from the roots of both BSTs
- Run some comparison function to determine which node should be the new root
- Continue recursing downwards… in some manner?
You’ll find that this strategy falls apart very quickly. There are no actual guarantees that the resultant root will be one of the roots from the original BSTs.
Transforming BSTs
Instead, you’ll need to transform both BSTs into a different kind of data structure. This data structure isn’t terribly different from a BST –– it involves the use of nodes and double pointers. As you might’ve guessed, I’m speaking about doubly-linked lists.
Converting both BSTs into doubly-linked lists is beneficial as it allows us to merge both sets of data together (using a trivial, mergesort-style operation) while preserving the ordering of our data.
Let’s first pin down this conversion subroutine:
def convert_bst_to_list(tree):
def traverse(subtree):
if not subtree:
return None, None
head = tail = subtree
if subtree.left:
left_head, left_tail = traverse(subtree.left)
head = left_head
# left_tail <-> subtree
left_tail.right, subtree.left = subtree, left_tail
if subtree.right:
right_head, right_tail = traverse(subtree.right)
tail = right_tail
# subtree <-> right_head
subtree.right, right_head.left = right_head, subtree
head.left = tail.right = None
return head, tail
head, tail = traverse(tree)
return head
In this algorithm, notice how we’re returning head, tail
from each invocation of the traverse()
helper.
You can think of your leaf nodes as doubly-linked lists of length 1
. When you move up one level to your leaf’s parent, your goal becomes to connect this parent to both leaf nodes below it, resulting in a doubly-linked list of length 3
:
"""
4
/ \
3 5
becomes 3 <-> 4 <-> 5
"""
Merging Linked Lists
Our next step will be to combine the two lists while maintaining their sorted order.
If you’re familiar with the merge()
operation in mergesort, this step is easy. We’ll:
- Use two runners (i.e. running pointers) to move through the linked lists
- Let a third runner collect the nodes in sorted order
def merge_lists(L1, L2):
runner1 = L1
runner2 = L2
head = runner3 = Node() # Sentinel
while runner1 and runner2:
if runner1.data <= runner2.data:
runner1_right = runner1.right
runner3.right = runner1
runner1.left, runner1.right = runner3, None
runner1 = runner1_right
else:
runner2_right = runner2.right
runner3.right = runner2
runner2.left, runner2.right = runner3, None
runner2 = runner2_right
runner3 = runner3.right
if runner1:
runner3.right = runner1
runner1.left = runner3
elif runner2:
runner3.right = runner2
runner2.left = runner3
return head.right
Transforming a Linked List
The last part of our algorithm is quite possibly the trickiest.
We know that our intended output is another BST, so it almost goes without saying that our final step involves converting a doubly-linked list into a BST.
A natural way to go about this is to let a runner move (in 0.5 n
steps) to the middle of the list to mark the BST’s root. We then walk through the left and right sublists (in 2 * 0.25 n
steps) to find the respective subtree roots.
This turns out to be quite expensive (O n log n)
, as you’ll be repeatedly moving through the same nodes over and over again in order to access midpoints in each list and sublist.
Instead, you’ll want to construct your BST incrementally as you move down the linked list. By the time you’ve reached the midpoint of the list, the resultant BST’s left subtree should already be completely built:
def convert_list_to_bst(L):
# First pass: Get length
length = 0
runner = L
while runner:
length += 1
runner = runner.right
# Second pass: Build BST
runner = L
def traverse(start, last):
nonlocal runner
if not start <= last:
return None
# Note: `+1` below is optional, depends on
# whether you prefer to select the left
# middle node or the right middle node
mid = (start + last + 1) // 2
left = traverse(start, mid-1)
node = runner
runner = runner.right
right = traverse(mid+1, last)
node.left, node.right = left, right
return node
return traverse(0, length-1)
You’ll want to pass through the list once to get its full length. This allows you to determine the number of iterations required in your recursive calls. If you’re having trouble wrapping your head around the logic in the traverse()
helper, consider the following linked list:
"""
Index: 0 1 2 3 4
List: 1 <-> 3 <-> 5 <-> 7 <-> 9
Runner: ^
"""
Let’s walk through our first step –– returning the node at index 0 as a standalone subtree (i.e. a subtree of size 1):
traverse(0, 4)
invokesleft = traverse(0, 1)
traverse(0, 1)
invokesleft = traverse(0, 0)
traverse(0, 0)
invokesleft = traverse(0, -1)
, which immediately returnsNone
- We use
runner
and save a reference to the node at index 0 - We update
runner
and move it to the node at index 1 traverse(0, 0)
invokesright = traverse(1, 0)
, which also immediately returnsNone
- We assign
node.left
andnode.right
both toNone
- We return
node
as a standalone subtree
The node
we’ve returned here is propagated through the call stack and linked to its parent, which we subsequently return in traverse(0, 1)
.
Extending this logic forward, we see that as the runner moves from one list node to the next, the nodes on its left get reformed into BST nodes.
Putting All Three Steps Together
Putting our helper functions together gives us a clear picture of our final algorithm:
def merge_two_bsts(T1, T2):
L1 = convert_bst_to_list(T1)
L2 = convert_bst_to_list(T2)
L3 = merge_lists(L1, L2)
return convert_list_to_bst(L3)
Given n
number of nodes, our algorithm runs in O(log n)
space and O(n)
time.
With that, we’ve solved the Merging Binary Search Trees problem 🌲🌲.
Full Solution
def convert_bst_to_list(tree):
def traverse(subtree):
if not subtree:
return None, None
head = tail = subtree
if subtree.left:
left_head, left_tail = traverse(subtree.left)
head = left_head
# left_tail <-> subtree
left_tail.right, subtree.left = subtree, left_tail
if subtree.right:
right_head, right_tail = traverse(subtree.right)
tail = right_tail
# subtree <-> right_head
subtree.right, right_head.left = right_head, subtree
head.left = tail.right = None
return head, tail
head, tail = traverse(tree)
return head
def merge_lists(L1, L2):
runner1 = L1
runner2 = L2
head = runner3 = Node()
while runner1 and runner2:
if runner1.data <= runner2.data:
runner1_right = runner1.right
runner3.right = runner1
runner1.left, runner1.right = runner3, None
runner1 = runner1_right
else:
runner2_right = runner2.right
runner3.right = runner2
runner2.left, runner2.right = runner3, None
runner2 = runner2_right
runner3 = runner3.right
if runner1:
runner3.right = runner1
runner1.left = runner3
elif runner2:
runner3.right = runner2
runner2.left = runner3
return head.right
def convert_list_to_bst(L):
length = 0
runner = L
while runner:
length += 1
runner = runner.right
runner = L
def traverse(start, last):
nonlocal runner
if not start <= last:
return None
mid = (start + last + 1) // 2
left = traverse(start, mid-1)
node = runner
runner = runner.right
right = traverse(mid+1, last)
node.left, node.right = left, right
return node
return traverse(0, length-1)
def merge_two_bsts(T1, T2):
L1 = convert_bst_to_list(T1)
L2 = convert_bst_to_list(T2)
L3 = merge_lists(L1, L2)
return convert_list_to_bst(L3)