Skip to content

Commit 37edf77

Browse files
committed
Fixed calc_max_flow function by adding a tolerance parameter for floating-point comparisons instead of eq to zero.
1 parent 87cd142 commit 37edf77

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

docs/reference/api-full.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ For a curated, example-driven API guide, see **[api.md](api.md)**.
1010
> - **[CLI Reference](cli.md)** - Command-line interface
1111
> - **[DSL Reference](dsl.md)** - YAML syntax guide
1212
13-
**Generated from source code on:** June 20, 2025 at 01:48 UTC
13+
**Generated from source code on:** June 20, 2025 at 02:06 UTC
1414

1515
**Modules auto-discovered:** 48
1616

@@ -1638,7 +1638,7 @@ Returns:
16381638

16391639
Maximum flow algorithms and network flow computations.
16401640

1641-
### calc_max_flow(graph: ngraph.lib.graph.StrictMultiDiGraph, src_node: Hashable, dst_node: Hashable, *, return_summary: bool = False, return_graph: bool = False, flow_placement: ngraph.lib.algorithms.base.FlowPlacement = <FlowPlacement.PROPORTIONAL: 1>, shortest_path: bool = False, reset_flow_graph: bool = False, capacity_attr: str = 'capacity', flow_attr: str = 'flow', flows_attr: str = 'flows', copy_graph: bool = True) -> Union[float, tuple]
1641+
### calc_max_flow(graph: ngraph.lib.graph.StrictMultiDiGraph, src_node: Hashable, dst_node: Hashable, *, return_summary: bool = False, return_graph: bool = False, flow_placement: ngraph.lib.algorithms.base.FlowPlacement = <FlowPlacement.PROPORTIONAL: 1>, shortest_path: bool = False, reset_flow_graph: bool = False, capacity_attr: str = 'capacity', flow_attr: str = 'flow', flows_attr: str = 'flows', copy_graph: bool = True, tolerance: float = 1e-10) -> Union[float, tuple]
16421642

16431643
Compute the maximum flow between two nodes in a directed multi-graph,
16441644
using an iterative shortest-path augmentation approach.
@@ -1684,6 +1684,9 @@ Args:
16841684
copy_graph (bool):
16851685
If True, work on a copy of the original graph so it remains unmodified.
16861686
Defaults to True.
1687+
tolerance (float):
1688+
Tolerance for floating-point comparisons when determining saturated edges
1689+
and residual capacity. Defaults to 1e-10.
16871690

16881691
Returns:
16891692
Union[float, tuple]:

ngraph/lib/algorithms/max_flow.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from ngraph.lib.graph import NodeID, StrictMultiDiGraph
1111

1212

13+
# Use @overload to provide precise static type safety for conditional return types.
14+
# The function returns different types based on boolean flags: float, tuple[float, FlowSummary],
15+
# tuple[float, StrictMultiDiGraph], or tuple[float, FlowSummary, StrictMultiDiGraph].
1316
@overload
1417
def calc_max_flow(
1518
graph: StrictMultiDiGraph,
@@ -25,6 +28,7 @@ def calc_max_flow(
2528
flow_attr: str = "flow",
2629
flows_attr: str = "flows",
2730
copy_graph: bool = True,
31+
tolerance: float = 1e-10,
2832
) -> float: ...
2933

3034

@@ -43,6 +47,7 @@ def calc_max_flow(
4347
flow_attr: str = "flow",
4448
flows_attr: str = "flows",
4549
copy_graph: bool = True,
50+
tolerance: float = 1e-10,
4651
) -> tuple[float, FlowSummary]: ...
4752

4853

@@ -61,6 +66,7 @@ def calc_max_flow(
6166
flow_attr: str = "flow",
6267
flows_attr: str = "flows",
6368
copy_graph: bool = True,
69+
tolerance: float = 1e-10,
6470
) -> tuple[float, StrictMultiDiGraph]: ...
6571

6672

@@ -79,6 +85,7 @@ def calc_max_flow(
7985
flow_attr: str = "flow",
8086
flows_attr: str = "flows",
8187
copy_graph: bool = True,
88+
tolerance: float = 1e-10,
8289
) -> tuple[float, FlowSummary, StrictMultiDiGraph]: ...
8390

8491

@@ -96,6 +103,7 @@ def calc_max_flow(
96103
flow_attr: str = "flow",
97104
flows_attr: str = "flows",
98105
copy_graph: bool = True,
106+
tolerance: float = 1e-10,
99107
) -> Union[float, tuple]:
100108
"""Compute the maximum flow between two nodes in a directed multi-graph,
101109
using an iterative shortest-path augmentation approach.
@@ -141,6 +149,9 @@ def calc_max_flow(
141149
copy_graph (bool):
142150
If True, work on a copy of the original graph so it remains unmodified.
143151
Defaults to True.
152+
tolerance (float):
153+
Tolerance for floating-point comparisons when determining saturated edges
154+
and residual capacity. Defaults to 1e-10.
144155
145156
Returns:
146157
Union[float, tuple]:
@@ -196,6 +207,7 @@ def calc_max_flow(
196207
return_graph,
197208
capacity_attr,
198209
flow_attr,
210+
tolerance,
199211
)
200212
else:
201213
return 0.0
@@ -234,6 +246,7 @@ def calc_max_flow(
234246
return_graph,
235247
capacity_attr,
236248
flow_attr,
249+
tolerance,
237250
)
238251

239252
# Otherwise, repeatedly find augmenting paths until no new flow can be placed.
@@ -255,8 +268,8 @@ def calc_max_flow(
255268
flow_attr=flow_attr,
256269
flows_attr=flows_attr,
257270
)
258-
if flow_meta.placed_flow <= 0:
259-
# No additional flow could be placed; at capacity.
271+
if flow_meta.placed_flow <= tolerance:
272+
# No significant additional flow could be placed; at capacity.
260273
break
261274

262275
max_flow += flow_meta.placed_flow
@@ -269,6 +282,7 @@ def calc_max_flow(
269282
return_graph,
270283
capacity_attr,
271284
flow_attr,
285+
tolerance,
272286
)
273287

274288

@@ -280,6 +294,7 @@ def _build_return_value(
280294
return_graph: bool,
281295
capacity_attr: str,
282296
flow_attr: str,
297+
tolerance: float,
283298
) -> Union[float, tuple]:
284299
"""Build the appropriate return value based on the requested flags."""
285300
if not (return_summary or return_graph):
@@ -288,7 +303,7 @@ def _build_return_value(
288303
summary = None
289304
if return_summary:
290305
summary = _build_flow_summary(
291-
max_flow, flow_graph, src_node, capacity_attr, flow_attr
306+
max_flow, flow_graph, src_node, capacity_attr, flow_attr, tolerance
292307
)
293308

294309
ret: list = [max_flow]
@@ -306,6 +321,7 @@ def _build_flow_summary(
306321
src_node: NodeID,
307322
capacity_attr: str,
308323
flow_attr: str,
324+
tolerance: float,
309325
) -> FlowSummary:
310326
"""Build a FlowSummary from the flow graph state."""
311327
edge_flow = {}
@@ -327,7 +343,10 @@ def _build_flow_summary(
327343
continue
328344
reachable.add(n)
329345
for _, nbr, _, d in flow_graph.out_edges(n, data=True, keys=True):
330-
if d[capacity_attr] - d.get(flow_attr, 0.0) > 0 and nbr not in reachable:
346+
if (
347+
d[capacity_attr] - d.get(flow_attr, 0.0) > tolerance
348+
and nbr not in reachable
349+
):
331350
stack.append(nbr)
332351

333352
# Find min-cut edges (saturated edges crossing the cut)
@@ -336,7 +355,7 @@ def _build_flow_summary(
336355
for u, v, k, d in flow_graph.edges(data=True, keys=True)
337356
if u in reachable
338357
and v not in reachable
339-
and d[capacity_attr] - d.get(flow_attr, 0.0) == 0
358+
and d[capacity_attr] - d.get(flow_attr, 0.0) <= tolerance
340359
]
341360

342361
return FlowSummary(

0 commit comments

Comments
 (0)