All Articles

Connecting Cities At Minimum Cost

Image from Unsplash by Colin Watts
Image from Unsplash by Colin Watts

From LeetCode:

There are N cities numbered from 1 to N.

You are given array of connections. Each connection, [c1, c2, cost], describes the cost of connecting city1 and city2 together. A connection is bidirectional –– connecting city1 and city2 is the same as connecting city2 and city1.

Return the minimum cost such that for every pair of cities, there exists a path of connections (possibly of length 1) that connects them together. The cost is the sum of the connection costs used. If the task is impossible, return -1.

Input: N = 3, connections = [[1, 2, 5], [1, 3, 6], [2, 3, 1]]
Output: 6

This question invites us to think about “cities” as graph nodes and “connections” as graph edges. Essentially, the challenge here is to find the total weight of all the edges involved in building a minimum spanning tree (MST).

This isn’t a hard problem. There are plenty of template solutions for finding a MST in a weighted graph, but I wanted to use this article to 1) highlight some finer details about working with priority queues and 2) ease into other graph problems that we’ll look at, later this month.

As usual, let’s first define the node class we’ll be using for our graph:

class Node:
    def __init__(self, label):
        self.label = label
        # Using a set() instead of {} because two cities
        # can have more than one connection
        self.edges = set()

You’ll want to organise the connections between cities such that they are easily accessible by their respective nodes (on both ends):

Edge = collections.namedtuple('Edge', ('cost', 'label'))

def min_MST_cost(n, connections):
    nodes = {i: Node(i) for i in range(1, n+1)}

    for c1, c2, cost in connections:
        nodes[c1].edges.add(Edge(cost, c2))
        nodes[c2].edges.add(Edge(cost, c1))

When constructing an MST, we can think of the graph as being composed of two sets: the nodes we’ve visited and the nodes we’ve yet to visit. Our goal is to extend the first set to include the most accessible node in the second set.

Priority Queues with Min Heaps

How do we determine what is most accessible? We use a min heap to maintain a priority queue of edges and greedily pick the edge with the smallest cost. To state it differently –– we’re finding the cheapest way to extend our set of visited nodes.

# Init priority queue
pq = [Edge(0, 1)]

while pq:
    popped = heapq.heappop(pq)
    node = nodes[popped.label]


    for cost, adj_label in node.edges.items():
        heapq.heappush(pq, Edge(cost, adj_label))

There are a few things I should call out here.

First, our use of heapq demands that we order our entries as (cost, label) rather than (label, cost). While it might feel more intuitive to place the labels (i.e. our identifiers for each node) in front, the min heap provided by heapq sorts items by their first value by default (and by subsequent values if necessary). We can work around this by specifically writing a class definition for our edges:

class Edge:
    def __init__(self, label, cost):
        self.label, self.cost = label, cost
    
    def __lt__(self, other):
        return self.cost < other.cost

    def __hash__(self):
        return hash((self.label, self.cost))

However, doing this takes too much time in a live coding environment and isn’t quite worth the effort, so we’ll stick with either raw or named tuples.

Second, consider the fact that we’re working specifically with edges in our priority queue. Depending on your use case, the types of entries you have in your queue might differ. For an algorithm that deals with shortest paths (Djikstra), you might want to track nodes instead and have them be sorted by their relative distance to the source node.

Third, it is typically the case that graph optimisations that rely on priority queues initialise it with a single element. This is because we expect ourselves to add more entries to pq while this first element is being processed (unless the graph is disconnected).

Fourth, understand that a min heap is just one way of implementing a priority queue! Any sorted collection with a O(log n) upper bound on their basic operations is suitable. In particular, even though SortedList takes O(log n) time to remove its minimum element (which is less efficient compared to heapq), it also grants us the flexibility of finding, removing, and inserting entries in O(log n) time.

Constructing a Minimum Spanning Tree

Our algorithm is not complete yet. In fact, it’ll run into an infinite loop since every connection between cities is bi-directional –– so we’ll wind up with a cycle! Let’s prevent this by maintaining a visited set. If a node has already been added to our MST, we’ll ignore any edges leading to it.

We’ll also want to track the cost we’ve incurred from adding a edge / node at each step:

pq = [Edge(0, 1)]
visited = set() # Track seen nodes
total_cost = 0 # Track cost of adding edges

# Terminate when:
# - Queue has been flushed, or
# - All nodes have been visited
while pq and len(visited) < n:
    popped = heapq.heappop(pq)

    if popped.label not in visited:
        visited.add(popped.label)
        node = nodes[popped.label]
        total_cost += popped.cost

        for edge in node.edges:
            heapq.heappush(pq, edge)

if len(visited) != n:
    return -1
return total_cost

Once we’ve visited all nodes, we can return the accumulated total_cost as our answer.

Given |V| nodes and |E| edges, our algorithm runs in O(|E|) space and O(|E| log |E|) time.

With that, we’ve solved the Connecting Cities At Minimum Cost problem 🌆🌉🌆.

Full Solution

import heapq
import collections

class Node:
    def __init__(self, label):
        self.label = label
        self.edges = set()


Edge = collections.namedtuple('Edge', ('cost', 'label'))


def min_MST_cost(n, connections):
    nodes = {i: Node(i) for i in range(1, n+1)}
    for c1, c2, cost in connections:
        nodes[c1].edges.add(Edge(cost, c2))
        nodes[c2].edges.add(Edge(cost, c1))

    pq = [Edge(0, 1)]
    visited = set()
    total_cost = 0

    while pq and len(visited) < n:
        popped = heapq.heappop(pq)

        if popped.label not in visited:
            visited.add(popped.label)
            node = nodes[popped.label]
            total_cost += popped.cost

            for edge in node.edges:
                heapq.heappush(pq, edge)

    if len(visited) != n:
        return -1
    return total_cost