Skip to content

Commit fa54fdc

Browse files
committed
spf with early termination
1 parent 6508bc0 commit fa54fdc

File tree

6 files changed

+10246
-89
lines changed

6 files changed

+10246
-89
lines changed

dev/bench_pairwise_maxflow.py

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""Benchmark pairwise max-flow on a scenario graph.
2+
3+
This script loads a scenario YAML, builds its `StrictMultiDiGraph`, identifies
4+
datacenter nodes by a regex pattern, and measures the runtime of three modes:
5+
6+
- bare: direct pairwise calls to `calc_max_flow` on graph copies
7+
- solver: `ngraph.solver.maxflow.max_flow(..., mode="pairwise")`
8+
- fm: FailureManager `run_max_flow_monte_carlo` with iterations=1
9+
10+
Use this to compare backbone vs clos scenarios and isolate overheads.
11+
12+
Run examples:
13+
14+
python -m dev.bench_pairwise_maxflow \
15+
--scenario scenarios/backbone.yml --limit-pairs 50
16+
17+
python -m dev.bench_pairwise_maxflow \
18+
--scenario scenarios/clos_scenario.yml --limit-pairs 50
19+
20+
Notes:
21+
- The script does not modify the repository state and writes no files by default.
22+
- For large scenarios, consider `--limit-pairs` or `--max-metros` to constrain work.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import argparse
28+
import statistics
29+
import time
30+
from itertools import islice
31+
from pathlib import Path
32+
from typing import Iterable, Iterator, Sequence, Tuple
33+
34+
from ngraph.algorithms.max_flow import calc_max_flow
35+
from ngraph.failure.manager.manager import FailureManager
36+
from ngraph.graph.strict_multidigraph import StrictMultiDiGraph
37+
from ngraph.scenario import Scenario
38+
from ngraph.solver.maxflow import max_flow as solver_max_flow
39+
40+
41+
def _pairwise(iterable: Sequence[str]) -> Iterator[Tuple[str, str]]:
42+
for i, a in enumerate(iterable):
43+
for j, b in enumerate(iterable):
44+
if i == j:
45+
continue
46+
yield a, b
47+
48+
49+
def _take(it: Iterable[Tuple[str, str]], n: int | None) -> list[Tuple[str, str]]:
50+
if n is None or n <= 0:
51+
return list(it)
52+
return list(islice(it, n))
53+
54+
55+
def _select_dc_nodes(scenario: Scenario, group_pattern: str) -> list[Tuple[str, str]]:
56+
"""Return list of (label, node_name) for DC groups.
57+
58+
Args:
59+
scenario: Loaded scenario instance.
60+
group_pattern: Regex used to group/select DC nodes; labels are group labels.
61+
62+
Returns:
63+
List of tuples where each tuple is (group_label, node_name). If a group
64+
contains multiple nodes, each node is returned with the same label.
65+
"""
66+
groups = scenario.network.select_node_groups_by_path(group_pattern)
67+
items: list[Tuple[str, str]] = []
68+
for label, nodes in groups.items():
69+
for node in nodes:
70+
items.append((label, node.name))
71+
# Stable by label then node name
72+
items.sort(key=lambda p: (p[0], p[1]))
73+
return items
74+
75+
76+
def _build_base_graph(scenario: Scenario) -> StrictMultiDiGraph:
77+
"""Build the base graph from the scenario network (bidirectional links)."""
78+
return scenario.network.to_strict_multidigraph(add_reverse=True)
79+
80+
81+
def _bench_bare_pairwise(
82+
base_graph: StrictMultiDiGraph,
83+
dc_nodes: list[Tuple[str, str]],
84+
limit_pairs: int | None,
85+
) -> dict:
86+
"""Benchmark direct calc_max_flow calls with pseudo source/sink per pair.
87+
88+
For each (src_node, dst_node), copy the base_graph, attach pseudo nodes with
89+
infinite capacity, then call `calc_max_flow(copy_graph=False)`.
90+
"""
91+
# Deduplicate labels by node selection order
92+
node_names: list[str] = [n for (_, n) in dc_nodes]
93+
pairs: list[Tuple[str, str]] = _take(_pairwise(node_names), limit_pairs)
94+
95+
times_ms: list[float] = []
96+
last_flow: float | None = None
97+
t_start = time.perf_counter()
98+
99+
for src, dst in pairs:
100+
g = base_graph.copy()
101+
g.add_node("source")
102+
g.add_node("sink")
103+
g.add_edge("source", src, capacity=float("inf"), cost=0)
104+
g.add_edge(dst, "sink", capacity=float("inf"), cost=0)
105+
t0 = time.perf_counter()
106+
last_flow = float(
107+
calc_max_flow(
108+
g,
109+
"source",
110+
"sink",
111+
copy_graph=False,
112+
)
113+
)
114+
t1 = time.perf_counter()
115+
times_ms.append(1000.0 * (t1 - t0))
116+
117+
t_end = time.perf_counter()
118+
return {
119+
"pairs": len(pairs),
120+
"last_flow": last_flow,
121+
"elapsed_s": (t_end - t_start),
122+
"min_ms": min(times_ms) if times_ms else 0.0,
123+
"mean_ms": statistics.mean(times_ms) if times_ms else 0.0,
124+
"median_ms": statistics.median(times_ms) if times_ms else 0.0,
125+
"max_ms": max(times_ms) if times_ms else 0.0,
126+
}
127+
128+
129+
def _bench_solver_pairwise(
130+
scenario: Scenario,
131+
group_pattern: str,
132+
) -> dict:
133+
"""Benchmark solver.max_flow pairwise over all groups (no limit)."""
134+
t0 = time.perf_counter()
135+
flows = solver_max_flow(
136+
scenario.network,
137+
group_pattern,
138+
group_pattern,
139+
mode="pairwise",
140+
shortest_path=False,
141+
)
142+
t1 = time.perf_counter()
143+
return {
144+
"pairs": len(flows),
145+
"elapsed_s": (t1 - t0),
146+
"nonzero": sum(1 for v in flows.values() if v > 0.0),
147+
}
148+
149+
150+
def _bench_failure_manager(
151+
scenario: Scenario,
152+
group_pattern: str,
153+
) -> dict:
154+
"""Benchmark FailureManager with iterations=1, pairwise mode.
155+
156+
This mirrors the CapacityEnvelopeAnalysis step configuration used in scenarios.
157+
"""
158+
fm = FailureManager(
159+
network=scenario.network,
160+
failure_policy_set=scenario.failure_policy_set,
161+
policy_name=None,
162+
)
163+
t0 = time.perf_counter()
164+
res = fm.run_max_flow_monte_carlo(
165+
source_path=group_pattern,
166+
sink_path=group_pattern,
167+
mode="pairwise",
168+
iterations=1,
169+
parallelism=1,
170+
shortest_path=False,
171+
flow_placement="PROPORTIONAL",
172+
baseline=False,
173+
seed=scenario.seed,
174+
store_failure_patterns=False,
175+
include_flow_summary=False,
176+
)
177+
t1 = time.perf_counter()
178+
meta = getattr(res, "metadata", {})
179+
envs = getattr(res, "envelopes", {})
180+
return {
181+
"pairs": len(envs),
182+
"elapsed_s": (t1 - t0),
183+
"meta_time_s": float(meta.get("execution_time", 0.0))
184+
if isinstance(meta, dict)
185+
else 0.0,
186+
}
187+
188+
189+
def main(argv: list[str] | None = None) -> int:
190+
"""CLI entry point.
191+
192+
Args:
193+
argv: Optional argument vector.
194+
195+
Returns:
196+
Process exit code (0 on success).
197+
"""
198+
parser = argparse.ArgumentParser(
199+
description="Benchmark pairwise max-flow for a scenario"
200+
)
201+
parser.add_argument(
202+
"--scenario",
203+
type=Path,
204+
required=True,
205+
help="Path to scenario YAML",
206+
)
207+
parser.add_argument(
208+
"--pattern",
209+
type=str,
210+
default=r"(metro[0-9]+/dc[0-9]+)",
211+
help="Regex to group/select DC nodes (use capturing group to label)",
212+
)
213+
parser.add_argument(
214+
"--limit-pairs",
215+
type=int,
216+
default=0,
217+
help="Limit number of pairwise computations in bare mode (0 = all)",
218+
)
219+
parser.add_argument(
220+
"--skip-modes",
221+
type=str,
222+
default="",
223+
help="Comma-separated modes to skip: bare,solver,fm",
224+
)
225+
226+
args = parser.parse_args(argv)
227+
skip = {s.strip() for s in args.skip_modes.split(",") if s.strip()}
228+
229+
yaml_text = args.scenario.read_text()
230+
scenario = Scenario.from_yaml(yaml_text)
231+
232+
# Build graph and DC node list
233+
base_graph = _build_base_graph(scenario)
234+
dc_items = _select_dc_nodes(scenario, args.pattern)
235+
dc_labels = sorted({lbl for (lbl, _) in dc_items})
236+
237+
print(f"scenario: {args.scenario}")
238+
print(
239+
f"graph: nodes={len(base_graph)}, edges={base_graph.number_of_edges()} | dcs={len(dc_labels)}"
240+
)
241+
242+
# Mode: bare
243+
if "bare" not in skip:
244+
print("[bench] bare pairwise calc_max_flow ...")
245+
bare_stats = _bench_bare_pairwise(
246+
base_graph,
247+
dc_items,
248+
None if args.limit_pairs <= 0 else int(args.limit_pairs),
249+
)
250+
print(
251+
f"[bare ] pairs={bare_stats['pairs']} elapsed={bare_stats['elapsed_s']:.3f}s "
252+
f"min/mean/med/max={bare_stats['min_ms']:.2f}/{bare_stats['mean_ms']:.2f}/"
253+
f"{bare_stats['median_ms']:.2f}/{bare_stats['max_ms']:.2f} ms"
254+
)
255+
256+
# Mode: solver (full pairwise)
257+
if "solver" not in skip:
258+
print("[bench] solver.max_flow pairwise ...")
259+
sol_stats = _bench_solver_pairwise(scenario, args.pattern)
260+
print(
261+
f"[solve] pairs={sol_stats['pairs']} elapsed={sol_stats['elapsed_s']:.3f}s "
262+
f"nonzero={sol_stats['nonzero']}"
263+
)
264+
265+
# Mode: FailureManager (iterations=1)
266+
if "fm" not in skip:
267+
print("[bench] FailureManager iterations=1 pairwise ...")
268+
fm_stats = _bench_failure_manager(scenario, args.pattern)
269+
print(
270+
f"[fm ] pairs={fm_stats['pairs']} elapsed={fm_stats['elapsed_s']:.3f}s "
271+
f"meta_time={fm_stats['meta_time_s']:.3f}s"
272+
)
273+
274+
return 0
275+
276+
277+
if __name__ == "__main__": # pragma: no cover - manual utility
278+
raise SystemExit(main())

ngraph/algorithms/max_flow.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,10 @@ def calc_max_flow(
233233

234234
# First path-finding iteration.
235235
costs, pred = spf(
236-
flow_graph, src_node, edge_select=EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING
236+
flow_graph,
237+
src_node,
238+
edge_select=EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING,
239+
dst_node=dst_node,
237240
)
238241
flow_meta = place_flow_on_graph(
239242
flow_graph,
@@ -271,7 +274,10 @@ def calc_max_flow(
271274
# Otherwise, repeatedly find augmenting paths until no new flow can be placed.
272275
while True:
273276
costs, pred = spf(
274-
flow_graph, src_node, edge_select=EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING
277+
flow_graph,
278+
src_node,
279+
edge_select=EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING,
280+
dst_node=dst_node,
275281
)
276282
if dst_node not in pred:
277283
# No path found; we've reached max flow.

0 commit comments

Comments
 (0)