All Articles

Smaller Numbers After Self

Image from Unsplash by Chuttersnap
Image from Unsplash by Chuttersnap

From LeetCode:

You are given an integer array A and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of A[i].

Input: [5, 2, 6, 1]
Output: [2, 1, 1, 0]

# Explanation:
# To the right of 5 there are 2 smaller elements (2 and 1).
# To the right of 2 there is only 1 smaller element (1).
# To the right of 6 there is 1 smaller element (1).
# To the right of 1 there is 0 smaller element.

Intuitively, returning the number of smaller elements on the right of A[i] means returning the number of “out of place” elements. Elements might be “out of place” in the sense that they aren’t already positioned where they will be when A is sorted. It seems probable that we’ll need some kind of sorting algorithm to solve this question.

Inspecting Shifts in Indices

One might try to solve this problem by sorting A and then checking how much A[i] has shifted. To do this, one could modify the array by latching the original index (prev_i) onto each of the values in the array:

import collections

Entry = collections.namedtuple('Entry', ('val', 'prev_i'))

def count_smaller(self, A):
    n = len(A)
    counts = [0 for _ in range(n)]
    A = [Entry(val, i) for i, val in enumerate(A)]
    A.sort(key=lambda e: e.val)

    for i, e in enumerate(A):
        # i > prev_idx means the index has shifted right,
        # implying that originally, there were
        # smaller elements on the right
        counts[e.prev_i] = max(i - e.prev_i, 0)

    return counts

This line of thinking is on the right track, but it won’t always give us the correct answer. Consider the example we’ve been given:

A =  [5, 2, 6, 1]
# sorted(A) = [1, 2, 5, 6]

count_smaller(A)
# counts = [2, 0, 1, 0] 

# A[0] shifted right by 2 positions to index 2.
# A[1] remained at index 1.
# A[2] shifted right by 1 position to index 3.
# A[3] shifted left by 3 positions to index 0,
#      which means it did not shift right.

We get [2, 0, 1, 0] as our result, but the correct answer is [2, 1, 1, 0]. This algorithm doesn’t enable us to deduce that the element A[1] = 2 has one smaller element, A[3] = 1, on its right.

This is because we haven’t properly accounted for the elements that precede each A[i] –– we might have larger elements in front of A[i] that wind up getting moved from A[i]’s left to A[i]’s right during the sorting process. Thus, we’ll need to modify our strategy to solve this problem.

Using Mergesort

It turns out that a mergesort-style approach is very suitable for our problem. Mergesort, if you recall, is composed of three steps: We split the array into two partitions, sort each partition, and then merge both partitions in a way that maintains their sortedness.

def msort(start, last):
    if start == last:
        return [A[start]]

    # Partition
    mid = (start + last) // 2

    # Sort partitions
    left = msort(start, mid)
    right = msort(mid+1, last)

    # Merge partitions
    return merge(left, right)

It’s the merge step, in particular, that will help us compute the smaller elements on the right of each A[i]:

def merge(L1, L2):
    i = j = 0
    L3 = [] # Resultant array

    while i < len(L1) and j < len(L2):
        if L1[i] <= L2[j]:
            L3.append(L1[i])
            i += 1
        else:
            L3.append(L2[j])
            j += 1

    while i < len(L1):
        L3.append(L1[i])
        i += 1

    while j < len(L2):
        L3.append(L2[j])
        j += 1
    
    return L3

Let’s suppose we’ve already generated two sorted halves from our input array:

A = [2, 4, 6, 1, 3, 5]
# left = [2, 4, 6], right = [1, 3, 5]
# merged = []

Our goal is to produce a resultant array that merges both left and right. Off the bat, we can tell that the elements in right will have to be appended into this resultant array before some (or all) of the elements in left. The leftmost element 1, for instance, should be added before 2, 4, and 6.

A = [2, 4, 6, 1, 3, 5]

# left = [2, 4, 6], right = [3, 5]
# merged = [1]

If our left and right arrays were [1, 2, 3] and [4, 5, 6] respectively, there would be no need for us to pick from the right subarray before picking from the left one. We arrive at an important conclusion: Every time we select an element from the right subarray, it indicates that we’ve found a smaller element on the right of the elements in the left subarray.

To continues with our example: When we append A[0] = 2 to the resultant array, we can increment its count (i.e. counts[0]) by 1, because we already know that we previously selected one element from the right subarray (i.e. A[5] = 1).

A = [2, 4, 6, 1, 3, 5]

# left = [4, 6], right = [3, 5]
# merged = [1, 2]

# ...

# left = [6], right = []
# merged = [1, 2, 3, 4, 5]

Likewise, when the time comes for us to append A[2] = 6 to the resultant array, we know we’ll need to increment its count (i.e. counts[2]) by 3, because we would have been selected all three elements from the right subarray beforehand (i.e. A[:3]).

This approach turns out to be more robust than the earlier one, because it bakes how we count of the number of smaller (right) elements into every step of the sorting process.

Since mergesort() runs recursively, so we can guarantee that each of the halves we’re dealing with have already been pre-sorted. By extension, smaller elements on the right of each A[i], within those halves, would have been accounted for.

From Intuition to Code

Let’s look at how we can tweak merge() specifically to suit our needs. In merge(), we use two index variables i and j for traversing through both the left and right subarrays (labelled L1 and L2).

Each time we add an element from the left subarray to the resultant array, how do we know the number of element we previously selected from the right? We’ll find the answer in the value of j:

Entry = collections.namedtuple('Entry', ('val', 'prev_i'))

# ...
A = [Entry(val, i) for val, x in enumerate(A)]
counts = [0 for _ in range(n)]

def merge(L1, L2):
    i = j = 0
    L3 = [] # Resultant array

    while i < len(L1) and j < len(L2):
        if L1[i].val <= L2[j].val:        
            L3.append(L1[i])
            # Increment count
            counts[L1[i].prev_i] += j
            i += 1
        else:
            L3.append(L2[j])
            j += 1

    while i < len(L1):
        L3.append(L1[i])
        # Increment count
        counts[L1[i].prev_i] += j
        i += 1

    while j < len(L2):
        L3.append(L2[j])
        j += 1
    
    return L3

As more elements are picked from the right subarray, j gets incremented, so we can use this value to update our counter (i.e. counts) for the elements in the left subarray.

At this point, returning the array of counts become trivial:

import collections

Entry = collections.namedtuple('Entry', ('val', 'prev_i'))

def count_smaller(A):
    n = len(A)
    counts = [0 for _ in range(n)]
    A = [Entry(val, i) for i, val in enumerate(A)]
    
    def merge(L1, L2):
        # ...
    
    def msort(s, l):
        # ...

    msort(0, n-1) 
    return counts

Given an array of n numbers, our algorithm runs in O(n) space and O(n log n) time.

With that, we’ve solved the Smaller Numbers After Self problem 🤟✋🤘.

Full Solution

import collections

Entry = collections.namedtuple('Entry', ('val', 'prev_i'))

def count_smaller(A):

    n = len(A)

    # Output array, keeps track of counts for each index
    counts = [0 for _ in range(n)]

    # Attach original index to each value
    A = [Entry(val, i) for i, val in enumerate(A)]
    
    # Define merge algo
    def merge(L1, L2):
        i = j = 0
        L3 = []

        while i < len(L1) and j < len(L2):
            if L1[i].val <= L2[j].val:
                # Index j captures the number of smaller
                # elements found on the right
                counts[L1[i].prev_i] += j
                L3.append(L1[i])
                i += 1
                
            else: # L1[i].val > L2[j].val
                L3.append(L2[j])
                j += 1

        while i < len(L1):
            counts[L1[i].prev_i] += j
            L3.append(L1[i])
            i += 1
        
        while j < len(L2):
            L3.append(L2[j])
            j += 1
                            
        return L3
    
    # Define mergesort algo
    def msort(s, l):
        if s > l:
            return []
        elif s == l:
            return [A[s]]

        m = (s + l) // 2
        return merge(msort(s, m), msort(m+1, l))

    # Run mergesort
    msort(0, n-1) 
    return counts