diff --git a/ngraph/lib/demand.py b/ngraph/lib/demand.py index 37d094c..a2b06da 100644 --- a/ngraph/lib/demand.py +++ b/ngraph/lib/demand.py @@ -1,8 +1,10 @@ from __future__ import annotations +import math from dataclasses import dataclass, field from typing import Optional, Tuple +from ngraph.lib.algorithms.base import MIN_FLOW from ngraph.lib.flow_policy import FlowPolicy from ngraph.lib.graph import NodeID, StrictMultiDiGraph @@ -21,6 +23,16 @@ class Demand: flow_policy: Optional[FlowPolicy] = None placed_demand: float = field(default=0.0, init=False) + @staticmethod + def _round_float(value: float) -> float: + """Round ``value`` to avoid tiny floating point drift.""" + if math.isfinite(value): + rounded = round(value, 12) + if abs(rounded) < MIN_FLOW: + return 0.0 + return rounded + return value + def __lt__(self, other: Demand) -> bool: """ Compare Demands by their demand_class (priority). A lower demand_class @@ -94,7 +106,10 @@ def place( # placed_now is the difference from the old placed_demand placed_now = self.flow_policy.placed_demand - self.placed_demand - self.placed_demand = self.flow_policy.placed_demand + self.placed_demand = self._round_float(self.flow_policy.placed_demand) remaining = to_place - placed_now + placed_now = self._round_float(placed_now) + remaining = self._round_float(remaining) + return placed_now, remaining