Skip to content

Commit 87c4455

Browse files
authored
Fix floating point rounding in Demand.place (#65)
1 parent 3bf9aed commit 87c4455

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

ngraph/lib/demand.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
import math
34
from dataclasses import dataclass, field
45
from typing import Optional, Tuple
56

7+
from ngraph.lib.algorithms.base import MIN_FLOW
68
from ngraph.lib.flow_policy import FlowPolicy
79
from ngraph.lib.graph import NodeID, StrictMultiDiGraph
810

@@ -21,6 +23,16 @@ class Demand:
2123
flow_policy: Optional[FlowPolicy] = None
2224
placed_demand: float = field(default=0.0, init=False)
2325

26+
@staticmethod
27+
def _round_float(value: float) -> float:
28+
"""Round ``value`` to avoid tiny floating point drift."""
29+
if math.isfinite(value):
30+
rounded = round(value, 12)
31+
if abs(rounded) < MIN_FLOW:
32+
return 0.0
33+
return rounded
34+
return value
35+
2436
def __lt__(self, other: Demand) -> bool:
2537
"""
2638
Compare Demands by their demand_class (priority). A lower demand_class
@@ -94,7 +106,10 @@ def place(
94106

95107
# placed_now is the difference from the old placed_demand
96108
placed_now = self.flow_policy.placed_demand - self.placed_demand
97-
self.placed_demand = self.flow_policy.placed_demand
109+
self.placed_demand = self._round_float(self.flow_policy.placed_demand)
98110
remaining = to_place - placed_now
99111

112+
placed_now = self._round_float(placed_now)
113+
remaining = self._round_float(remaining)
114+
100115
return placed_now, remaining

0 commit comments

Comments
 (0)