diff --git a/src/main/java/com/williamfiset/algorithms/graphtheory/Boruvkas.java b/src/main/java/com/williamfiset/algorithms/graphtheory/Boruvkas.java index 1be3a5f7d..e132017ae 100644 --- a/src/main/java/com/williamfiset/algorithms/graphtheory/Boruvkas.java +++ b/src/main/java/com/williamfiset/algorithms/graphtheory/Boruvkas.java @@ -1,11 +1,10 @@ -/** WIP */ package com.williamfiset.algorithms.graphtheory; import java.util.*; public class Boruvkas { - static class Edge { + static class Edge implements Comparable { int u, v, cost; public Edge(int u, int v, int cost) { @@ -14,11 +13,12 @@ public Edge(int u, int v, int cost) { this.cost = cost; } + @Override public String toString() { return String.format("%d %d, cost: %d", u, v, cost); } - // @Override + @Override public int compareTo(Edge other) { int cmp = cost - other.cost; // Break ties by picking lexicographically smallest edge pair. @@ -32,7 +32,7 @@ public int compareTo(Edge other) { } // Inputs - private final int n, m; // Num nodes, num edges + private final int n; // Number of nodes private final Edge[] graph; // Edge list // Internal @@ -43,11 +43,10 @@ public int compareTo(Edge other) { private long minCostSum; private List mst; - public Boruvkas(int n, int m, Edge[] graph) { + public Boruvkas(int n, Edge[] graph) { if (graph == null) throw new IllegalArgumentException(); this.graph = graph; this.n = n; - this.m = m; } // Returns the edges used in finding the minimum spanning tree, or returns @@ -69,62 +68,43 @@ private void solve() { if (solved) return; mst = new ArrayList<>(); - UnionFind uf = new UnionFind(n); - int[] cheapest = new int[n]; - Arrays.fill(cheapest, -1); - - // Repeat at most log(n) times or until we have a complete spanning tree. - // for(int t = 1; t < N && index < n - 1; t = t + t) { - // for(long t = 1; t <= n && mst.size() != n-1; t = t << 1) { - for (; mst.size() != n - 1; ) { - - // TODO: Remove - Arrays.fill(cheapest, -1); - boolean stop = true; - - for (int i = 0; i < graph.length; i++) { - Edge e = graph[i]; - if (e.u == e.v) continue; - int uc = uf.id[e.u], vc = uf.id[e.v]; - if (uc == vc) continue; - // if (cheapest[vc] == -1 || e.compareTo(graph[cheapest[vc]]) < 0) { stop = false; - // cheapest[vc] = i; } - // if (cheapest[uc] == -1 || e.compareTo(graph[cheapest[uc]]) < 0) { stop = false; - // cheapest[uc] = i; } - if (cheapest[vc] == -1 || e.cost < graph[cheapest[vc]].cost) { - stop = false; - cheapest[vc] = i; + while (uf.components > 1) { + Edge[] cheapest = new Edge[n]; + + // Find the cheapest edge for each component + for (Edge e : graph) { + int root1 = uf.find(e.u); + int root2 = uf.find(e.v); + if (root1 == root2) continue; + + if (cheapest[root1] == null || e.cost < cheapest[root1].cost) { + cheapest[root1] = e; } - if (cheapest[uc] == -1 || e.cost < graph[cheapest[uc]].cost) { - stop = false; - cheapest[uc] = i; + if (cheapest[root2] == null || e.cost < cheapest[root2].cost) { + cheapest[root2] = e; } } - if (stop) break; - + // Add the cheapest edges to the MST for (int i = 0; i < n; i++) { - if (cheapest[i] == -1) continue; - Edge e = graph[cheapest[i]]; - // cheapest[i] = -1; - if (uf.connected(e.u, e.v)) continue; - - mst.add(e); - minCostSum += e.cost; - uf.union(e.u, e.v); - - // TODO(williamfiset): Optimization is to remove e from graph. + Edge e = cheapest[i]; + if (e == null) { + continue; + } + int root1 = uf.find(e.u); + int root2 = uf.find(e.v); + if (root1 != root2) { + uf.union(root1, root2); + mst.add(e); + minCostSum += e.cost; + } } } - // if ( (index==n-1) != (uf.size(0) == n) ) throw new NullPointerException(); - - mstExists = (mst.size() == n - 1); // (uf.size(0) == n); + mstExists = (mst.size() == n - 1); solved = true; - - // if (!check()) throw new IllegalStateException(); } private boolean check() { @@ -200,7 +180,7 @@ public static void main(String[] args) { g[i++] = new Edge(7, 8, 6); g[i++] = new Edge(9, 8, 0); - Boruvkas solver = new Boruvkas(n, m, g); + Boruvkas solver = new Boruvkas(n, g); Long ans = solver.getMstCost(); if (ans != null) { @@ -211,9 +191,6 @@ public static void main(String[] args) { } else { System.out.println("No MST exists"); } - - // System.out.println(solver.solve(g, n)); - } // Union find data structure