From LeetCode:
Design a matrix-like data structure that allows you to update elements and query submatrix sums efficiently.
Input: [ [3, 0, 1, 4, 2], [5, 6, 3, 2, 1], [1, 2, 0, 1, 5], [4, 1, 0, 1, 7], [1, 0, 3, 0, 5] ] Output: sumRegion(2, 1, 4, 3) # 8 update(3, 2, 2) sumRegion(2, 1, 4, 3) # 10
Previously, we tackled the Querying Range Sums problem and used a segment tree (i.e. augmented binary trees with range labels on nodes) to support logarithmic-time operations. In this post, we’ll see how we can do the same for a 2D array / matrix.
I recommend going through the first version of this problem before continuing this article.
Before we proceed, it’s also worth mentioning that the solution I’m about to share may not be producible within a 30 - 45 minute timeframe. In other words, this doesn’t seem like the kind of question you’re likely to encounter in a live coding exercise. Nevertheless, I think it’s still good practice to 1) get the intuition down and 2) get our hands dirty with complex data stuctures.
Tree of Trees
We know that we can use a segment tree to represent a 1D array, but what if we needed to represent two dimensions of data? It turns out that we can store pointers to segment trees as ordinary node values within an overarching segment tree:
In this diagram, each (square) node contains its own segment tree. For clarity, we’ll refer to the top-level segment tree as the “parent tree” and refer to the other segment trees as “nested trees”.
The range labels displayed in the parent tree correspond to the row indices of our matrix. [0, 0]
, for instance, refers to row 0. The nested tree for row 0 would appear like this:
Row_0: [5, 6, 3, 2]
The range labels in this nested tree correspond to the column indices of our matrix. If we were specifically interested in the node representing the value at row 0 and column 0, we’d look up the node labelled [0, 0]
in the parent tree and then find the node labeled [0, 0]
in the nested tree. As seen in the nested tree above, this gives us 5
.
What about non-leaf nodes on the parent tree, such as [0, 1]
? The nested tree stored at [0, 1]
will contain the pairwise sums of the nodes in [0, 0]
and [1, 1]
:
Row_0: [5, 6, 3, 2]
Row_1: [1, 2, 0, 1]
Row_0 + Row_1: [6, 8, 3, 3]
You can imagine this nested tree to be a merged version of the nested trees for row 0 and row 1.
Let’s see how this nested tree structure supports our queries. Suppose we’re interested in the submatrix from (r, c) == (0, 0)
to (r, c) == (1, 1)
(inclusive).
- From the parent tree’s root
[0, 3]
, we navigate to the left node[0, 1]
- We explore the nested tree contained in
[0, 1]
. - From the nested tree’s root
[0, 3]
, we navigate to the left node[0, 1]
The value found at [0, 1]
(within the nested tree) will be the sum of (0, 0)
to (1, 1)
.
From Intuition to Code
Our class definition for a segment tree node stays largely the same:
class Node:
def __init__(self, start, last):
self.start, self.last = start, last
self.left = self.right = None
# Points to integer OR segment tree
self.v = None
Let’s walk through each of the major segment tree operations one by one. Most of these are adapted from the operations we defined in the 1D variant of our current problem.
Tree Construction
class SegmentTree:
def __init__(self):
self.root = None
def create(self, ref_data):
def helper(i, j):
if i > j:
return None
elif i == j:
node = Node(i, i)
if isinstance(ref_data[i], list):
node.v = SegmentTree()
node.v.create(ref_data[i])
else:
node.v = ref_data[i]
return node
m = (i + j) // 2
node = Node(i, j)
node.left = helper(i, m)
node.right = helper(m+1, j)
node.v = merge_vals(node.left.v, node.right.v)
return node
self.root = helper(0, len(ref_data)-1)
ref_data
simply describes the reference data we’re using for our tree construction. If we’re building out the parent tree, then this data should point to the input matrix we’ve been given. If we’re building out the nested tree, then ref_data
should be one of the rows in the matrix.
The utility function merge_vals
merges two “values” together. Recall that values can be either integers or entire segment trees, so this function accounts for that:
def merge_vals(v1, v2):
def merge_nodes(n1, n2):
if not n1 and not n2:
return None
merged_node = Node(n1.start, n1.last)
merged_node.left = merge_nodes(n1.left, n2.left)
merged_node.right = merge_nodes(n1.right, n2.right)
merged_node.v = n1.v + n2.v
return merged_node
if (
isinstance(v1, SegmentTree) and
isinstance(v2, SegmentTree)
):
st = SegmentTree()
st.root = merge_nodes(v1.root, v2.root)
return st
else:
return v1 + v2
It’s worth noting here that two nested trees will share the same structure. Thus, when we merge them, there’s no need to worry about a node being present in one nested tree and missing in another.
Tree Lookups
Looking up an element entails navigating the parent tree from root to leaf and navigating the nested tree from root to leaf.
def lookup(self, r, c):
def helper(node, i):
nonlocal c
# Leaf case
if node.start == i and node.last == i:
if isinstance(node.v, SegmentTree):
return helper(node.v.root, c)
else:
return node.v
# General case
m = (node.start + node.last) // 2
if 0 <= i <= m:
return helper(node.left, i)
return helper(node.right, i)
return helper(self.root, r)
We didn’t define an element lookup function in the 1D variant of the problem, but we’ll quickly see why this is useful for our current setup.
Tree Updates
Updating an element consists of two steps:
- Looking up the current value
- Calculating the difference
- Capturing this difference in all affected nested trees
def update(self, r, c, val):
def helper(node, i, diff):
nonlocal c
if node:
if isinstance(node.v, SegmentTree):
helper(node.v.root, c, diff)
else:
node.v += diff
m = (node.start + node.last) // 2
if 0 <= i <= m:
helper(node.left, i, diff)
else:
helper(node.right, i, diff)
diff = val - self.lookup(r, c)
return helper(self.root, r, diff)
This is why we defined a lookup()
method on our Segment Tree in the first place –– it enables us to compute the difference before passing the delta value on to affected nodes.
Tree Queries
Our logic for tree querying is somewhat similar to our logic for looking up elements. Querying for range sums in a 2D matrix becomes the same as querying for range sums in a 1D array, except that we traverse two trees instead of one.
def query(self, r1, c1, r2, c2):
def helper(node, i, j):
nonlocal c1
nonlocal c2
# Base case: End-to-end match
if node.start == i and node.last == j:
if isinstance(node.v, SegmentTree):
return helper(node.v.root, c1, c2)
else:
return node.v
m = (node.start + node.last) // 2
# Case 1: Query falls on left
if i <= m and j <= m:
return helper(node.left, i, j)
# Case 2: Query falls on right
elif m < i and m < j:
return helper(node.right, i, j)
# Case 3: Query falls on both sides
elif i <= m and m < j:
left = helper(node.left, i, m)
right = helper(node.right, m+1, j)
return left + right
return helper(self.root, r1, r2)
Applying a Segment Tree to the Matrix
With our SegmentTree
class defined, we can easily configure it to support our NumMatrix
data structure:
class NumMatrix:
def __init__(self, matrix: List[List[int]]):
self.st = SegmentTree()
self.st.create(matrix)
def update(self, row: int, col: int, val: int) -> None:
self.st.update(row, col, val)
def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
return self.st.query(row1, col1, row2, col2)
Given a matrix of n
rows and m
columns, our NumMatrix
is initialised in O(nm)
space and O(nm)
time. Updates require O(log n + log m)
space and O(log n * log m) time, while queries run in O(log n + log m)
space and O(log n + log m)
time.
With that, we’ve solved the Querying Range Sums (2D) problem 📸🧮📸.
Full Solution
class Node:
def __init__(self, start, last):
self.start, self.last = start, last
self.left = self.right = None
# Points to integer OR segment tree
self.v = None
# Utils
def merge_vals(v1, v2):
def merge_nodes(n1, n2):
if not n1 and not n2:
return None
merged_node = Node(n1.start, n1.last)
merged_node.left = merge_nodes(n1.left, n2.left)
merged_node.right = merge_nodes(n1.right, n2.right)
merged_node.v = n1.v + n2.v
return merged_node
if (
isinstance(v1, SegmentTree) and
isinstance(v2, SegmentTree)
):
st = SegmentTree()
st.root = merge_nodes(v1.root, v2.root)
return st
else:
return v1 + v2
class SegmentTree:
def __init__(self):
self.root = None
def create(self, ref_data):
def helper(i, j):
if i > j:
return None
elif i == j:
node = Node(i, i)
if isinstance(ref_data[i], list):
node.v = SegmentTree()
node.v.create(ref_data[i])
else:
node.v = ref_data[i]
return node
m = (i + j) // 2
node = Node(i, j)
node.left = helper(i, m)
node.right = helper(m+1, j)
node.v = merge_vals(node.left.v, node.right.v)
return node
self.root = helper(0, len(ref_data)-1)
def lookup(self, r, c):
def helper(node, i):
nonlocal c
# Leaf case
if node.start == i and node.last == i:
if isinstance(node.v, SegmentTree):
return helper(node.v.root, c)
else:
return node.v
# General case
m = (node.start + node.last) // 2
if 0 <= i <= m:
return helper(node.left, i)
return helper(node.right, i)
return helper(self.root, r)
def update(self, r, c, val):
def helper(node, i, diff):
nonlocal c
if node:
if isinstance(node.v, SegmentTree):
helper(node.v.root, c, diff)
else:
node.v += diff
m = (node.start + node.last) // 2
if 0 <= i <= m:
helper(node.left, i, diff)
else:
helper(node.right, i, diff)
diff = val - self.lookup(r, c)
return helper(self.root, r, diff)
def query(self, r1, c1, r2, c2):
def helper(node, i, j):
nonlocal c1
nonlocal c2
# Base case: End-to-end match
if node.start == i and node.last == j:
if isinstance(node.v, SegmentTree):
return helper(node.v.root, c1, c2)
else:
return node.v
m = (node.start + node.last) // 2
# Case 1: Query falls on left
if i <= m and j <= m:
return helper(node.left, i, j)
# Case 2: Query falls on right
elif m < i and m < j:
return helper(node.right, i, j)
# Case 3: Query falls on both sides
elif i <= m and m < j:
left = helper(node.left, i, m)
right = helper(node.right, m+1, j)
return left + right
return helper(self.root, r1, r2)
class NumMatrix:
def __init__(self, matrix: List[List[int]]):
self.st = SegmentTree()
self.st.create(matrix)
def update(self, row: int, col: int, val: int) -> None:
self.st.update(row, col, val)
def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
return self.st.query(row1, col1, row2, col2)