1010from ngraph .lib .graph import NodeID , StrictMultiDiGraph
1111
1212
13+ # Use @overload to provide precise static type safety for conditional return types.
14+ # The function returns different types based on boolean flags: float, tuple[float, FlowSummary],
15+ # tuple[float, StrictMultiDiGraph], or tuple[float, FlowSummary, StrictMultiDiGraph].
1316@overload
1417def calc_max_flow (
1518 graph : StrictMultiDiGraph ,
@@ -25,6 +28,7 @@ def calc_max_flow(
2528 flow_attr : str = "flow" ,
2629 flows_attr : str = "flows" ,
2730 copy_graph : bool = True ,
31+ tolerance : float = 1e-10 ,
2832) -> float : ...
2933
3034
@@ -43,6 +47,7 @@ def calc_max_flow(
4347 flow_attr : str = "flow" ,
4448 flows_attr : str = "flows" ,
4549 copy_graph : bool = True ,
50+ tolerance : float = 1e-10 ,
4651) -> tuple [float , FlowSummary ]: ...
4752
4853
@@ -61,6 +66,7 @@ def calc_max_flow(
6166 flow_attr : str = "flow" ,
6267 flows_attr : str = "flows" ,
6368 copy_graph : bool = True ,
69+ tolerance : float = 1e-10 ,
6470) -> tuple [float , StrictMultiDiGraph ]: ...
6571
6672
@@ -79,6 +85,7 @@ def calc_max_flow(
7985 flow_attr : str = "flow" ,
8086 flows_attr : str = "flows" ,
8187 copy_graph : bool = True ,
88+ tolerance : float = 1e-10 ,
8289) -> tuple [float , FlowSummary , StrictMultiDiGraph ]: ...
8390
8491
@@ -96,6 +103,7 @@ def calc_max_flow(
96103 flow_attr : str = "flow" ,
97104 flows_attr : str = "flows" ,
98105 copy_graph : bool = True ,
106+ tolerance : float = 1e-10 ,
99107) -> Union [float , tuple ]:
100108 """Compute the maximum flow between two nodes in a directed multi-graph,
101109 using an iterative shortest-path augmentation approach.
@@ -141,6 +149,9 @@ def calc_max_flow(
141149 copy_graph (bool):
142150 If True, work on a copy of the original graph so it remains unmodified.
143151 Defaults to True.
152+ tolerance (float):
153+ Tolerance for floating-point comparisons when determining saturated edges
154+ and residual capacity. Defaults to 1e-10.
144155
145156 Returns:
146157 Union[float, tuple]:
@@ -196,6 +207,7 @@ def calc_max_flow(
196207 return_graph ,
197208 capacity_attr ,
198209 flow_attr ,
210+ tolerance ,
199211 )
200212 else :
201213 return 0.0
@@ -234,6 +246,7 @@ def calc_max_flow(
234246 return_graph ,
235247 capacity_attr ,
236248 flow_attr ,
249+ tolerance ,
237250 )
238251
239252 # Otherwise, repeatedly find augmenting paths until no new flow can be placed.
@@ -255,8 +268,8 @@ def calc_max_flow(
255268 flow_attr = flow_attr ,
256269 flows_attr = flows_attr ,
257270 )
258- if flow_meta .placed_flow <= 0 :
259- # No additional flow could be placed; at capacity.
271+ if flow_meta .placed_flow <= tolerance :
272+ # No significant additional flow could be placed; at capacity.
260273 break
261274
262275 max_flow += flow_meta .placed_flow
@@ -269,6 +282,7 @@ def calc_max_flow(
269282 return_graph ,
270283 capacity_attr ,
271284 flow_attr ,
285+ tolerance ,
272286 )
273287
274288
@@ -280,6 +294,7 @@ def _build_return_value(
280294 return_graph : bool ,
281295 capacity_attr : str ,
282296 flow_attr : str ,
297+ tolerance : float ,
283298) -> Union [float , tuple ]:
284299 """Build the appropriate return value based on the requested flags."""
285300 if not (return_summary or return_graph ):
@@ -288,7 +303,7 @@ def _build_return_value(
288303 summary = None
289304 if return_summary :
290305 summary = _build_flow_summary (
291- max_flow , flow_graph , src_node , capacity_attr , flow_attr
306+ max_flow , flow_graph , src_node , capacity_attr , flow_attr , tolerance
292307 )
293308
294309 ret : list = [max_flow ]
@@ -306,6 +321,7 @@ def _build_flow_summary(
306321 src_node : NodeID ,
307322 capacity_attr : str ,
308323 flow_attr : str ,
324+ tolerance : float ,
309325) -> FlowSummary :
310326 """Build a FlowSummary from the flow graph state."""
311327 edge_flow = {}
@@ -327,7 +343,10 @@ def _build_flow_summary(
327343 continue
328344 reachable .add (n )
329345 for _ , nbr , _ , d in flow_graph .out_edges (n , data = True , keys = True ):
330- if d [capacity_attr ] - d .get (flow_attr , 0.0 ) > 0 and nbr not in reachable :
346+ if (
347+ d [capacity_attr ] - d .get (flow_attr , 0.0 ) > tolerance
348+ and nbr not in reachable
349+ ):
331350 stack .append (nbr )
332351
333352 # Find min-cut edges (saturated edges crossing the cut)
@@ -336,7 +355,7 @@ def _build_flow_summary(
336355 for u , v , k , d in flow_graph .edges (data = True , keys = True )
337356 if u in reachable
338357 and v not in reachable
339- and d [capacity_attr ] - d .get (flow_attr , 0.0 ) == 0
358+ and d [capacity_attr ] - d .get (flow_attr , 0.0 ) <= tolerance
340359 ]
341360
342361 return FlowSummary (
0 commit comments