All Articles

Partitioning an Array into K Equal Sum Subsets

Image from Unsplash by Nguyen Duc Thuan
Image from Unsplash by Nguyen Duc Thuan

From LeetCode:

Given an array of integers nums and a positive integer k, find whether it’s possible to divide this array into k non-empty subsets whose sums are all equal.

Input: nums = [4, 3, 2, 3, 5, 2, 1], k = 4
Output: True
# Example: (5), (1, 4), (2,3), (2,3) have equal sums

Note:

  • 1 <= k <= len(nums) <= 16.
  • 0 < nums[i] < 10000.

Filtering Out Invalid Inputs

First, let’s sieve out any inputs to this problem that are clearly invalid. We’ll deal with the less obvious cases afterward.

Identifying an invalid input is easy. We’ll obtain the sum of the array and divide it by k. If it turns out that k doesn’t cleanly divide the sum (i.e. there’s a remainder value), then we’ll immediately know that it’s impossible for us to form equal sum subsets.

def can_partition(nums, k):
    subset_tgt, remainder = divmod(sum(nums), k)
    if remainder != 0:
        return False
    
    # ...

Depth-First Search vs Dynamic Programming

From this point onward, there are a couple of good approaches to this problem.

A popular strategy is to use DFS and recurse with an accumulator sum until it reaches subset_tgt. Doing this k times gives you the answer. This type of algorithm runs with an upper bound of approximately O(k * 2ⁿ) should already be somewhat familiar to you if you’ve practiced backtracking a few times.

I won’t dive into the DFS-based solution because I’d like to write about the DP-based one instead, which also runs in exponential time but relies on more interesting code.

Thus far, most of my February articles have been written without the need for significant cross-referencing. To familiarise myself with the DP-based approach, however, I’ve had to go through this discussion thread on LeetCode. I must credit the author, maverick009, for enabling my learning.

Representing Subsets as Integers

Understand that every subset in the nums array can be represented by an integer. For instance, given:

nums = [2, 4, 6, 8, 10]

The subset of [2, 4, 8] can be represented as 1011 in binary. 2, 4, and 8 are positioned at the 0th, 1st, and 3rd index and we set these corresponding bits to 1 to indicate that these nums are present in the subset.

1011, in decimal form, is simply 11. Checking if 4 exists, for instance, can be achieved by doing 11 & (1 < 3). As a general rule, we’d do x & (1 < i) to see if the i-th bit has been set on x.

Since every subset can be represented as an integer, how many subsets are there? Another way to ask this is: What’s the cardinality of the power set? The answer turns out to be 2^n –– ranging from binary representations 0000... to 1111....

Building Towards an Answer

We’ll initialise two DP arrays, DP and subset_sums to store some information about each subset. Using this stored data, we’ll work our way up towards understanding the complete set.

def can_partition(nums, k):
    subset_tgt, remainder = divmod(sum(nums), k)
    if remainder != 0:
        return False

    # Set up DP arrays
    num_subsets = 1 << len(nums)
    subset_sums = [None for _ in range(num_subsets)]
    subset_sums[0] = 0
    DP = [False for _ in range(num_subsets)]
    DP[0] = True

subset_sums[x] gives us the sum of the subset represented by x. To use an example:

nums = [2, 4, 6, 8, 10]
# 1. x == 11, so bin(x) == "1011"
# 2. "1011" represents [2, 4, 8]
# 3. sum([2, 4, 8]) == 14, so subset_sums[11] == 14

DP[x] is tricker to explain. It holds a boolean value which tell us that the subset represented by x can be divided into buckets that each sum to subset_tgt –– with the exception of at most one bucket not being completely filled.

nums = [2, 4, 6, 8, 10], k = 3

# 1. sum(nums) == 30, so subset_tgt == 10

# 2. 11 (or "1011") represents [2, 4, 8]
#    [2, 4, 8] splits into:
#      - [2, 8] (filled bucket) 
#      - [4] (unfilled bucket)
#    DP[11] = True

# 3. 22 (or "11010") represents [4, 8, 10] 
#    [4, 8, 10] splits into:
#      - [10] (filled bucket) 
#      - [4, 8] (overflowing bucket)
#    DP[22] = False

With these two arrays, we can eventually build up answer an for DP[-1], which tell us if the complete set represented by 1111.... can be divided into k buckets.

We also expect subset_sums[-1] == sum(nums).

def has_ith_bit(x, i):
    return x & (1 << i)

def set_ith_bit(x, i):
    return x | (1 << i)

for x in range(num_subsets):
    # For every viable subset
    if DP[x]:
        for i, num in enumerate(nums):
            if not has_ith_bit(x, i):
                # Consider new subset with `num` added
                x_with_num = set_ith_bit(x, i)
                subset_sums[x_with_num] = subset_sums[x] + num

                # See if adding `num` to the subset results 
                # in an overflowing bucket or not
                DP[x_with_num] = (
                    (subset_sums[x] % subset_tgt) + num <= subset_tgt
                )

Here’s one way to think about this double loop: We’ve examining all subsets and considering candidate elements that can be added to each of them. A good candidate element should add to the subset without causing the current bucket to overflow.

For example, if subset_tgt == 10 and our currently-examined subset can be formed as ([5, 5], [6]), we’ll want to add a number that lies within the range of 1 to 4. ([5, 5], [6, 1]) allows us to continue looking for more candidates for the current bucket, while ([5, 5], [6, 4]) allows us to move on to the next bucket.

In other words, a good candidate num gives us a viable bucket arrangement and tells us that we can mark DP[x_with_num] = True.

Once we’re done iterating through all possible subsets, we’ll check the last item in the DP array for our answer:

def can_partition(nums, k):
    subset_tgt, remainder = divmod(sum(nums), k)
    # Check for remainder...

    num_subsets = 1 << len(nums)
    subset_sums = [None for _ in range(num_subsets)]
    subset_sums[0] = 0
    DP = [False for _ in range(num_subsets)]
    DP[0] = True

    for x in range(num_subsets):
        # Go through every possible subset...

    return DP[-1]

Given n integers in the array, our algorithm runs in O(2ⁿ) space and O(n * 2ⁿ) time.

With that, we’ve solved the Partition Array into K Equal Sum Subsets problem 🏢🏢🏢.

Full Solution

def can_partition(nums, k):
    def has_ith_bit(x, i):
        return x & (1 << i)
        
    def set_ith_bit(x, i):
        return x | (1 << i)

    subset_tgt, remainder = divmod(sum(nums), k)
    if remainder != 0:
        return False

    num_subsets = 1 << len(nums)
    DP = [False for _ in range(num_subsets)]
    DP[0] = True
    
    subset_sums = [None for _ in range(num_subsets)]
    subset_sums[0] = 0

    for x in range(num_subsets):
        if DP[x]:
            for i, num in enumerate(nums):
                if not has_ith_bit(x, i):
                    x_with_num = set_ith_bit(x, i)
                    subset_sums[x_with_num] = subset_sums[x] + num
                    DP[x_with_num] = (
                        (subset_sums[x] % subset_tgt) + num <= subset_tgt
                    )

    return DP[-1]