From LeetCode:
Design an array-like data structure that allows you to update elements and query subarray sums efficiently.
Input: [1, 3, 5] Output: sumRange(0, 2) # 9 update(1, 2) sumRange(0, 2) # 8
Most (if not all) modern software applications require us to persist and retrieve some kind of data. Being able to retrieve data in a robust manner usually means being able to query for information that’s relevant to a range of values.
In this problem, we’ll see how we can implement range sum queries efficiently using segment trees.
Segment Trees
Segment trees are augmented binary trees. Each node is labelled with a range, while each node’s left and right child will be labelled with the left and right halves of that range. Each node also stores the sum of the range it’s been labelled with:
Array: [5, 10, 3, 7]
"""
[0, 3]
(25)
/ \
[0, 1] [2, 3]
(15) (10)
/ \ / \
[0, 0] [1, 1] [2, 2] [3, 3]
(5) (10) (3) (7)
"""
Notice how the leaf nodes match the elements in our original array. Each leaf node represents a single data value. The non-leaf nodes, on the other hand, represent subarrays in the original array and provide us with information about ranges of values.
Let’s suppose we’re querying for the sum of [0, 2]
.
- From the root
[0, 3]
, we navigate to the left subtree and find the sum of[0, 1]
- From the root
[0, 3]
, we navigate to the right subtree[2, 3]
- From the subtree
[2, 3]
, we navigate to the left leaf and find the sum of[2, 2]
We’ve found the sum of [0, 1]
as well as the sum of [2, 2]
. Combining them gives us the sum of [0, 2]
.
There are more complex queries that we can consider, but I think it’ll be more beneficial to your understanding to see the tree traversal logic in code.
From Intuition to Code
Let’s define the data node that we’ll be using in our segment tree:
class Node:
def __init__(self, start, last):
self.start, self.last = start, last
self.left = self.right = None
self.total = 0
The first operation we’ll handle is the construction of our segment tree, NumArray
:
class NumArray:
def __init__(self, nums: List[int]):
# O(n) time operation for n data entries
def create(i, j):
if i > j:
return None
elif i == j:
# Leaf node
node = Node(i, j)
node.total = nums[i]
return node
# Find midpoint to split child ranges by
m = (i + j) // 2
node = Node(i, j)
# Construct child nodes and increment
# the total as we backtrack
node.left = create(i, m)
node.right = create(m+1, j)
node.total = node.left.total + node.right.total
return node
self.root = create(0, len(nums)-1)
Updating a record in the segment tree simply involves navigating from root to leaf –– specifically, to the leaf node that represents the record. Once we’ve updated the sum of the leaf, we’ll propagate the delta value and update the sum on all ancestor nodes:
class NumArray:
def update(self, i: int, val: int) -> None:
# O(log n) time operation
def helper(node, i, val):
if node.start == i and node.last == i:
diff = val - node.total
node.total = val
return diff
mid = (node.start + node.last) // 2
diff = (
helper(node.left, i, val)
if node.start <= i <= mid
else helper(node.right, i, val)
)
# Update ancestor sums as we backtrack
node.total += diff
return diff
helper(self.root, i, val)
Querying for the range sum seems complex, but it really isn’t once you break it down into a few distinct cases:
class NumArray:
def sumRange(self, i: int, j: int) -> int:
# O(log n) operation
def helper(node, i, j):
# Base Case: End-to-end match
if node.start == i and node.last == j:
return node.total
mid = (node.start + node.last) // 2
# Case 1: Whole range falls on right
if i <= mid and j <= mid:
return helper(node.left, i, j)
# Case 2: Whole range falls on left
elif mid < i and mid < j:
return helper(node.right, i, j)
# Case 3: Range is split between both subtrees
elif i <= mid and mid < j:
left = helper(node.left, i, mid)
right = helper(node.right, mid+1, j)
return left + right
return helper(self.root, i, j)
It might not be obvious at first that our range queries will run in O(log n)
. After all, in Case 3, the midpoint lies in the centre of our query range, so we can’t rule out one half of the subtree (which is atypical for logarithmic-time tree algorithms). Instead, we’ve to split our queries into two calls on both subtrees.
Runtime of a Query
Let a numerical suffix on a variable indicate the level of the tree we’re inspecting. Say we’re given a query range (i1, j1)
and we hit Case 3 –– (i1, j1)
becomes (i1, mid1)
and (mid1+1, j1)
.
Think about what happens at subsequent levels:
- On the left subtree, we have
(i2, j2)
and it must be thatj2 == last2
. - If
i2 <= mid
, we’ll hit the Base Case on this subtree’s right child and we’ll continue to explore the left side. - If
i2 > mid
, we’ll ignore this subtree’s left child and continue to explore the right side.
This tells us: At subsequent levels (after we encounter Case 3), we maintain a pattern of ruling out subtree halves. This keeps our range sum query within O(log n)
time.
Given an array of n
numbers, our NumArray
is initialised in O(n)
space and O(n)
time. Updates and queries run in O(log n)
space and O(log n)
time.
With that, we’ve solved the Querying Range Sums problem 📸🧮.
Full Solution
class Node:
def __init__(self, start, last):
self.start, self.last = start, last
self.left = self.right = None
self.total = 0
class NumArray:
def __init__(self, nums: List[int]):
# O(n) time operation for n data entries
def create(i, j):
if i > j:
return None
elif i == j:
# Leaf node
node = Node(i, j)
node.total = nums[i]
return node
# Find midpoint to split child ranges by
m = (i + j) // 2
node = Node(i, j)
# Construct child nodes and increment
# the total as we backtrack
node.left = create(i, m)
node.right = create(m+1, j)
node.total = node.left.total + node.right.total
return node
self.root = create(0, len(nums)-1)
def update(self, i: int, val: int) -> None:
# O(log n) time operation
def helper(node, i, val):
if node.start == i and node.last == i:
diff = val - node.total
node.total = val
return diff
mid = (node.start + node.last) // 2
diff = (
helper(node.left, i, val)
if node.start <= i <= mid
else helper(node.right, i, val)
)
# Update ancestor sums as we backtrack
node.total += diff
return diff
helper(self.root, i, val)
def sumRange(self, i: int, j: int) -> int:
# O(log n) time operation
def helper(node, i, j):
# Base Case: End-to-end match
if node.start == i and node.last == j:
return node.total
mid = (node.start + node.last) // 2
# Case 1: Whole range falls on right
if i <= mid and j <= mid:
return helper(node.left, i, j)
# Case 2: Whole range falls on left
elif mid < i and mid < j:
return helper(node.right, i, j)
# Case 3: Range is split between both subtrees
elif i <= mid and mid < j:
left = helper(node.left, i, mid)
right = helper(node.right, mid+1, j)
return left + right
return helper(self.root, i, j)