All Articles

Querying Range Sums

Image from Unsplash by Suganth
Image from Unsplash by Suganth

From LeetCode:

Design an array-like data structure that allows you to update elements and query subarray sums efficiently.

Input: [1, 3, 5]

Output:
sumRange(0, 2) # 9
update(1, 2)
sumRange(0, 2) # 8

Most (if not all) modern software applications require us to persist and retrieve some kind of data. Being able to retrieve data in a robust manner usually means being able to query for information that’s relevant to a range of values.

In this problem, we’ll see how we can implement range sum queries efficiently using segment trees.

Segment Trees

Segment trees are augmented binary trees. Each node is labelled with a range, while each node’s left and right child will be labelled with the left and right halves of that range. Each node also stores the sum of the range it’s been labelled with:

Array: [5, 10, 3, 7]

"""
           [0, 3]
            (25)
           /     \ 
     [0, 1]       [2, 3]
      (15)         (10)
     /    \       /    \ 
[0, 0] [1, 1]   [2, 2] [3, 3]
 (5)    (10)     (3)     (7)
"""

Notice how the leaf nodes match the elements in our original array. Each leaf node represents a single data value. The non-leaf nodes, on the other hand, represent subarrays in the original array and provide us with information about ranges of values.

Let’s suppose we’re querying for the sum of [0, 2].

  1. From the root [0, 3], we navigate to the left subtree and find the sum of [0, 1]
  2. From the root [0, 3], we navigate to the right subtree [2, 3]
  3. From the subtree [2, 3], we navigate to the left leaf and find the sum of [2, 2]

We’ve found the sum of [0, 1] as well as the sum of [2, 2]. Combining them gives us the sum of [0, 2].

There are more complex queries that we can consider, but I think it’ll be more beneficial to your understanding to see the tree traversal logic in code.

From Intuition to Code

Let’s define the data node that we’ll be using in our segment tree:

class Node:
    def __init__(self, start, last):
        self.start, self.last = start, last
        self.left = self.right = None
        self.total = 0

The first operation we’ll handle is the construction of our segment tree, NumArray:

class NumArray:
    def __init__(self, nums: List[int]):
        # O(n) time operation for n data entries
        def create(i, j):
            if i > j:
                return None

            elif i == j:
                # Leaf node
                node = Node(i, j)
                node.total = nums[i]
                return node
            
            # Find midpoint to split child ranges by
            m = (i + j) // 2
            node = Node(i, j)
            
            # Construct child nodes and increment
            # the total as we backtrack
            node.left = create(i, m)
            node.right = create(m+1, j)
            node.total = node.left.total + node.right.total
            return node

        self.root = create(0, len(nums)-1)

Updating a record in the segment tree simply involves navigating from root to leaf –– specifically, to the leaf node that represents the record. Once we’ve updated the sum of the leaf, we’ll propagate the delta value and update the sum on all ancestor nodes:

class NumArray:
    def update(self, i: int, val: int) -> None:
        # O(log n) time operation
        def helper(node, i, val):
            if node.start == i and node.last == i:
                diff = val - node.total
                node.total = val
                return diff

            mid = (node.start + node.last) // 2
            diff = (
                helper(node.left, i, val)
                if node.start <= i <= mid
                else helper(node.right, i, val)
            )

            # Update ancestor sums as we backtrack
            node.total += diff
            return diff

        helper(self.root, i, val)

Querying for the range sum seems complex, but it really isn’t once you break it down into a few distinct cases:

class NumArray:
    def sumRange(self, i: int, j: int) -> int:        
        # O(log n) operation
        def helper(node, i, j):
            # Base Case: End-to-end match
            if node.start == i and node.last == j:
                return node.total

            mid = (node.start + node.last) // 2
            
            # Case 1: Whole range falls on right            
            if i <= mid and j <= mid:
                return helper(node.left, i, j)
            
            # Case 2: Whole range falls on left
            elif mid < i and mid < j:
                return helper(node.right, i, j)

            # Case 3: Range is split between both subtrees
            elif i <= mid and mid < j:
                left = helper(node.left, i, mid)
                right = helper(node.right, mid+1, j)
                return left + right
          
        return helper(self.root, i, j)

It might not be obvious at first that our range queries will run in O(log n). After all, in Case 3, the midpoint lies in the centre of our query range, so we can’t rule out one half of the subtree (which is atypical for logarithmic-time tree algorithms). Instead, we’ve to split our queries into two calls on both subtrees.

Runtime of a Query

Let a numerical suffix on a variable indicate the level of the tree we’re inspecting. Say we’re given a query range (i1, j1) and we hit Case 3 –– (i1, j1) becomes (i1, mid1) and (mid1+1, j1).

Think about what happens at subsequent levels:

  • On the left subtree, we have (i2, j2) and it must be that j2 == last2.
  • If i2 <= mid, we’ll hit the Base Case on this subtree’s right child and we’ll continue to explore the left side.
  • If i2 > mid, we’ll ignore this subtree’s left child and continue to explore the right side.

This tells us: At subsequent levels (after we encounter Case 3), we maintain a pattern of ruling out subtree halves. This keeps our range sum query within O(log n) time.

Given an array of n numbers, our NumArray is initialised in O(n) space and O(n) time. Updates and queries run in O(log n) space and O(log n) time.

With that, we’ve solved the Querying Range Sums problem 📸🧮.

Full Solution

class Node:
    def __init__(self, start, last):
        self.start, self.last = start, last
        self.left = self.right = None        
        self.total = 0


class NumArray:
    def __init__(self, nums: List[int]):
        # O(n) time operation for n data entries
        def create(i, j):
            if i > j:
                return None

            elif i == j:
                # Leaf node
                node = Node(i, j)
                node.total = nums[i]
                return node
            
            # Find midpoint to split child ranges by
            m = (i + j) // 2
            node = Node(i, j)
            
            # Construct child nodes and increment
            # the total as we backtrack
            node.left = create(i, m)
            node.right = create(m+1, j)
            node.total = node.left.total + node.right.total
            return node

        self.root = create(0, len(nums)-1)


    def update(self, i: int, val: int) -> None:
        # O(log n) time operation
        def helper(node, i, val):
            if node.start == i and node.last == i:
                diff = val - node.total
                node.total = val
                return diff

            mid = (node.start + node.last) // 2
            diff = (
                helper(node.left, i, val)
                if node.start <= i <= mid
                else helper(node.right, i, val)
            )

            # Update ancestor sums as we backtrack
            node.total += diff
            return diff

        helper(self.root, i, val)


    def sumRange(self, i: int, j: int) -> int:
        # O(log n) time operation
        def helper(node, i, j):
            # Base Case: End-to-end match
            if node.start == i and node.last == j:
                return node.total

            mid = (node.start + node.last) // 2
            
            # Case 1: Whole range falls on right            
            if i <= mid and j <= mid:
                return helper(node.left, i, j)
            
            # Case 2: Whole range falls on left
            elif mid < i and mid < j:
                return helper(node.right, i, j)

            # Case 3: Range is split between both subtrees
            elif i <= mid and mid < j:
                left = helper(node.left, i, mid)
                right = helper(node.right, mid+1, j)
                return left + right

        return helper(self.root, i, j)