Skip to content

Commit b2d53bc

Browse files
committed
Add include_disabled parameter to NetworkStats for optional inclusion of disabled nodes and links in statistics. Update related logic and tests to ensure backward compatibility and correct behavior.
1 parent ef51696 commit b2d53bc

File tree

2 files changed

+172
-13
lines changed

2 files changed

+172
-13
lines changed

ngraph/workflow/network_stats.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@
1414

1515
@dataclass
1616
class NetworkStats(WorkflowStep):
17-
"""Compute basic node and link statistics for the network."""
17+
"""Compute basic node and link statistics for the network.
18+
19+
Attributes:
20+
include_disabled (bool): If True, include disabled nodes and links in statistics.
21+
If False, only consider enabled entities. Defaults to False.
22+
"""
23+
24+
include_disabled: bool = False
1825

1926
def run(self, scenario: Scenario) -> None:
2027
"""Collect capacity and degree statistics.
@@ -25,9 +32,14 @@ def run(self, scenario: Scenario) -> None:
2532

2633
network = scenario.network
2734

28-
link_caps = [
29-
link.capacity for link in network.links.values() if not link.disabled
30-
]
35+
# Collect link capacity statistics - filter based on include_disabled setting
36+
if self.include_disabled:
37+
link_caps = [link.capacity for link in network.links.values()]
38+
else:
39+
link_caps = [
40+
link.capacity for link in network.links.values() if not link.disabled
41+
]
42+
3143
link_caps_sorted = sorted(link_caps)
3244
link_stats = {
3345
"values": link_caps_sorted,
@@ -37,17 +49,29 @@ def run(self, scenario: Scenario) -> None:
3749
"median": median(link_caps_sorted) if link_caps_sorted else 0.0,
3850
}
3951

52+
# Collect per-node statistics and aggregate data for distributions
4053
node_stats: Dict[str, Dict[str, List[float] | float]] = {}
4154
node_capacities = []
4255
node_degrees = []
4356
for node_name, node in network.nodes.items():
44-
if node.disabled:
57+
# Skip disabled nodes unless include_disabled is True
58+
if not self.include_disabled and node.disabled:
4559
continue
46-
outgoing = [
47-
link.capacity
48-
for link in network.links.values()
49-
if link.source == node_name and not link.disabled
50-
]
60+
61+
# Calculate node degree and capacity - filter links based on include_disabled setting
62+
if self.include_disabled:
63+
outgoing = [
64+
link.capacity
65+
for link in network.links.values()
66+
if link.source == node_name
67+
]
68+
else:
69+
outgoing = [
70+
link.capacity
71+
for link in network.links.values()
72+
if link.source == node_name and not link.disabled
73+
]
74+
5175
degree = len(outgoing)
5276
cap_sum = sum(outgoing)
5377

@@ -60,6 +84,7 @@ def run(self, scenario: Scenario) -> None:
6084
"capacities": sorted(outgoing),
6185
}
6286

87+
# Create aggregate distributions for network-wide analysis
6388
node_caps_sorted = sorted(node_capacities)
6489
node_degrees_sorted = sorted(node_degrees)
6590

@@ -73,10 +98,10 @@ def run(self, scenario: Scenario) -> None:
7398

7499
node_degree_dist = {
75100
"values": node_degrees_sorted,
76-
"min": min(node_degrees_sorted) if node_degrees_sorted else 0,
77-
"max": max(node_degrees_sorted) if node_degrees_sorted else 0,
101+
"min": min(node_degrees_sorted) if node_degrees_sorted else 0.0,
102+
"max": max(node_degrees_sorted) if node_degrees_sorted else 0.0,
78103
"mean": mean(node_degrees_sorted) if node_degrees_sorted else 0.0,
79-
"median": median(node_degrees_sorted) if node_degrees_sorted else 0,
104+
"median": median(node_degrees_sorted) if node_degrees_sorted else 0.0,
80105
}
81106

82107
scenario.results.put(self.name, "link_capacity", link_stats)

tests/workflow/test_network_stats.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,31 @@ def mock_scenario():
2323
return scenario
2424

2525

26+
@pytest.fixture
27+
def mock_scenario_with_disabled():
28+
"""Scenario with disabled nodes and links for testing include_disabled parameter."""
29+
scenario = MagicMock()
30+
scenario.network = Network()
31+
scenario.results = MagicMock()
32+
scenario.results.put = MagicMock()
33+
34+
# Add nodes - some enabled, some disabled
35+
scenario.network.add_node(Node("A")) # enabled
36+
scenario.network.add_node(Node("B")) # enabled
37+
scenario.network.add_node(Node("C", disabled=True)) # disabled
38+
scenario.network.add_node(Node("D")) # enabled
39+
40+
# Add links - some enabled, some disabled
41+
scenario.network.add_link(Link("A", "B", capacity=10)) # enabled
42+
scenario.network.add_link(Link("A", "C", capacity=5)) # enabled (to disabled node)
43+
scenario.network.add_link(
44+
Link("C", "A", capacity=7)
45+
) # enabled (from disabled node)
46+
scenario.network.add_link(Link("B", "D", capacity=15, disabled=True)) # disabled
47+
scenario.network.add_link(Link("D", "B", capacity=20)) # enabled
48+
return scenario
49+
50+
2651
def test_network_stats_collects_statistics(mock_scenario):
2752
step = NetworkStats(name="stats")
2853

@@ -50,3 +75,112 @@ def test_network_stats_collects_statistics(mock_scenario):
5075
if call.args[1] == "per_node"
5176
)
5277
assert set(per_node.keys()) == {"A", "B", "C"}
78+
79+
80+
def test_network_stats_excludes_disabled_by_default(mock_scenario_with_disabled):
81+
"""Test that disabled nodes and links are excluded by default."""
82+
step = NetworkStats(name="stats")
83+
84+
step.run(mock_scenario_with_disabled)
85+
86+
# Get the collected data
87+
calls = {
88+
call.args[1]: call.args[2]
89+
for call in mock_scenario_with_disabled.results.put.call_args_list
90+
}
91+
92+
# Link capacity should exclude disabled link (capacity=15)
93+
link_data = calls["link_capacity"]
94+
# Should include capacities: 10, 5, 7, 20 (excluding disabled link with capacity=15)
95+
assert sorted(link_data["values"]) == [5, 7, 10, 20]
96+
assert link_data["min"] == 5
97+
assert link_data["max"] == 20
98+
assert link_data["mean"] == pytest.approx((5 + 7 + 10 + 20) / 4)
99+
100+
# Per-node stats should exclude disabled node C
101+
per_node = calls["per_node"]
102+
# Should only include enabled nodes: A, B, D (excluding disabled node C)
103+
assert set(per_node.keys()) == {"A", "B", "D"}
104+
105+
# Node A should have degree 2 (links to B and C, both enabled)
106+
assert per_node["A"]["degree"] == 2
107+
assert per_node["A"]["capacity_sum"] == 15 # 10 + 5
108+
109+
# Node B should have degree 0 (link to D is disabled)
110+
assert per_node["B"]["degree"] == 0
111+
assert per_node["B"]["capacity_sum"] == 0
112+
113+
# Node D should have degree 1 (link to B is enabled)
114+
assert per_node["D"]["degree"] == 1
115+
assert per_node["D"]["capacity_sum"] == 20
116+
117+
118+
def test_network_stats_includes_disabled_when_enabled(mock_scenario_with_disabled):
119+
"""Test that disabled nodes and links are included when include_disabled=True."""
120+
step = NetworkStats(name="stats", include_disabled=True)
121+
122+
step.run(mock_scenario_with_disabled)
123+
124+
# Get the collected data
125+
calls = {
126+
call.args[1]: call.args[2]
127+
for call in mock_scenario_with_disabled.results.put.call_args_list
128+
}
129+
130+
# Link capacity should include all links including disabled one
131+
link_data = calls["link_capacity"]
132+
# Should include all capacities: 10, 5, 7, 15, 20
133+
assert sorted(link_data["values"]) == [5, 7, 10, 15, 20]
134+
assert link_data["min"] == 5
135+
assert link_data["max"] == 20
136+
assert link_data["mean"] == pytest.approx((5 + 7 + 10 + 15 + 20) / 5)
137+
138+
# Per-node stats should include disabled node C
139+
per_node = calls["per_node"]
140+
# Should include all nodes: A, B, C, D
141+
assert set(per_node.keys()) == {"A", "B", "C", "D"}
142+
143+
# Node A should have degree 2 (links to B and C)
144+
assert per_node["A"]["degree"] == 2
145+
assert per_node["A"]["capacity_sum"] == 15 # 10 + 5
146+
147+
# Node B should have degree 1 (link to D, now included)
148+
assert per_node["B"]["degree"] == 1
149+
assert per_node["B"]["capacity_sum"] == 15 # disabled link now included
150+
151+
# Node C should have degree 1 (link to A)
152+
assert per_node["C"]["degree"] == 1
153+
assert per_node["C"]["capacity_sum"] == 7
154+
155+
# Node D should have degree 1 (link to B)
156+
assert per_node["D"]["degree"] == 1
157+
assert per_node["D"]["capacity_sum"] == 20
158+
159+
160+
def test_network_stats_parameter_backward_compatibility(mock_scenario):
161+
"""Test that the new parameter maintains backward compatibility."""
162+
# Test with explicit default
163+
step_explicit = NetworkStats(name="stats", include_disabled=False)
164+
step_explicit.run(mock_scenario)
165+
166+
# Capture results from explicit test
167+
explicit_calls = {
168+
call.args[1]: call.args[2] for call in mock_scenario.results.put.call_args_list
169+
}
170+
171+
# Reset mock for second test
172+
mock_scenario.results.put.reset_mock()
173+
174+
# Test with implicit default
175+
step_implicit = NetworkStats(name="stats")
176+
step_implicit.run(mock_scenario)
177+
178+
# Capture results from implicit test
179+
implicit_calls = {
180+
call.args[1]: call.args[2] for call in mock_scenario.results.put.call_args_list
181+
}
182+
183+
# Results should be identical
184+
assert explicit_calls.keys() == implicit_calls.keys()
185+
for key in explicit_calls:
186+
assert explicit_calls[key] == implicit_calls[key]

0 commit comments

Comments
 (0)