From LeetCode:
Given an array of integers nums and a positive integer k, find whether it’s possible to divide this array into k non-empty subsets whose sums are all equal.
Input: nums = [4, 3, 2, 3, 5, 2, 1], k = 4 Output: True # Example: (5), (1, 4), (2,3), (2,3) have equal sums
Note:
- 1 <= k <= len(nums) <= 16.
- 0 < nums[i] < 10000.
Filtering Out Invalid Inputs
First, let’s sieve out any inputs to this problem that are clearly invalid. We’ll deal with the less obvious cases afterward.
Identifying an invalid input is easy. We’ll obtain the sum of the array and divide it by k
. If it turns out that k
doesn’t cleanly divide the sum (i.e. there’s a remainder
value), then we’ll immediately know that it’s impossible for us to form equal sum subsets.
def can_partition(nums, k):
subset_tgt, remainder = divmod(sum(nums), k)
if remainder != 0:
return False
# ...
Depth-First Search vs Dynamic Programming
From this point onward, there are a couple of good approaches to this problem.
A popular strategy is to use DFS and recurse with an accumulator sum until it reaches subset_tgt
. Doing this k
times gives you the answer. This type of algorithm runs with an upper bound of approximately O(k * 2ⁿ)
should already be somewhat familiar to you if you’ve practiced backtracking a few times.
I won’t dive into the DFS-based solution because I’d like to write about the DP-based one instead, which also runs in exponential time but relies on more interesting code.
Thus far, most of my February articles have been written without the need for significant cross-referencing. To familiarise myself with the DP-based approach, however, I’ve had to go through this discussion thread on LeetCode. I must credit the author, maverick009, for enabling my learning.
Representing Subsets as Integers
Understand that every subset in the nums
array can be represented by an integer. For instance, given:
nums = [2, 4, 6, 8, 10]
The subset of [2, 4, 8]
can be represented as 1011
in binary. 2
, 4
, and 8
are positioned at the 0th, 1st, and 3rd index and we set these corresponding bits to 1
to indicate that these nums are present in the subset.
1011
, in decimal form, is simply 11
. Checking if 4
exists, for instance, can be achieved by doing 11 & (1 < 3)
. As a general rule, we’d do x & (1 < i)
to see if the i-th bit has been set on x
.
Since every subset can be represented as an integer, how many subsets are there? Another way to ask this is: What’s the cardinality of the power set? The answer turns out to be 2^n
–– ranging from binary representations 0000...
to 1111...
.
Building Towards an Answer
We’ll initialise two DP arrays, DP
and subset_sums
to store some information about each subset. Using this stored data, we’ll work our way up towards understanding the complete set.
def can_partition(nums, k):
subset_tgt, remainder = divmod(sum(nums), k)
if remainder != 0:
return False
# Set up DP arrays
num_subsets = 1 << len(nums)
subset_sums = [None for _ in range(num_subsets)]
subset_sums[0] = 0
DP = [False for _ in range(num_subsets)]
DP[0] = True
subset_sums[x]
gives us the sum of the subset represented by x
. To use an example:
nums = [2, 4, 6, 8, 10]
# 1. x == 11, so bin(x) == "1011"
# 2. "1011" represents [2, 4, 8]
# 3. sum([2, 4, 8]) == 14, so subset_sums[11] == 14
DP[x]
is tricker to explain. It holds a boolean value which tell us that the subset represented by x
can be divided into buckets that each sum to subset_tgt
–– with the exception of at most one bucket not being completely filled.
nums = [2, 4, 6, 8, 10], k = 3
# 1. sum(nums) == 30, so subset_tgt == 10
# 2. 11 (or "1011") represents [2, 4, 8]
# [2, 4, 8] splits into:
# - [2, 8] (filled bucket)
# - [4] (unfilled bucket)
# DP[11] = True
# 3. 22 (or "11010") represents [4, 8, 10]
# [4, 8, 10] splits into:
# - [10] (filled bucket)
# - [4, 8] (overflowing bucket)
# DP[22] = False
With these two arrays, we can eventually build up answer an for DP[-1]
, which tell us if the complete set represented by 1111....
can be divided into k
buckets.
We also expect subset_sums[-1] == sum(nums)
.
def has_ith_bit(x, i):
return x & (1 << i)
def set_ith_bit(x, i):
return x | (1 << i)
for x in range(num_subsets):
# For every viable subset
if DP[x]:
for i, num in enumerate(nums):
if not has_ith_bit(x, i):
# Consider new subset with `num` added
x_with_num = set_ith_bit(x, i)
subset_sums[x_with_num] = subset_sums[x] + num
# See if adding `num` to the subset results
# in an overflowing bucket or not
DP[x_with_num] = (
(subset_sums[x] % subset_tgt) + num <= subset_tgt
)
Here’s one way to think about this double loop: We’ve examining all subsets and considering candidate elements that can be added to each of them. A good candidate element should add to the subset without causing the current bucket to overflow.
For example, if subset_tgt == 10
and our currently-examined subset can be formed as ([5, 5], [6])
, we’ll want to add a number that lies within the range of 1
to 4
. ([5, 5], [6, 1])
allows us to continue looking for more candidates for the current bucket, while ([5, 5], [6, 4])
allows us to move on to the next bucket.
In other words, a good candidate num
gives us a viable bucket arrangement and tells us that we can mark DP[x_with_num] = True
.
Once we’re done iterating through all possible subsets, we’ll check the last item in the DP
array for our answer:
def can_partition(nums, k):
subset_tgt, remainder = divmod(sum(nums), k)
# Check for remainder...
num_subsets = 1 << len(nums)
subset_sums = [None for _ in range(num_subsets)]
subset_sums[0] = 0
DP = [False for _ in range(num_subsets)]
DP[0] = True
for x in range(num_subsets):
# Go through every possible subset...
return DP[-1]
Given n
integers in the array, our algorithm runs in O(2ⁿ)
space and O(n * 2ⁿ)
time.
With that, we’ve solved the Partition Array into K Equal Sum Subsets problem 🏢🏢🏢.
Full Solution
def can_partition(nums, k):
def has_ith_bit(x, i):
return x & (1 << i)
def set_ith_bit(x, i):
return x | (1 << i)
subset_tgt, remainder = divmod(sum(nums), k)
if remainder != 0:
return False
num_subsets = 1 << len(nums)
DP = [False for _ in range(num_subsets)]
DP[0] = True
subset_sums = [None for _ in range(num_subsets)]
subset_sums[0] = 0
for x in range(num_subsets):
if DP[x]:
for i, num in enumerate(nums):
if not has_ith_bit(x, i):
x_with_num = set_ith_bit(x, i)
subset_sums[x_with_num] = subset_sums[x] + num
DP[x_with_num] = (
(subset_sums[x] % subset_tgt) + num <= subset_tgt
)
return DP[-1]