11from __future__ import annotations
22
3+ import math
34from dataclasses import dataclass , field
45from typing import Optional , Tuple
56
7+ from ngraph .lib .algorithms .base import MIN_FLOW
68from ngraph .lib .flow_policy import FlowPolicy
79from ngraph .lib .graph import NodeID , StrictMultiDiGraph
810
@@ -21,6 +23,16 @@ class Demand:
2123 flow_policy : Optional [FlowPolicy ] = None
2224 placed_demand : float = field (default = 0.0 , init = False )
2325
26+ @staticmethod
27+ def _round_float (value : float ) -> float :
28+ """Round ``value`` to avoid tiny floating point drift."""
29+ if math .isfinite (value ):
30+ rounded = round (value , 12 )
31+ if abs (rounded ) < MIN_FLOW :
32+ return 0.0
33+ return rounded
34+ return value
35+
2436 def __lt__ (self , other : Demand ) -> bool :
2537 """
2638 Compare Demands by their demand_class (priority). A lower demand_class
@@ -94,7 +106,10 @@ def place(
94106
95107 # placed_now is the difference from the old placed_demand
96108 placed_now = self .flow_policy .placed_demand - self .placed_demand
97- self .placed_demand = self .flow_policy .placed_demand
109+ self .placed_demand = self ._round_float ( self . flow_policy .placed_demand )
98110 remaining = to_place - placed_now
99111
112+ placed_now = self ._round_float (placed_now )
113+ remaining = self ._round_float (remaining )
114+
100115 return placed_now , remaining
0 commit comments