diff --git a/src/boruvkas_algorithm/__init__.py b/src/boruvkas_algorithm/__init__.py index e69de29..4b7a38f 100644 --- a/src/boruvkas_algorithm/__init__.py +++ b/src/boruvkas_algorithm/__init__.py @@ -0,0 +1,6 @@ +"""Boruvka's algorithm for finding minimum spanning trees.""" + +from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm +from boruvkas_algorithm.union_find import UnionFind + +__all__: list[str] = ["Graph", "UnionFind", "find_mst_with_boruvkas_algorithm"] diff --git a/src/boruvkas_algorithm/boruvka.py b/src/boruvkas_algorithm/boruvka.py index 82e157e..05617a0 100644 --- a/src/boruvkas_algorithm/boruvka.py +++ b/src/boruvkas_algorithm/boruvka.py @@ -5,6 +5,8 @@ import matplotlib.pyplot as plt import networkx as nx +from boruvkas_algorithm.union_find import UnionFind + class Graph: """A graph that contains nodes and edges.""" @@ -78,69 +80,23 @@ def draw_mst(self, mst_edges: list[tuple[int, int, int]]) -> None: def find_mst_with_boruvkas_algorithm( graph: Graph, + union_find: UnionFind | None = None, ) -> tuple[int, list[tuple[int, int, int]]]: """ Finds the minimum spanning tree (MST) of a graph using Boruvka's algorithm. Args: graph: The graph to find the MST of. + union_find: Optional UnionFind instance for tracking components. If not + provided, a new one will be created. Returns: A tuple containing the total weight of the MST and a list of the edges in the MST. """ - - def find(node: int) -> int: - """ - Finds the root parent of the node using path compression. - - Args: - node: The node to find the root parent of. - - Returns: - The root parent of the node. - """ - cur_parent = parent[node] - while cur_parent != parent[cur_parent]: - # Compress the links as we go up the chain of parents to make - # it faster to traverse in the future - amortised O(a(n)) time, - # where a(n) is the inverse Ackermann function. - parent[cur_parent] = parent[parent[cur_parent]] - cur_parent = parent[cur_parent] - return cur_parent - - def union(node1: int, node2: int) -> bool: - """ - Combines the two nodes into the larger segment. - - Args: - node1: The first node to combine. - node2: The second node to combine. - - Returns: - True if the nodes were combined, False if they were already in the - same segment. - """ - root1 = find(node1) - root2 = find(node2) - # If they have the same root parent, they're already connected. - if root1 == root2: - return False - - # Combine the two nodes into the larger segment based on the rank. - if rank[root1] > rank[root2]: - parent[root2] = root1 - rank[root1] += rank[root2] - else: - parent[root1] = root2 - rank[root2] += rank[root1] - return True - num_vertices = len(graph.vertices) - # Each node is its own parent initially. - parent: list[int] = list(range(num_vertices)) - # Each tree has size 1 (itself) initially. - rank: list[int] = [1] * num_vertices + if union_find is None: + union_find = UnionFind(num_vertices) print("\nFinding MST with Boruvka's algorithm:") graph.print_graph_info() @@ -164,7 +120,7 @@ def union(node1: int, node2: int) -> bool: ] * num_vertices for edge in graph.edges: node1, node2, weight = edge - comp1, comp2 = find(node1), find(node2) + comp1, comp2 = union_find.find(node1), union_find.find(node2) if comp1 != comp2: current_min1 = min_edge_per_component[comp1] @@ -178,10 +134,10 @@ def union(node1: int, node2: int) -> bool: for edge in min_edge_per_component: if edge is not None: node1, node2, weight = edge - if find(node1) != find(node2): + if union_find.find(node1) != union_find.find(node2): mst_edges.append(edge) mst_weight += weight - union(node1, node2) + union_find.union(node1, node2) num_components -= 1 print(f"Added edge {node1} - {node2} with weight {weight} to MST.") diff --git a/src/boruvkas_algorithm/union_find.py b/src/boruvkas_algorithm/union_find.py new file mode 100644 index 0000000..949c868 --- /dev/null +++ b/src/boruvkas_algorithm/union_find.py @@ -0,0 +1,76 @@ +class UnionFind: + """ + Union-find (disjoint set union) data structure for tracking connected + components with path compression and union by size. + """ + + def __init__(self, size: int) -> None: + """ + Initialises the Union-Find structure. + + Args: + size: The number of elements in the structure. + """ + # Each node is its own parent initially. + self.parent: list[int] = list(range(size)) + # Each tree has size 1 (itself) initially. + self.rank: list[int] = [1] * size + + def find(self, node: int) -> int: + """ + Finds the root parent of the node using path compression. + + Args: + node: The node to find the root parent of. + + Returns: + The root parent of the node. + """ + cur_parent = self.parent[node] + while cur_parent != self.parent[cur_parent]: + # Compress the links as we go up the chain of parents to make + # it faster to traverse in the future - amortised O(a(n)) time, + # where a(n) is the inverse Ackermann function. + self.parent[cur_parent] = self.parent[self.parent[cur_parent]] + cur_parent = self.parent[cur_parent] + return cur_parent + + def union(self, node1: int, node2: int) -> bool: + """ + Combines the two nodes into the larger segment. + + Args: + node1: The first node to combine. + node2: The second node to combine. + + Returns: + True if the nodes were combined, False if they were already in the + same segment. + """ + root1 = self.find(node1) + root2 = self.find(node2) + # If they have the same root parent, they're already connected. + if root1 == root2: + return False + + # Combine the two nodes into the larger segment based on the rank. + if self.rank[root1] > self.rank[root2]: + self.parent[root2] = root1 + self.rank[root1] += self.rank[root2] + else: + self.parent[root1] = root2 + self.rank[root2] += self.rank[root1] + return True + + def is_connected(self, node1: int, node2: int) -> bool: + """ + Checks if two nodes are in the same component. + + Args: + node1: The first node. + node2: The second node. + + Returns: + True if the nodes are connected, False otherwise. + """ + return self.find(node1) == self.find(node2) diff --git a/tests/test_boruvka.py b/tests/test_boruvka.py index 2c09052..334a8b3 100644 --- a/tests/test_boruvka.py +++ b/tests/test_boruvka.py @@ -1,6 +1,7 @@ import pytest from boruvkas_algorithm.boruvka import Graph, find_mst_with_boruvkas_algorithm +from boruvkas_algorithm.union_find import UnionFind @pytest.fixture @@ -15,6 +16,16 @@ def setup_graph(): return Graph(9) # Example graph with 9 vertices. +def test_graph_initialization(): + """ + Tests that a graph is initialised with the correct number of vertices and + no edges. + """ + graph = Graph(5) + assert len(graph.vertices) == 5, "Graph should have 5 vertices" + assert len(graph.edges) == 0, "Graph should be initialised with no edges" + + def test_add_edge(setup_graph: Graph): """ Tests that edges are correctly added by checking the length of the edge @@ -38,6 +49,131 @@ def test_add_edge_invalid_vertices(setup_graph: Graph): graph.add_edge(10, 11, 5) +# ============================================================================= +# UnionFind Tests +# ============================================================================= + + +def test_union_find_initialization(): + """Tests that UnionFind initialises with correct parent and rank arrays.""" + uf = UnionFind(5) + assert uf.parent == [0, 1, 2, 3, 4], "Each node should be its own parent" + assert uf.rank == [1, 1, 1, 1, 1], "Each node should have rank 1" + + +def test_union_find_find_single_node(): + """Tests that find returns the node itself when it's its own parent.""" + uf = UnionFind(5) + assert uf.find(0) == 0 + assert uf.find(4) == 4 + + +def test_union_find_union_two_nodes(): + """Tests that union correctly combines two nodes.""" + uf = UnionFind(5) + result = uf.union(0, 1) + assert result is True, "Union should return True when nodes are combined" + assert uf.find(0) == uf.find(1), "Nodes should have the same root after union" + + +def test_union_find_union_already_connected(): + """Tests that union returns False when nodes are already connected.""" + uf = UnionFind(5) + uf.union(0, 1) + result = uf.union(0, 1) + assert result is False, "Union should return False when already connected" + + +def test_union_find_union_by_size(): + """Tests that smaller trees are merged into larger trees.""" + uf = UnionFind(5) + # Create a larger tree: 0 <- 1, 0 <- 2 + uf.union(0, 1) + uf.union(0, 2) + # Now union with node 3 - node 3 should be merged into the larger tree. + uf.union(3, 0) + # The root of the larger tree should remain the root. + root = uf.find(0) + assert uf.find(3) == root, "Smaller tree should be merged into larger tree" + + +def test_union_find_path_compression(): + """Tests that path compression flattens the tree structure.""" + uf = UnionFind(5) + # Create a chain: 0 <- 1 <- 2 <- 3 + uf.parent = [0, 0, 1, 2, 4] + uf.rank = [4, 1, 1, 1, 1] + # Find on node 3 should compress the path. + root = uf.find(3) + assert root == 0, "Root should be 0" + # After path compression, intermediate nodes should point closer to root. + assert uf.parent[2] in (0, 1), "Path compression should shorten the path" + + +def test_union_find_multiple_components(): + """Tests UnionFind with multiple separate components.""" + uf = UnionFind(6) + # Create two components: {0, 1, 2} and {3, 4, 5} + uf.union(0, 1) + uf.union(1, 2) + uf.union(3, 4) + uf.union(4, 5) + + # Check components are separate. + assert uf.find(0) == uf.find(1) == uf.find(2) + assert uf.find(3) == uf.find(4) == uf.find(5) + assert uf.find(0) != uf.find(3), "Components should be separate" + + # Merge the two components. + uf.union(2, 3) + assert uf.find(0) == uf.find(5), "Components should be merged" + + +def test_union_find_is_connected(): + """Tests the is_connected convenience method.""" + uf = UnionFind(5) + assert not uf.is_connected(0, 1), "Nodes should not be connected initially" + + uf.union(0, 1) + assert uf.is_connected(0, 1), "Nodes should be connected after union" + assert not uf.is_connected(0, 2), "Unconnected nodes should return False" + + uf.union(1, 2) + assert uf.is_connected(0, 2), "Transitively connected nodes should return True" + + +# ============================================================================= +# MST Algorithm Tests +# ============================================================================= + + +def test_mst_with_injected_union_find(setup_graph: Graph): + """Tests that the algorithm works with an injected UnionFind instance.""" + graph = setup_graph + graph.add_edge(0, 1, 4) + graph.add_edge(0, 6, 7) + graph.add_edge(1, 6, 11) + graph.add_edge(1, 7, 20) + graph.add_edge(1, 2, 9) + graph.add_edge(2, 3, 6) + graph.add_edge(2, 4, 2) + graph.add_edge(3, 4, 10) + graph.add_edge(3, 5, 5) + graph.add_edge(4, 5, 15) + graph.add_edge(4, 7, 1) + graph.add_edge(4, 8, 5) + graph.add_edge(5, 8, 12) + graph.add_edge(6, 7, 1) + graph.add_edge(7, 8, 3) + + # Inject a custom UnionFind instance. + union_find = UnionFind(len(graph.vertices)) + mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph, union_find) + + assert mst_weight == 29, "MST weight should be 29" + assert len(mst_edges) == 8, "MST should have 8 edges for 9 vertices" + + def test_mst(setup_graph: Graph): """ Tests that the MST has the correct total weight and structure by comparing @@ -80,11 +216,27 @@ def test_mst(setup_graph: Graph): ) -def test_graph_initialization(): - """ - Test that a graph is initialized with the correct number of vertices and - no edges. - """ - graph = Graph(5) # Initialize a graph with 5 vertices. - assert len(graph.vertices) == 5, "Graph should have 5 vertices" - assert len(graph.edges) == 0, "Graph should be initialized with no edges" +def test_mst_simple_triangle(): + """Tests MST on a simple triangle graph.""" + graph = Graph(3) + graph.add_edge(0, 1, 1) + graph.add_edge(1, 2, 2) + graph.add_edge(0, 2, 3) + + mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph) + + assert mst_weight == 3, "MST weight should be 3 (edges 1 + 2)" + assert len(mst_edges) == 2, "MST should have 2 edges for 3 vertices" + + +def test_mst_linear_graph(): + """Tests MST on a linear graph (already a tree).""" + graph = Graph(4) + graph.add_edge(0, 1, 1) + graph.add_edge(1, 2, 2) + graph.add_edge(2, 3, 3) + + mst_weight, mst_edges = find_mst_with_boruvkas_algorithm(graph) + + assert mst_weight == 6, "MST weight should be 6 (1 + 2 + 3)" + assert len(mst_edges) == 3, "MST should have 3 edges for 4 vertices" diff --git a/tests/test_union_find.py b/tests/test_union_find.py new file mode 100644 index 0000000..3d87b56 --- /dev/null +++ b/tests/test_union_find.py @@ -0,0 +1,89 @@ +from boruvkas_algorithm.boruvka import UnionFind + + +def test_union_find_initialization(): + """Tests that UnionFind initialises with correct parent and rank arrays.""" + uf = UnionFind(5) + assert uf.parent == [0, 1, 2, 3, 4], "Each node should be its own parent" + assert uf.rank == [1, 1, 1, 1, 1], "Each node should have rank 1" + + +def test_union_find_find_single_node(): + """Tests that find returns the node itself when it's its own parent.""" + uf = UnionFind(5) + assert uf.find(0) == 0 + assert uf.find(4) == 4 + + +def test_union_find_union_two_nodes(): + """Tests that union correctly combines two nodes.""" + uf = UnionFind(5) + result = uf.union(0, 1) + assert result is True, "Union should return True when nodes are combined" + assert uf.find(0) == uf.find(1), "Nodes should have the same root after union" + + +def test_union_find_union_already_connected(): + """Tests that union returns False when nodes are already connected.""" + uf = UnionFind(5) + uf.union(0, 1) + result = uf.union(0, 1) + assert result is False, "Union should return False when already connected" + + +def test_union_find_union_by_size(): + """Tests that smaller trees are merged into larger trees.""" + uf = UnionFind(5) + # Create a larger tree: 0 <- 1, 0 <- 2 + uf.union(0, 1) + uf.union(0, 2) + # Now union with node 3 - node 3 should be merged into the larger tree. + uf.union(3, 0) + # The root of the larger tree should remain the root. + root = uf.find(0) + assert uf.find(3) == root, "Smaller tree should be merged into larger tree" + + +def test_union_find_path_compression(): + """Tests that path compression flattens the tree structure.""" + uf = UnionFind(5) + # Create a chain: 0 <- 1 <- 2 <- 3 + uf.parent = [0, 0, 1, 2, 4] + uf.rank = [4, 1, 1, 1, 1] + # Find on node 3 should compress the path. + root = uf.find(3) + assert root == 0, "Root should be 0" + # After path compression, intermediate nodes should point closer to root. + assert uf.parent[2] in (0, 1), "Path compression should shorten the path" + + +def test_union_find_multiple_components(): + """Tests UnionFind with multiple separate components.""" + uf = UnionFind(6) + # Create two components: {0, 1, 2} and {3, 4, 5} + uf.union(0, 1) + uf.union(1, 2) + uf.union(3, 4) + uf.union(4, 5) + + # Check components are separate. + assert uf.find(0) == uf.find(1) == uf.find(2) + assert uf.find(3) == uf.find(4) == uf.find(5) + assert uf.find(0) != uf.find(3), "Components should be separate" + + # Merge the two components. + uf.union(2, 3) + assert uf.find(0) == uf.find(5), "Components should be merged" + + +def test_union_find_connected(): + """Tests the connected convenience method.""" + uf = UnionFind(5) + assert not uf.is_connected(0, 1), "Nodes should not be connected initially" + + uf.union(0, 1) + assert uf.is_connected(0, 1), "Nodes should be connected after union" + assert not uf.is_connected(0, 2), "Unconnected nodes should return False" + + uf.union(1, 2) + assert uf.is_connected(0, 2), "Transitively connected nodes should return True"