All Articles

Querying Range Sums (2D)

Image from Unsplash by Elliot Teo
Image from Unsplash by Elliot Teo

From LeetCode:

Design a matrix-like data structure that allows you to update elements and query submatrix sums efficiently.

Input: [
    [3, 0, 1, 4, 2],
    [5, 6, 3, 2, 1],
    [1, 2, 0, 1, 5],
    [4, 1, 0, 1, 7],
    [1, 0, 3, 0, 5]
]

Output:
sumRegion(2, 1, 4, 3) # 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) # 10

Previously, we tackled the Querying Range Sums problem and used a segment tree (i.e. augmented binary trees with range labels on nodes) to support logarithmic-time operations. In this post, we’ll see how we can do the same for a 2D array / matrix.

I recommend going through the first version of this problem before continuing this article.

Before we proceed, it’s also worth mentioning that the solution I’m about to share may not be producible within a 30 - 45 minute timeframe. In other words, this doesn’t seem like the kind of question you’re likely to encounter in a live coding exercise. Nevertheless, I think it’s still good practice to 1) get the intuition down and 2) get our hands dirty with complex data stuctures.

Tree of Trees

We know that we can use a segment tree to represent a 1D array, but what if we needed to represent two dimensions of data? It turns out that we can store pointers to segment trees as ordinary node values within an overarching segment tree:

Tree of Segment Trees. Nodes in the parent tree are drawn as squares and nodes in the nested trees are drawn as circles.
Tree of Segment Trees. Nodes in the parent tree are drawn as squares and nodes in the nested trees are drawn as circles.

In this diagram, each (square) node contains its own segment tree. For clarity, we’ll refer to the top-level segment tree as the “parent tree” and refer to the other segment trees as “nested trees”.

The range labels displayed in the parent tree correspond to the row indices of our matrix. [0, 0], for instance, refers to row 0. The nested tree for row 0 would appear like this:

Row_0: [5, 6, 3, 2]

Nested Tree for Row 0
Nested Tree for Row 0

The range labels in this nested tree correspond to the column indices of our matrix. If we were specifically interested in the node representing the value at row 0 and column 0, we’d look up the node labelled [0, 0] in the parent tree and then find the node labeled [0, 0] in the nested tree. As seen in the nested tree above, this gives us 5.

What about non-leaf nodes on the parent tree, such as [0, 1]? The nested tree stored at [0, 1] will contain the pairwise sums of the nodes in [0, 0] and [1, 1]:

Row_0:         [5, 6, 3, 2]
Row_1:         [1, 2, 0, 1]
Row_0 + Row_1: [6, 8, 3, 3]

Nested Tree for Row 0 to Row 1
Nested Tree for Row 0 to Row 1

You can imagine this nested tree to be a merged version of the nested trees for row 0 and row 1.

Let’s see how this nested tree structure supports our queries. Suppose we’re interested in the submatrix from (r, c) == (0, 0) to (r, c) == (1, 1) (inclusive).

Navigating the tree of trees to find the sum of elements from (0, 0) to (1, 1)
Navigating the tree of trees to find the sum of elements from (0, 0) to (1, 1)

  1. From the parent tree’s root [0, 3], we navigate to the left node [0, 1]
  2. We explore the nested tree contained in [0, 1].
  3. From the nested tree’s root [0, 3], we navigate to the left node [0, 1]

The value found at [0, 1] (within the nested tree) will be the sum of (0, 0) to (1, 1).

From Intuition to Code

Our class definition for a segment tree node stays largely the same:

class Node:
    def __init__(self, start, last):
        self.start, self.last = start, last
        self.left = self.right = None
        # Points to integer OR segment tree
        self.v = None 

Let’s walk through each of the major segment tree operations one by one. Most of these are adapted from the operations we defined in the 1D variant of our current problem.

Tree Construction

class SegmentTree:
    def __init__(self):
        self.root = None

    def create(self, ref_data):
        def helper(i, j):
            if i > j:
                return None
            elif i == j:
                node = Node(i, i)
                if isinstance(ref_data[i], list):
                    node.v = SegmentTree()
                    node.v.create(ref_data[i])
                else:
                    node.v = ref_data[i]
                return node

            m = (i + j) // 2
            node = Node(i, j)
            node.left = helper(i, m)
            node.right = helper(m+1, j)
            node.v = merge_vals(node.left.v, node.right.v)
            return node

        self.root = helper(0, len(ref_data)-1)

ref_data simply describes the reference data we’re using for our tree construction. If we’re building out the parent tree, then this data should point to the input matrix we’ve been given. If we’re building out the nested tree, then ref_data should be one of the rows in the matrix.

The utility function merge_vals merges two “values” together. Recall that values can be either integers or entire segment trees, so this function accounts for that:

def merge_vals(v1, v2):
    def merge_nodes(n1, n2):
        if not n1 and not n2:
            return None

        merged_node = Node(n1.start, n1.last)
        merged_node.left = merge_nodes(n1.left, n2.left)
        merged_node.right = merge_nodes(n1.right, n2.right)
        merged_node.v = n1.v + n2.v
        return merged_node

    if (
        isinstance(v1, SegmentTree) and
        isinstance(v2, SegmentTree)
    ):
        st = SegmentTree()
        st.root = merge_nodes(v1.root, v2.root)
        return st

    else:
        return v1 + v2

It’s worth noting here that two nested trees will share the same structure. Thus, when we merge them, there’s no need to worry about a node being present in one nested tree and missing in another.

Tree Lookups

Looking up an element entails navigating the parent tree from root to leaf and navigating the nested tree from root to leaf.

def lookup(self, r, c):
    def helper(node, i):
        nonlocal c

        # Leaf case
        if node.start == i and node.last == i:
            if isinstance(node.v, SegmentTree):
                return helper(node.v.root, c)
            else:
                return node.v

        # General case
        m = (node.start + node.last) // 2
        if 0 <= i <= m:
            return helper(node.left, i)
        return helper(node.right, i)

    return helper(self.root, r)

We didn’t define an element lookup function in the 1D variant of the problem, but we’ll quickly see why this is useful for our current setup.

Tree Updates

Updating an element consists of two steps:

  • Looking up the current value
  • Calculating the difference
  • Capturing this difference in all affected nested trees
def update(self, r, c, val):
    def helper(node, i, diff):
        nonlocal c
        
        if node:
            if isinstance(node.v, SegmentTree):
                helper(node.v.root, c, diff)
            else:
                node.v += diff

            m = (node.start + node.last) // 2
            if 0 <= i <= m:
                helper(node.left, i, diff)
            else:
                helper(node.right, i, diff)

    diff = val - self.lookup(r, c)
    return helper(self.root, r, diff)

This is why we defined a lookup() method on our Segment Tree in the first place –– it enables us to compute the difference before passing the delta value on to affected nodes.

Tree Queries

Our logic for tree querying is somewhat similar to our logic for looking up elements. Querying for range sums in a 2D matrix becomes the same as querying for range sums in a 1D array, except that we traverse two trees instead of one.

def query(self, r1, c1, r2, c2):
    def helper(node, i, j):
        nonlocal c1
        nonlocal c2
        
        # Base case: End-to-end match
        if node.start == i and node.last == j:
            if isinstance(node.v, SegmentTree):
                return helper(node.v.root, c1, c2)
            else:
                return node.v

        m = (node.start + node.last) // 2

        # Case 1: Query falls on left
        if i <= m and j <= m:
            return helper(node.left, i, j)
        
        # Case 2: Query falls on right
        elif m < i and m < j:
            return helper(node.right, i, j)
        
        # Case 3: Query falls on both sides
        elif i <= m and m < j:
            left = helper(node.left, i, m)
            right = helper(node.right, m+1, j)
            return left + right

    return helper(self.root, r1, r2)

Applying a Segment Tree to the Matrix

With our SegmentTree class defined, we can easily configure it to support our NumMatrix data structure:

class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        self.st = SegmentTree()
        self.st.create(matrix)

    def update(self, row: int, col: int, val: int) -> None:
        self.st.update(row, col, val)

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        return self.st.query(row1, col1, row2, col2)

Given a matrix of n rows and m columns, our NumMatrix is initialised in O(nm) space and O(nm) time. Updates require O(log n + log m) space and O(log n * log m) time, while queries run in O(log n + log m) space and O(log n + log m) time.

With that, we’ve solved the Querying Range Sums (2D) problem 📸🧮📸.

Full Solution

class Node:
    def __init__(self, start, last):
        self.start, self.last = start, last
        self.left = self.right = None
        # Points to integer OR segment tree
        self.v = None


# Utils

def merge_vals(v1, v2):
    def merge_nodes(n1, n2):
        if not n1 and not n2:
            return None

        merged_node = Node(n1.start, n1.last)
        merged_node.left = merge_nodes(n1.left, n2.left)
        merged_node.right = merge_nodes(n1.right, n2.right)
        merged_node.v = n1.v + n2.v
        return merged_node

    if (
        isinstance(v1, SegmentTree) and
        isinstance(v2, SegmentTree)
    ):
        st = SegmentTree()
        st.root = merge_nodes(v1.root, v2.root)
        return st

    else:
        return v1 + v2


    
class SegmentTree:
    def __init__(self):
        self.root = None

    def create(self, ref_data):
        def helper(i, j):
            if i > j:
                return None
            elif i == j:
                node = Node(i, i)
                if isinstance(ref_data[i], list):
                    node.v = SegmentTree()
                    node.v.create(ref_data[i])
                else:
                    node.v = ref_data[i]
                return node

            m = (i + j) // 2
            node = Node(i, j)
            node.left = helper(i, m)
            node.right = helper(m+1, j)
            node.v = merge_vals(node.left.v, node.right.v)
            return node

        self.root = helper(0, len(ref_data)-1)

    def lookup(self, r, c):
        def helper(node, i):
            nonlocal c

            # Leaf case
            if node.start == i and node.last == i:
                if isinstance(node.v, SegmentTree):
                    return helper(node.v.root, c)
                else:
                    return node.v

            # General case
            m = (node.start + node.last) // 2
            if 0 <= i <= m:
                return helper(node.left, i)
            return helper(node.right, i)

        return helper(self.root, r)

    def update(self, r, c, val):
        def helper(node, i, diff):
            nonlocal c
            
            if node:
                if isinstance(node.v, SegmentTree):
                    helper(node.v.root, c, diff)
                else:
                    node.v += diff

                m = (node.start + node.last) // 2
                if 0 <= i <= m:
                    helper(node.left, i, diff)
                else:
                    helper(node.right, i, diff)

        diff = val - self.lookup(r, c)
        return helper(self.root, r, diff)

    
    def query(self, r1, c1, r2, c2):
        def helper(node, i, j):
            nonlocal c1
            nonlocal c2
            
            # Base case: End-to-end match
            if node.start == i and node.last == j:
                if isinstance(node.v, SegmentTree):
                    return helper(node.v.root, c1, c2)
                else:
                    return node.v

            m = (node.start + node.last) // 2

            # Case 1: Query falls on left
            if i <= m and j <= m:
                return helper(node.left, i, j)
            
            # Case 2: Query falls on right
            elif m < i and m < j:
                return helper(node.right, i, j)
            
            # Case 3: Query falls on both sides
            elif i <= m and m < j:
                left = helper(node.left, i, m)
                right = helper(node.right, m+1, j)
                return left + right

        return helper(self.root, r1, r2)


class NumMatrix:

    def __init__(self, matrix: List[List[int]]):
        self.st = SegmentTree()
        self.st.create(matrix)

    def update(self, row: int, col: int, val: int) -> None:
        self.st.update(row, col, val)

    def sumRegion(self, row1: int, col1: int, row2: int, col2: int) -> int:
        return self.st.query(row1, col1, row2, col2)