11package com .thealgorithms .graph ;
22
3+ import java .util .ArrayList ;
4+ import java .util .Arrays ;
35import java .util .List ;
46import java .util .PriorityQueue ;
5- import java .util .Comparator ;
67
78/**
8- * This class provides a method to compute the weight of the
9- * Minimum Spanning Tree (MST) of a graph using Prim's Algorithm.
9+ * Implementation of Prim's algorithm for computing Minimum Spanning Tree (MST).
10+ * This class provides methods to find the MST of a weighted undirected graph
11+ * using a priority queue based approach for optimal performance.
12+ *
13+ * <p>Time Complexity: O(E log V) where E is edges and V is vertices
14+ * <p>Space Complexity: O(V + E) for storing graph and auxiliary structures
15+ *
1016 */
1117public final class PrimAlgorithm {
1218
13- private PrimAlgorithm () {
14- throw new UnsupportedOperationException ("Utility class" );
19+ public PrimAlgorithm (List <List <Edge >> adjacencyList , int vertexCount ) {
20+ this .adjacencyList = adjacencyList ;
21+ this .vertexCount = vertexCount ;
1522 }
1623
1724 /**
18- * Helper record representing an edge with its associated weight and node.
25+ * Represents a weighted edge in the graph structure.
26+ * This inner class encapsulates edge information including destination
27+ * vertex and associated weight for MST computation.
28+ */
29+ private static final class Edge implements Comparable <Edge > {
30+ private final int destination ;
31+ private final int weight ;
32+
33+ /**
34+ * Constructs an edge with specified destination and weight.
35+ *
36+ * @param destination the target vertex of this edge
37+ * @param weight the cost associated with traversing this edge
38+ */
39+ Edge (final int destination , final int weight ) {
40+ this .destination = destination ;
41+ this .weight = weight ;
42+ }
43+
44+ /**
45+ * Retrieves the destination vertex of this edge.
46+ *
47+ * @return the destination vertex identifier
48+ */
49+ public int getDestination () {
50+ return destination ;
51+ }
52+
53+ /**
54+ * Retrieves the weight of this edge.
55+ *
56+ * @return the edge weight value
57+ */
58+ public int getWeight () {
59+ return weight ;
60+ }
61+
62+ /**
63+ * Compares edges based on weight for priority queue ordering.
64+ *
65+ * @param other the edge to compare against
66+ * @return negative if this edge weighs less, positive if more, zero if equal
67+ */
68+ @ Override
69+ public int compareTo (final Edge other ) {
70+ return Integer .compare (this .weight , other .weight );
71+ }
72+ }
73+
74+ /**
75+ * Stores the result of MST computation including total cost and edge list.
76+ */
77+ public static final class MstResult {
78+ private final int totalWeight ;
79+ private final List <int []> edges ;
80+
81+ /**
82+ * Constructs an MST result object.
83+ *
84+ * @param totalWeight the sum of all edge weights in the MST
85+ * @param edges list of edges where each edge is [source, dest, weight]
86+ */
87+ public MstResult (final int totalWeight , final List <int []> edges ) {
88+ this .totalWeight = totalWeight ;
89+ this .edges = new ArrayList <>(edges );
90+ }
91+
92+ /**
93+ * Returns the total weight of the MST.
94+ *
95+ * @return sum of all edge weights in the minimum spanning tree
96+ */
97+ public int getTotalWeight () {
98+ return totalWeight ;
99+ }
100+
101+ /**
102+ * Returns the list of edges in the MST.
103+ *
104+ * @return unmodifiable list of edges, each represented as [source, dest, weight]
105+ */
106+ public List <int []> getEdges () {
107+ return new ArrayList <>(edges );
108+ }
109+ }
110+
111+ private final List <List <Edge >> adjacencyList ;
112+ private final int vertexCount ;
113+
114+ /**
115+ * Constructs a graph representation for Prim's algorithm.
19116 *
20- * @param node the target node connected by the edge
21- * @param weight the weight of the edge
117+ * @param vertexCount the total number of vertices in the graph
118+ * @throws IllegalArgumentException if vertexCount is negative
22119 */
23- private record Pair (int node , int weight ) {}
120+ public PrimAlgorithm (final int vertexCount ) {
121+ if (vertexCount < 0 ) {
122+ throw new IllegalArgumentException ("Vertex count must be non-negative" );
123+ }
124+ this .vertexCount = vertexCount ;
125+ this .adjacencyList = new ArrayList <>(vertexCount );
126+ for (int i = 0 ; i < vertexCount ; i ++) {
127+ adjacencyList .add (new ArrayList <>());
128+ }
129+ }
24130
25131 /**
26- * Computes the total weight of the Minimum Spanning Tree (MST)
27- * for a given undirected, weighted graph .
132+ * Adds an undirected weighted edge between two vertices.
133+ * Since the graph is undirected, edges are added in both directions .
28134 *
29- * @param vertices number of vertices in the graph
30- * @param adj adjacency list representation of the graph
31- * for each node, the adjacency list contains a list of
32- * {adjacentNode, edgeWeight}
33- * @return the sum of the edge weights in the MST
135+ * @param source the starting vertex of the edge
136+ * @param destination the ending vertex of the edge
137+ * @param weight the cost of traversing this edge
138+ * @throws IllegalArgumentException if vertices are out of bounds or weight is negative
139+ */
140+ public void addEdge (final int source , final int destination , final int weight ) {
141+ validateVertex (source );
142+ validateVertex (destination );
143+ if (weight < 0 ) {
144+ throw new IllegalArgumentException ("Edge weight cannot be negative" );
145+ }
146+ adjacencyList .get (source ).add (new Edge (destination , weight ));
147+ adjacencyList .get (destination ).add (new Edge (source , weight ));
148+ }
149+
150+ /**
151+ * Validates that a vertex identifier is within acceptable range.
152+ *
153+ * @param vertex the vertex to validate
154+ * @throws IllegalArgumentException if vertex is out of bounds
155+ */
156+ private void validateVertex (final int vertex ) {
157+ if (vertex < 0 || vertex >= vertexCount ) {
158+ throw new IllegalArgumentException (
159+ "Vertex " + vertex + " is out of bounds [0, " + (vertexCount - 1 ) + "]"
160+ );
161+ }
162+ }
163+
164+ /**
165+ * Computes the Minimum Spanning Tree using Prim's algorithm.
166+ * Starts from vertex 0 and greedily selects minimum weight edges
167+ * that connect unvisited vertices to the growing MST.
34168 *
35- * <p>Time Complexity: O(E log V), where E is the number of edges
36- * and V is the number of vertices.</p>
37- * <p>Space Complexity: O(V + E) due to adjacency list and visited array.</p>
169+ * @return MstResult containing total weight and list of MST edges
170+ * @throws IllegalStateException if graph is empty or disconnected
38171 */
39- public static int spanningTree (int vertices , List <? extends List <? extends List <Integer >>> adj ) {
40- // Min-heap to pick edge with the smallest weight
41- PriorityQueue <Pair > pq = new PriorityQueue <>(Comparator .comparingInt (Pair ::weight ));
172+ public MstResult computeMinimumSpanningTree () {
173+ if (vertexCount == 0 ) {
174+ throw new IllegalStateException ("Cannot compute MST on empty graph" );
175+ }
42176
43- // Array to keep track of visited vertices
44- boolean [] visited = new boolean [vertices ];
177+ final boolean [] visitedVertices = new boolean [vertexCount ];
178+ final int [] parent = new int [vertexCount ];
179+ Arrays .fill (parent , -1 );
45180
46- // Start with node 0 (arbitrary choice), with edge weight = 0
47- pq .add (new Pair (0 , 0 ));
181+ final PriorityQueue <Edge > minHeap = new PriorityQueue <>();
182+ final List <int []> mstEdges = new ArrayList <>();
183+ int totalWeight = 0 ;
48184
49- int mstWeightSum = 0 ;
185+ // Start from vertex 0
186+ visitedVertices [0 ] = true ;
187+ for (final Edge edge : adjacencyList .get (0 )) {
188+ minHeap .offer (new Edge (edge .getDestination (), edge .getWeight ()));
189+ }
50190
51- // Process nodes until the priority queue is empty
52- while (!pq .isEmpty ()) {
53- Pair current = pq .poll ();
54- int node = current .node ();
55- int weight = current .weight ();
191+ int processedVertices = 1 ;
56192
57- // Skip if the node is already included in MST
58- if (visited [node ]) {
193+ while (!minHeap .isEmpty () && processedVertices < vertexCount ) {
194+ final Edge currentEdge = minHeap .poll ();
195+ final int currentVertex = currentEdge .getDestination ();
196+
197+ if (visitedVertices [currentVertex ]) {
59198 continue ;
60199 }
61200
62- // Include the node in MST
63- visited [ node ] = true ;
64- mstWeightSum += weight ;
201+ visitedVertices [ currentVertex ] = true ;
202+ totalWeight += currentEdge . getWeight () ;
203+ processedVertices ++ ;
65204
66- // Traverse all adjacent edges
67- for (List <Integer > edge : adj .get (node )) {
68- int adjNode = edge .get (0 );
69- int edgeWeight = edge .get (1 );
205+ // Find the source vertex for this edge
206+ int sourceVertex = findSourceVertex (currentVertex , visitedVertices );
207+ mstEdges .add (new int []{sourceVertex , currentVertex , currentEdge .getWeight ()});
70208
71- // Only consider unvisited nodes
72- if (!visited [adjNode ]) {
73- pq .add (new Pair (adjNode , edgeWeight ));
209+ // Add all edges from newly visited vertex
210+ for (final Edge neighborEdge : adjacencyList .get (currentVertex )) {
211+ if (!visitedVertices [neighborEdge .getDestination ()]) {
212+ minHeap .offer (neighborEdge );
74213 }
75214 }
76215 }
77216
78- return mstWeightSum ;
217+ if (processedVertices < vertexCount ) {
218+ throw new IllegalStateException ("Graph is disconnected - MST cannot be formed" );
219+ }
220+
221+ return new MstResult (totalWeight , mstEdges );
222+ }
223+
224+ /**
225+ * Finds the source vertex that connects to the given destination in the MST.
226+ *
227+ * @param destination the destination vertex to find source for
228+ * @param visited array tracking which vertices are in the MST
229+ * @return the source vertex identifier
230+ */
231+ private int findSourceVertex (final int destination , final boolean [] visited ) {
232+ for (int vertex = 0 ; vertex < vertexCount ; vertex ++) {
233+ if (visited [vertex ]) {
234+ for (final Edge edge : adjacencyList .get (vertex )) {
235+ if (edge .getDestination () == destination ) {
236+ return vertex ;
237+ }
238+ }
239+ }
240+ }
241+ return -1 ;
242+ }
243+
244+ /**
245+ * Returns the number of vertices in the graph.
246+ *
247+ * @return vertex count
248+ */
249+ public int getVertexCount () {
250+ return vertexCount ;
79251 }
80- }
252+ }
0 commit comments