From 45d796d384fa35716f04b81ca1e05ec9c4c5742f Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Thu, 11 Dec 2025 22:56:15 -0800 Subject: [PATCH 1/2] Update to v0.12.3 - Implemented SPF caching in demand placement to optimize shortest path computations, reducing redundant calculations for cacheable policies (ECMP, WCMP, TE_WCMP_UNLIM). --- CHANGELOG.md | 6 + docs/reference/api-full.md | 17 +- docs/reference/design.md | 12 +- ngraph/_version.py | 2 +- ngraph/exec/analysis/flow.py | 382 ++++++++-- pyproject.toml | 2 +- tests/exec/analysis/test_spf_caching.py | 972 ++++++++++++++++++++++++ 7 files changed, 1340 insertions(+), 53 deletions(-) create mode 100644 tests/exec/analysis/test_spf_caching.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bd2b36..d930934 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.12.3] - 2025-12-11 + +### Changed + +- **SPF caching in demand placement**: `demand_placement_analysis()` caches SPF results by (source, policy_preset) for ECMP, WCMP, and TE_WCMP_UNLIM policies; TE policies recompute when capacity constraints require alternate paths + ## [0.12.2] - 2025-12-08 ### Fixed diff --git a/docs/reference/api-full.md b/docs/reference/api-full.md index 73dafbe..aa77927 100644 --- a/docs/reference/api-full.md +++ b/docs/reference/api-full.md @@ -12,7 +12,7 @@ Quick links: - [CLI Reference](cli.md) - [DSL Reference](dsl.md) -Generated from source code on: December 07, 2025 at 00:13 UTC +Generated from source code on: December 11, 2025 at 22:53 UTC Modules auto-discovered: 44 @@ -2345,6 +2345,10 @@ with FailureManager's caching and multiprocessing systems. Graph caching enables efficient repeated analysis with different exclusion sets by building the graph once and using O(|excluded|) masks for exclusions. +SPF caching enables efficient demand placement by computing shortest paths once +per unique source node rather than once per demand. For networks with many demands +sharing the same sources, this can reduce SPF computations by an order of magnitude. + ### build_demand_context(network: "'Network'", demands_config: 'list[dict[str, Any]]') -> 'AnalysisContext' Build an AnalysisContext for repeated demand placement analysis. @@ -2383,8 +2387,15 @@ This function: 1. Builds Core infrastructure (graph, algorithms, flow_graph) or uses cached 2. Expands demands into concrete (src, dst, volume) tuples -3. Places each demand using Core's FlowPolicy with exclusion masks -4. Aggregates results into FlowIterationResult +3. Places each demand using SPF caching for cacheable policies +4. Falls back to FlowPolicy for complex multi-flow policies +5. Aggregates results into FlowIterationResult + +SPF Caching Optimization: + For cacheable policies (ECMP, WCMP, TE_WCMP_UNLIM), SPF results are + cached by source node. This reduces SPF computations from O(demands) + to O(unique_sources), typically a 5-10x reduction for workloads with + many demands sharing the same sources. Args: network: Network instance. diff --git a/docs/reference/design.md b/docs/reference/design.md index ac02316..c38e4c7 100644 --- a/docs/reference/design.md +++ b/docs/reference/design.md @@ -654,7 +654,7 @@ Managers handle scenario dynamics and prepare inputs for algorithmic steps. - Deterministic expansion: source/sink node lists sorted alphabetically; no randomization - Supports `combine` mode (aggregate via pseudo nodes) and `pairwise` mode (individual (src,dst) pairs with volume split) - Demands sorted by ascending priority before placement (lower value = higher priority) -- Placement handled by Core's FlowPolicy with configurable presets (ECMP, WCMP, TE modes) +- Placement uses SPF caching for simple policies (ECMP, WCMP, TE_WCMP_UNLIM), FlowPolicy for complex multi-flow policies - Non-mutating: operates on Core flow graphs with exclusions; Network remains unmodified **Failure Manager** (`ngraph.exec.failure.manager`): Applies a `FailurePolicy` to compute exclusion sets and runs analyses with those exclusions. @@ -737,6 +737,16 @@ For Monte Carlo analysis with many failure iterations, graph construction is amo This optimization is critical for performance: graph construction involves Python processing, NumPy array creation, and C++ object initialization. Building the graph once eliminates this overhead from the per-iteration critical path, enabling the GIL-releasing C++ algorithms to execute with minimal Python overhead. +**SPF Caching for Demand Placement:** + +For demand placement with cacheable policies (ECMP, WCMP, TE_WCMP_UNLIM), SPF results are cached by (source_node, policy_preset): + +- Initial SPF computed once per unique source; subsequent demands from the same source reuse the cached DAG +- For TE policies, DAG is recomputed when capacity constraints require alternate paths +- Complex multi-flow policies (TE_ECMP_16_LSP, TE_ECMP_UP_TO_256_LSP) use FlowPolicy directly + +This reduces SPF computations from O(demands) to O(unique_sources) for workloads where many demands share the same source nodes. + **Monte Carlo Deduplication:** FailureManager collapses identical failure patterns into single executions. Runtime diff --git a/ngraph/_version.py b/ngraph/_version.py index 5f3c2fa..538d531 100644 --- a/ngraph/_version.py +++ b/ngraph/_version.py @@ -2,4 +2,4 @@ __all__ = ["__version__"] -__version__ = "0.12.2" +__version__ = "0.12.3" diff --git a/ngraph/exec/analysis/flow.py b/ngraph/exec/analysis/flow.py index fdc3433..565b6f7 100644 --- a/ngraph/exec/analysis/flow.py +++ b/ngraph/exec/analysis/flow.py @@ -9,16 +9,22 @@ Graph caching enables efficient repeated analysis with different exclusion sets by building the graph once and using O(|excluded|) masks for exclusions. + +SPF caching enables efficient demand placement by computing shortest paths once +per unique source node rather than once per demand. For networks with many demands +sharing the same sources, this can reduce SPF computations by an order of magnitude. """ from __future__ import annotations +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Optional, Set import netgraph_core +import numpy as np from ngraph.analysis import AnalysisContext, analyze -from ngraph.exec.demand.expand import expand_demands +from ngraph.exec.demand.expand import ExpandedDemand, expand_demands from ngraph.model.demand.spec import TrafficDemand from ngraph.model.flow.policy_config import FlowPolicyPreset, create_flow_policy from ngraph.results.flow import FlowEntry, FlowIterationResult, FlowSummary @@ -27,6 +33,257 @@ if TYPE_CHECKING: from ngraph.model.network import Network +# Minimum flow threshold for placement decisions +_MIN_FLOW = 1e-9 + +# Policies that support SPF caching with simple single-flow placement +_CACHEABLE_SIMPLE: frozenset[FlowPolicyPreset] = frozenset( + { + FlowPolicyPreset.SHORTEST_PATHS_ECMP, + FlowPolicyPreset.SHORTEST_PATHS_WCMP, + } +) + +# Policies that support SPF caching with fallback for capacity-aware routing +_CACHEABLE_TE: frozenset[FlowPolicyPreset] = frozenset( + { + FlowPolicyPreset.TE_WCMP_UNLIM, + } +) + +# All cacheable policies +_CACHEABLE_PRESETS: frozenset[FlowPolicyPreset] = _CACHEABLE_SIMPLE | _CACHEABLE_TE + + +def _get_selection_for_preset( + preset: FlowPolicyPreset, +) -> netgraph_core.EdgeSelection: + """Get EdgeSelection configuration for a cacheable policy preset. + + Args: + preset: Flow policy preset. + + Returns: + EdgeSelection configured for the preset. + + Raises: + ValueError: If preset is not cacheable. + """ + if preset in _CACHEABLE_SIMPLE: + return netgraph_core.EdgeSelection( + multi_edge=True, + require_capacity=False, + tie_break=netgraph_core.EdgeTieBreak.DETERMINISTIC, + ) + elif preset == FlowPolicyPreset.TE_WCMP_UNLIM: + return netgraph_core.EdgeSelection( + multi_edge=True, + require_capacity=True, + tie_break=netgraph_core.EdgeTieBreak.PREFER_HIGHER_RESIDUAL, + ) + raise ValueError(f"Preset {preset} is not cacheable") + + +def _get_placement_for_preset(preset: FlowPolicyPreset) -> netgraph_core.FlowPlacement: + """Get FlowPlacement strategy for a cacheable policy preset. + + Args: + preset: Flow policy preset. + + Returns: + FlowPlacement strategy for the preset. + """ + if preset == FlowPolicyPreset.SHORTEST_PATHS_ECMP: + return netgraph_core.FlowPlacement.EQUAL_BALANCED + # WCMP and TE policies use PROPORTIONAL + return netgraph_core.FlowPlacement.PROPORTIONAL + + +@dataclass +class _CachedPlacementResult: + """Result of placing a demand using cached SPF.""" + + total_placed: float + next_flow_idx: int + cost_distribution: dict[float, float] + used_edges: set[str] + flow_indices: list[netgraph_core.FlowIndex] = field(default_factory=list) + + +def _place_demand_cached( + demand: ExpandedDemand, + src_id: int, + dst_id: int, + dag_cache: dict[tuple[int, FlowPolicyPreset], tuple[np.ndarray, Any]], + algorithms: netgraph_core.Algorithms, + handle: netgraph_core.Graph, + flow_graph: netgraph_core.FlowGraph, + node_mask: np.ndarray, + edge_mask: np.ndarray, + flow_idx_start: int, + include_flow_details: bool, + include_used_edges: bool, + edge_mapper: Any, + multidigraph: netgraph_core.StrictMultiDiGraph, +) -> _CachedPlacementResult: + """Place a demand using cached SPF DAG with fallback for TE policies. + + This function implements SPF caching to reduce redundant shortest path + computations. For demands sharing the same source node and policy preset, + the SPF result is computed once and reused. + + For simple policies (ECMP/WCMP), the cached DAG is always valid. + For TE policies, the DAG may become stale as edges saturate. In this case, + a fallback loop recomputes SPF with current residuals until the demand is + placed or no more progress can be made. + + Args: + demand: Expanded demand to place. + src_id: Source node ID. + dst_id: Destination node ID. + dag_cache: Cache mapping (src_id, preset) to (distances, DAG). + algorithms: Core Algorithms instance. + handle: Core Graph handle. + flow_graph: FlowGraph for placement. + node_mask: Node inclusion mask. + edge_mask: Edge inclusion mask. + flow_idx_start: Starting flow index counter. + include_flow_details: Whether to collect cost distribution. + include_used_edges: Whether to collect used edges. + edge_mapper: Edge ID mapper for edge name resolution. + multidigraph: Graph for edge lookup. + + Returns: + _CachedPlacementResult with placement details. + """ + cache_key = (src_id, demand.policy_preset) + selection = _get_selection_for_preset(demand.policy_preset) + placement = _get_placement_for_preset(demand.policy_preset) + is_te = demand.policy_preset in _CACHEABLE_TE + + flow_indices: list[netgraph_core.FlowIndex] = [] + flow_costs: list[tuple[float, float]] = [] # (cost, placed_amount) + flow_idx_counter = flow_idx_start + demand_placed = 0.0 + remaining = demand.volume + + # Get or compute initial DAG + if cache_key not in dag_cache: + # Initial computation without residual - on a fresh graph all edges + # have full capacity, so residual-aware selection is not needed yet + dists, dag = algorithms.spf( + handle, + src=src_id, + dst=None, # Full DAG to all destinations + selection=selection, + node_mask=node_mask, + edge_mask=edge_mask, + multipath=True, + dtype="float64", + ) + dag_cache[cache_key] = (dists, dag) + + dists, dag = dag_cache[cache_key] + + # Check if destination is reachable + if dists[dst_id] == float("inf"): + # Destination unreachable - return zero placement + return _CachedPlacementResult( + total_placed=0.0, + next_flow_idx=flow_idx_counter, + cost_distribution={}, + used_edges=set(), + flow_indices=[], + ) + + cost = float(dists[dst_id]) + + # First placement attempt with cached DAG + flow_idx = netgraph_core.FlowIndex( + src_id, dst_id, demand.priority, flow_idx_counter + ) + flow_idx_counter += 1 + placed = flow_graph.place(flow_idx, src_id, dst_id, dag, remaining, placement) + + if placed > _MIN_FLOW: + flow_indices.append(flow_idx) + flow_costs.append((cost, placed)) + demand_placed += placed + remaining -= placed + + # For TE policies, use fallback loop if partial placement + if is_te and remaining > _MIN_FLOW: + max_fallback_iterations = 100 + iterations = 0 + + while remaining > _MIN_FLOW and iterations < max_fallback_iterations: + iterations += 1 + + # Recompute DAG with current residuals + residual = np.ascontiguousarray( + flow_graph.residual_view(), dtype=np.float64 + ) + fresh_dists, fresh_dag = algorithms.spf( + handle, + src=src_id, + dst=None, + selection=selection, + residual=residual, + node_mask=node_mask, + edge_mask=edge_mask, + multipath=True, + dtype="float64", + ) + + # Update cache with fresh DAG + dag_cache[cache_key] = (fresh_dists, fresh_dag) + + # Check if destination still reachable + if fresh_dists[dst_id] == float("inf"): + break # No more paths available + + fresh_cost = float(fresh_dists[dst_id]) + + flow_idx = netgraph_core.FlowIndex( + src_id, dst_id, demand.priority, flow_idx_counter + ) + flow_idx_counter += 1 + additional = flow_graph.place( + flow_idx, src_id, dst_id, fresh_dag, remaining, placement + ) + + if additional < _MIN_FLOW: + break # No progress, stop + + flow_indices.append(flow_idx) + flow_costs.append((fresh_cost, additional)) + demand_placed += additional + remaining -= additional + + # Collect cost distribution if requested + cost_distribution: dict[float, float] = {} + if include_flow_details: + for c, amount in flow_costs: + cost_distribution[c] = cost_distribution.get(c, 0.0) + amount + + # Collect used edges if requested + used_edges: set[str] = set() + if include_used_edges: + for fidx in flow_indices: + edges = flow_graph.get_flow_edges(fidx) + for edge_id, _ in edges: + edge_ref = edge_mapper.to_ref(edge_id, multidigraph) + if edge_ref is not None: + used_edges.add(f"{edge_ref.link_id}:{edge_ref.direction}") + + return _CachedPlacementResult( + total_placed=demand_placed, + next_flow_idx=flow_idx_counter, + cost_distribution=cost_distribution, + used_edges=used_edges, + flow_indices=flow_indices, + ) + def max_flow_analysis( network: "Network", @@ -164,8 +421,15 @@ def demand_placement_analysis( This function: 1. Builds Core infrastructure (graph, algorithms, flow_graph) or uses cached 2. Expands demands into concrete (src, dst, volume) tuples - 3. Places each demand using Core's FlowPolicy with exclusion masks - 4. Aggregates results into FlowIterationResult + 3. Places each demand using SPF caching for cacheable policies + 4. Falls back to FlowPolicy for complex multi-flow policies + 5. Aggregates results into FlowIterationResult + + SPF Caching Optimization: + For cacheable policies (ECMP, WCMP, TE_WCMP_UNLIM), SPF results are + cached by source node. This reduces SPF computations from O(demands) + to O(unique_sources), typically a 5-10x reduction for workloads with + many demands sharing the same sources. Args: network: Network instance. @@ -221,62 +485,86 @@ def demand_placement_analysis( flow_graph = netgraph_core.FlowGraph(multidigraph) - # Phase 3: Place demands using Core FlowPolicy + # Phase 3: Place demands with SPF caching for cacheable policies flow_entries: list[FlowEntry] = [] total_demand = 0.0 total_placed = 0.0 + # SPF cache: (src_id, policy_preset) -> (distances, DAG) + dag_cache: dict[tuple[int, FlowPolicyPreset], tuple[np.ndarray, Any]] = {} + flow_idx_counter = 0 + for demand in expansion.demands: # Resolve node names to IDs (includes pseudo nodes from augmentations) src_id = node_mapper.to_id(demand.src_name) dst_id = node_mapper.to_id(demand.dst_name) - # Create FlowPolicy for this demand with masks - policy = create_flow_policy( - algorithms, - handle, - demand.policy_preset, - node_mask=node_mask, - edge_mask=edge_mask, - ) + # Use cached placement for cacheable policies, FlowPolicy for others + if demand.policy_preset in _CACHEABLE_PRESETS: + result = _place_demand_cached( + demand=demand, + src_id=src_id, + dst_id=dst_id, + dag_cache=dag_cache, + algorithms=algorithms, + handle=handle, + flow_graph=flow_graph, + node_mask=node_mask, + edge_mask=edge_mask, + flow_idx_start=flow_idx_counter, + include_flow_details=include_flow_details, + include_used_edges=include_used_edges, + edge_mapper=edge_mapper, + multidigraph=multidigraph, + ) + flow_idx_counter = result.next_flow_idx + placed = result.total_placed + cost_distribution = result.cost_distribution + used_edges = result.used_edges + else: + # Complex policies (multi-flow LSP variants): use FlowPolicy + policy = create_flow_policy( + algorithms, + handle, + demand.policy_preset, + node_mask=node_mask, + edge_mask=edge_mask, + ) - # Place demand using Core - placed, flow_count = policy.place_demand( - flow_graph, - src_id, - dst_id, - demand.priority, # flowClass - demand.volume, - ) + placed, flow_count = policy.place_demand( + flow_graph, + src_id, + dst_id, + demand.priority, + demand.volume, + ) - # Collect flow details if requested - cost_distribution: dict[float, float] = {} - used_edges: set[str] = set() - - if include_flow_details or include_used_edges: - # Get flows from policy - flows_dict = policy.flows - for flow_key, flow_data in flows_dict.items(): - # flow_key is (src, dst, flowClass, flowId) - # flow_data is (src, dst, cost, placed_flow) - if include_flow_details: - cost = float(flow_data[2]) - flow_vol = float(flow_data[3]) - if flow_vol > 0: - cost_distribution[cost] = ( - cost_distribution.get(cost, 0.0) + flow_vol + # Collect flow details if requested + cost_distribution = {} + used_edges = set() + + if include_flow_details or include_used_edges: + flows_dict = policy.flows + for flow_key, flow_data in flows_dict.items(): + if include_flow_details: + cost = float(flow_data[2]) + flow_vol = float(flow_data[3]) + if flow_vol > 0: + cost_distribution[cost] = ( + cost_distribution.get(cost, 0.0) + flow_vol + ) + + if include_used_edges: + flow_idx = netgraph_core.FlowIndex( + flow_key[0], flow_key[1], flow_key[2], flow_key[3] ) - - if include_used_edges: - # Get edges for this flow - flow_idx = netgraph_core.FlowIndex( - flow_key[0], flow_key[1], flow_key[2], flow_key[3] - ) - edges = flow_graph.get_flow_edges(flow_idx) - for edge_id, _ in edges: - edge_ref = edge_mapper.to_ref(edge_id, multidigraph) - if edge_ref is not None: - used_edges.add(f"{edge_ref.link_id}:{edge_ref.direction}") + edges = flow_graph.get_flow_edges(flow_idx) + for edge_id, _ in edges: + edge_ref = edge_mapper.to_ref(edge_id, multidigraph) + if edge_ref is not None: + used_edges.add( + f"{edge_ref.link_id}:{edge_ref.direction}" + ) # Build entry data entry_data: dict[str, Any] = {} diff --git a/pyproject.toml b/pyproject.toml index 0fb25cf..4fe281e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" # --------------------------------------------------------------------- [project] name = "ngraph" -version = "0.12.2" +version = "0.12.3" description = "A tool and a library for network modeling and analysis." readme = "README.md" authors = [{ name = "Andrey Golovanov" }] diff --git a/tests/exec/analysis/test_spf_caching.py b/tests/exec/analysis/test_spf_caching.py new file mode 100644 index 0000000..4bc1da9 --- /dev/null +++ b/tests/exec/analysis/test_spf_caching.py @@ -0,0 +1,972 @@ +"""Tests for SPF caching in demand_placement_analysis. + +This module tests the SPF caching optimization that reduces redundant shortest +path computations when placing demands from the same source nodes. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from ngraph.exec.analysis.flow import ( + _CACHEABLE_PRESETS, + _CACHEABLE_SIMPLE, + _CACHEABLE_TE, + _get_placement_for_preset, + _get_selection_for_preset, + demand_placement_analysis, +) +from ngraph.model.flow.policy_config import FlowPolicyPreset +from ngraph.model.network import Link, Network, Node +from ngraph.results.flow import FlowIterationResult + + +class TestHelperFunctions: + """Test helper functions for SPF caching.""" + + def test_get_selection_for_ecmp(self) -> None: + """Test EdgeSelection for ECMP preset.""" + selection = _get_selection_for_preset(FlowPolicyPreset.SHORTEST_PATHS_ECMP) + assert selection.multi_edge is True + assert selection.require_capacity is False + + def test_get_selection_for_wcmp(self) -> None: + """Test EdgeSelection for WCMP preset.""" + selection = _get_selection_for_preset(FlowPolicyPreset.SHORTEST_PATHS_WCMP) + assert selection.multi_edge is True + assert selection.require_capacity is False + + def test_get_selection_for_te_wcmp_unlim(self) -> None: + """Test EdgeSelection for TE_WCMP_UNLIM preset.""" + selection = _get_selection_for_preset(FlowPolicyPreset.TE_WCMP_UNLIM) + assert selection.multi_edge is True + assert selection.require_capacity is True + + def test_get_selection_for_invalid_preset(self) -> None: + """Test that invalid preset raises ValueError.""" + with pytest.raises(ValueError, match="not cacheable"): + _get_selection_for_preset(FlowPolicyPreset.TE_ECMP_16_LSP) + + def test_get_placement_for_ecmp(self) -> None: + """Test FlowPlacement for ECMP preset.""" + import netgraph_core + + placement = _get_placement_for_preset(FlowPolicyPreset.SHORTEST_PATHS_ECMP) + assert placement == netgraph_core.FlowPlacement.EQUAL_BALANCED + + def test_get_placement_for_wcmp(self) -> None: + """Test FlowPlacement for WCMP preset.""" + import netgraph_core + + placement = _get_placement_for_preset(FlowPolicyPreset.SHORTEST_PATHS_WCMP) + assert placement == netgraph_core.FlowPlacement.PROPORTIONAL + + def test_get_placement_for_te_wcmp_unlim(self) -> None: + """Test FlowPlacement for TE_WCMP_UNLIM preset.""" + import netgraph_core + + placement = _get_placement_for_preset(FlowPolicyPreset.TE_WCMP_UNLIM) + assert placement == netgraph_core.FlowPlacement.PROPORTIONAL + + +class TestCacheablePresets: + """Test that cacheable preset sets are correctly defined.""" + + def test_cacheable_simple_presets(self) -> None: + """Test that simple cacheable presets are defined correctly.""" + assert FlowPolicyPreset.SHORTEST_PATHS_ECMP in _CACHEABLE_SIMPLE + assert FlowPolicyPreset.SHORTEST_PATHS_WCMP in _CACHEABLE_SIMPLE + # TE policies should not be in simple set + assert FlowPolicyPreset.TE_WCMP_UNLIM not in _CACHEABLE_SIMPLE + + def test_cacheable_te_presets(self) -> None: + """Test that TE cacheable presets are defined correctly.""" + assert FlowPolicyPreset.TE_WCMP_UNLIM in _CACHEABLE_TE + # Simple policies should not be in TE set + assert FlowPolicyPreset.SHORTEST_PATHS_ECMP not in _CACHEABLE_TE + + def test_cacheable_presets_is_union(self) -> None: + """Test that _CACHEABLE_PRESETS is the union of simple and TE.""" + assert _CACHEABLE_PRESETS == _CACHEABLE_SIMPLE | _CACHEABLE_TE + + def test_lsp_policies_not_cacheable(self) -> None: + """Test that LSP policies are not in cacheable set.""" + assert FlowPolicyPreset.TE_ECMP_16_LSP not in _CACHEABLE_PRESETS + assert FlowPolicyPreset.TE_ECMP_UP_TO_256_LSP not in _CACHEABLE_PRESETS + + +class TestSPFCachingBasic: + """Test basic SPF caching behavior.""" + + @pytest.fixture + def diamond_network(self) -> Network: + """Create a diamond network: A -> B,C -> D.""" + network = Network() + for node in ["A", "B", "C", "D"]: + network.add_node(Node(node)) + + # Two equal-cost paths of capacity 60 each + network.add_link(Link("A", "B", capacity=60.0, cost=1.0)) + network.add_link(Link("A", "C", capacity=60.0, cost=1.0)) + network.add_link(Link("B", "D", capacity=60.0, cost=1.0)) + network.add_link(Link("C", "D", capacity=60.0, cost=1.0)) + + return network + + @pytest.fixture + def multi_source_network(self) -> Network: + """Create a network with multiple sources sharing paths to destinations. + + Topology: + S1 --+ + | + S2 --+--> R1 --> D1 + | | + S3 --+ +--> D2 + """ + network = Network() + for node in ["S1", "S2", "S3", "R1", "D1", "D2"]: + network.add_node(Node(node)) + + # Sources to router + network.add_link(Link("S1", "R1", capacity=100.0, cost=1.0)) + network.add_link(Link("S2", "R1", capacity=100.0, cost=1.0)) + network.add_link(Link("S3", "R1", capacity=100.0, cost=1.0)) + + # Router to destinations + network.add_link(Link("R1", "D1", capacity=200.0, cost=1.0)) + network.add_link(Link("R1", "D2", capacity=200.0, cost=1.0)) + + return network + + def test_single_demand_ecmp(self, diamond_network: Network) -> None: + """Test that single demand with ECMP works correctly with caching.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 50.0, + "mode": "pairwise", + "priority": 0, + }, + ] + + result = demand_placement_analysis( + network=diamond_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert isinstance(result, FlowIterationResult) + assert len(result.flows) == 1 + flow = result.flows[0] + assert flow.source == "A" + assert flow.destination == "D" + assert flow.placed == 50.0 + assert flow.dropped == 0.0 + + def test_multiple_demands_same_source_reuses_cache( + self, multi_source_network: Network + ) -> None: + """Test that multiple demands from same source benefit from caching.""" + # Multiple demands from S1 to different destinations + demands_config = [ + { + "source_path": "S1", + "sink_path": "D1", + "demand": 30.0, + "mode": "pairwise", + "priority": 0, + }, + { + "source_path": "S1", + "sink_path": "D2", + "demand": 30.0, + "mode": "pairwise", + "priority": 0, + }, + ] + + result = demand_placement_analysis( + network=multi_source_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 2 + # Both demands should be fully placed + assert result.summary.total_placed == 60.0 + assert result.summary.overall_ratio == 1.0 + + def test_demands_from_multiple_sources(self, multi_source_network: Network) -> None: + """Test that demands from multiple sources each get their own cache entry.""" + demands_config = [ + { + "source_path": "S1", + "sink_path": "D1", + "demand": 50.0, + "mode": "pairwise", + }, + { + "source_path": "S2", + "sink_path": "D1", + "demand": 50.0, + "mode": "pairwise", + }, + { + "source_path": "S3", + "sink_path": "D2", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=multi_source_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 3 + assert result.summary.total_placed == 150.0 + assert result.summary.overall_ratio == 1.0 + + +class TestSPFCachingEquivalence: + """Test that cached placement produces equivalent results to non-cached.""" + + @pytest.fixture + def mesh_network(self) -> Network: + """Create a mesh network for equivalence testing. + + Topology (2x2 mesh): + A -- B + | | + C -- D + """ + network = Network() + for node in ["A", "B", "C", "D"]: + network.add_node(Node(node)) + + # Horizontal links + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + network.add_link(Link("C", "D", capacity=100.0, cost=1.0)) + + # Vertical links + network.add_link(Link("A", "C", capacity=100.0, cost=1.0)) + network.add_link(Link("B", "D", capacity=100.0, cost=1.0)) + + return network + + def _run_demand_placement_without_cache( + self, + network: Network, + demands_config: list[dict[str, Any]], + include_flow_details: bool = False, + include_used_edges: bool = False, + ) -> FlowIterationResult: + """Run demand placement using only FlowPolicy (no caching). + + This provides a reference implementation for equivalence testing. + """ + import netgraph_core + + from ngraph.analysis import AnalysisContext + from ngraph.exec.demand.expand import expand_demands + from ngraph.model.demand.spec import TrafficDemand + from ngraph.model.flow.policy_config import ( + FlowPolicyPreset, + create_flow_policy, + ) + from ngraph.results.flow import FlowEntry, FlowSummary + + # Reconstruct TrafficDemand objects + traffic_demands = [] + for config in demands_config: + demand = TrafficDemand( + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + mode=config.get("mode", "pairwise"), + flow_policy_config=config.get("flow_policy_config"), + priority=config.get("priority", 0), + ) + traffic_demands.append(demand) + + # Expand demands + expansion = expand_demands( + network, + traffic_demands, + default_policy_preset=FlowPolicyPreset.SHORTEST_PATHS_ECMP, + ) + + # Build context + ctx = AnalysisContext.from_network( + network, augmentations=expansion.augmentations + ) + + handle = ctx.handle + multidigraph = ctx.multidigraph + node_mapper = ctx.node_mapper + edge_mapper = ctx.edge_mapper + algorithms = ctx.algorithms + node_mask = ctx._build_node_mask(set()) + edge_mask = ctx._build_edge_mask(set()) + + flow_graph = netgraph_core.FlowGraph(multidigraph) + + # Place demands using ONLY FlowPolicy (no caching) + flow_entries: list[FlowEntry] = [] + total_demand = 0.0 + total_placed = 0.0 + + for demand in expansion.demands: + src_id = node_mapper.to_id(demand.src_name) + dst_id = node_mapper.to_id(demand.dst_name) + + policy = create_flow_policy( + algorithms, + handle, + demand.policy_preset, + node_mask=node_mask, + edge_mask=edge_mask, + ) + + placed, flow_count = policy.place_demand( + flow_graph, + src_id, + dst_id, + demand.priority, + demand.volume, + ) + + cost_distribution: dict[float, float] = {} + used_edges: set[str] = set() + + if include_flow_details or include_used_edges: + flows_dict = policy.flows + for flow_key, flow_data in flows_dict.items(): + if include_flow_details: + cost = float(flow_data[2]) + flow_vol = float(flow_data[3]) + if flow_vol > 0: + cost_distribution[cost] = ( + cost_distribution.get(cost, 0.0) + flow_vol + ) + + if include_used_edges: + flow_idx = netgraph_core.FlowIndex( + flow_key[0], flow_key[1], flow_key[2], flow_key[3] + ) + edges = flow_graph.get_flow_edges(flow_idx) + for edge_id, _ in edges: + edge_ref = edge_mapper.to_ref(edge_id, multidigraph) + if edge_ref is not None: + used_edges.add( + f"{edge_ref.link_id}:{edge_ref.direction}" + ) + + entry_data: dict[str, Any] = {} + if include_used_edges and used_edges: + entry_data["edges"] = sorted(used_edges) + entry_data["edges_kind"] = "used" + + entry = FlowEntry( + source=demand.src_name, + destination=demand.dst_name, + priority=demand.priority, + demand=demand.volume, + placed=placed, + dropped=demand.volume - placed, + cost_distribution=cost_distribution if include_flow_details else {}, + data=entry_data, + ) + flow_entries.append(entry) + total_demand += demand.volume + total_placed += placed + + overall_ratio = (total_placed / total_demand) if total_demand > 0 else 1.0 + dropped_flows = sum(1 for e in flow_entries if e.dropped > 0.0) + summary = FlowSummary( + total_demand=total_demand, + total_placed=total_placed, + overall_ratio=overall_ratio, + dropped_flows=dropped_flows, + num_flows=len(flow_entries), + ) + + return FlowIterationResult( + flows=flow_entries, + summary=summary, + data={}, + ) + + def test_equivalence_ecmp_single_demand(self, mesh_network: Network) -> None: + """Test that ECMP placement is equivalent with and without caching.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 80.0, + "mode": "pairwise", + }, + ] + + # Run with caching (default) + cached_result = demand_placement_analysis( + network=mesh_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + # Run without caching (reference) + reference_result = self._run_demand_placement_without_cache( + network=mesh_network, + demands_config=demands_config, + ) + + # Compare results + assert len(cached_result.flows) == len(reference_result.flows) + assert ( + cached_result.summary.total_demand == reference_result.summary.total_demand + ) + assert ( + cached_result.summary.total_placed == reference_result.summary.total_placed + ) + assert cached_result.summary.overall_ratio == pytest.approx( + reference_result.summary.overall_ratio, rel=1e-9 + ) + + def test_equivalence_ecmp_multiple_demands(self, mesh_network: Network) -> None: + """Test ECMP placement equivalence with multiple demands.""" + demands_config = [ + {"source_path": "A", "sink_path": "B", "demand": 30.0, "mode": "pairwise"}, + {"source_path": "A", "sink_path": "D", "demand": 40.0, "mode": "pairwise"}, + {"source_path": "C", "sink_path": "B", "demand": 25.0, "mode": "pairwise"}, + {"source_path": "C", "sink_path": "D", "demand": 35.0, "mode": "pairwise"}, + ] + + cached_result = demand_placement_analysis( + network=mesh_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + reference_result = self._run_demand_placement_without_cache( + network=mesh_network, + demands_config=demands_config, + ) + + # Compare summaries + assert ( + cached_result.summary.total_demand == reference_result.summary.total_demand + ) + assert cached_result.summary.total_placed == pytest.approx( + reference_result.summary.total_placed, rel=1e-9 + ) + + # Compare individual flows + for cached_flow, ref_flow in zip( + cached_result.flows, reference_result.flows, strict=True + ): + assert cached_flow.source == ref_flow.source + assert cached_flow.destination == ref_flow.destination + assert cached_flow.demand == ref_flow.demand + assert cached_flow.placed == pytest.approx(ref_flow.placed, rel=1e-9) + + def test_equivalence_with_flow_details(self, mesh_network: Network) -> None: + """Test equivalence when include_flow_details is True.""" + demands_config = [ + {"source_path": "A", "sink_path": "D", "demand": 50.0, "mode": "pairwise"}, + ] + + cached_result = demand_placement_analysis( + network=mesh_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_flow_details=True, + ) + + reference_result = self._run_demand_placement_without_cache( + network=mesh_network, + demands_config=demands_config, + include_flow_details=True, + ) + + # Both should have cost distribution + for cached_flow, ref_flow in zip( + cached_result.flows, reference_result.flows, strict=True + ): + # Cost distribution should be non-empty for both + if ref_flow.cost_distribution: + assert cached_flow.cost_distribution + # Total volume in cost distribution should match placed + cached_total = sum(cached_flow.cost_distribution.values()) + ref_total = sum(ref_flow.cost_distribution.values()) + assert cached_total == pytest.approx(ref_total, rel=1e-9) + + def test_equivalence_with_used_edges(self, mesh_network: Network) -> None: + """Test equivalence when include_used_edges is True.""" + demands_config = [ + {"source_path": "A", "sink_path": "D", "demand": 50.0, "mode": "pairwise"}, + ] + + cached_result = demand_placement_analysis( + network=mesh_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_used_edges=True, + ) + + reference_result = self._run_demand_placement_without_cache( + network=mesh_network, + demands_config=demands_config, + include_used_edges=True, + ) + + # Both should have used edges + for cached_flow, ref_flow in zip( + cached_result.flows, reference_result.flows, strict=True + ): + cached_edges = set(cached_flow.data.get("edges", [])) + ref_edges = set(ref_flow.data.get("edges", [])) + # Edges should be the same + assert cached_edges == ref_edges + + +class TestSPFCachingTEPolicy: + """Test SPF caching with TE_WCMP_UNLIM policy including fallback behavior.""" + + @pytest.fixture + def constrained_network(self) -> Network: + """Create a network with limited capacity to test fallback. + + Topology: + A --> B --> D + | ^ + +--> C -----+ + """ + network = Network() + for node in ["A", "B", "C", "D"]: + network.add_node(Node(node)) + + # Primary path (cost 2, capacity 50) + network.add_link(Link("A", "B", capacity=50.0, cost=1.0)) + network.add_link(Link("B", "D", capacity=50.0, cost=1.0)) + + # Secondary path (cost 4, capacity 50) + network.add_link(Link("A", "C", capacity=50.0, cost=2.0)) + network.add_link(Link("C", "D", capacity=50.0, cost=2.0)) + + return network + + def test_te_wcmp_basic_placement(self, constrained_network: Network) -> None: + """Test TE_WCMP_UNLIM basic placement without fallback.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 40.0, + "mode": "pairwise", + "flow_policy_config": FlowPolicyPreset.TE_WCMP_UNLIM, + }, + ] + + result = demand_placement_analysis( + network=constrained_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + # Should be able to place 40 on primary path (capacity 50) + assert flow.placed == 40.0 + assert flow.dropped == 0.0 + + def test_te_wcmp_fallback_on_saturation(self, constrained_network: Network) -> None: + """Test TE_WCMP_UNLIM fallback when primary path saturates.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 80.0, # Exceeds primary path capacity + "mode": "pairwise", + "flow_policy_config": FlowPolicyPreset.TE_WCMP_UNLIM, + }, + ] + + result = demand_placement_analysis( + network=constrained_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_flow_details=True, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + # Should place 80 using both paths (50 + 30 or similar distribution) + assert flow.placed == pytest.approx(80.0, rel=1e-6) + assert flow.dropped == pytest.approx(0.0, abs=1e-6) + + # Should have multiple cost tiers in distribution (primary + secondary path) + if flow.cost_distribution: + assert len(flow.cost_distribution) >= 1 + total_in_dist = sum(flow.cost_distribution.values()) + assert total_in_dist == pytest.approx(80.0, rel=1e-6) + + def test_te_wcmp_multiple_demands_same_source( + self, constrained_network: Network + ) -> None: + """Test TE_WCMP_UNLIM with multiple demands sharing source.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 30.0, + "mode": "pairwise", + "flow_policy_config": FlowPolicyPreset.TE_WCMP_UNLIM, + }, + { + "source_path": "A", + "sink_path": "D", + "demand": 30.0, + "mode": "pairwise", + "priority": 1, # Different priority = different demand + "flow_policy_config": FlowPolicyPreset.TE_WCMP_UNLIM, + }, + ] + + result = demand_placement_analysis( + network=constrained_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 2 + # First demand should use primary path + assert result.flows[0].placed == pytest.approx(30.0, rel=1e-6) + # Second demand should also be placed (may need secondary path) + assert result.flows[1].placed == pytest.approx(30.0, rel=1e-6) + # Total should be 60 + assert result.summary.total_placed == pytest.approx(60.0, rel=1e-6) + + +class TestSPFCachingEdgeCases: + """Test edge cases and error handling for SPF caching.""" + + @pytest.fixture + def disconnected_network(self) -> Network: + """Create a network with disconnected components.""" + network = Network() + # First component: A -> B + network.add_node(Node("A")) + network.add_node(Node("B")) + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + + # Second component: C -> D (disconnected from first) + network.add_node(Node("C")) + network.add_node(Node("D")) + network.add_link(Link("C", "D", capacity=100.0, cost=1.0)) + + return network + + def test_unreachable_destination(self, disconnected_network: Network) -> None: + """Test placement to unreachable destination returns zero.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", # Unreachable from A + "demand": 50.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=disconnected_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + assert flow.placed == 0.0 + assert flow.dropped == 50.0 + + def test_zero_demand(self) -> None: + """Test placement of zero demand.""" + network = Network() + network.add_node(Node("A")) + network.add_node(Node("B")) + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + + demands_config = [ + { + "source_path": "A", + "sink_path": "B", + "demand": 0.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 1 + assert result.flows[0].placed == 0.0 + assert result.flows[0].dropped == 0.0 + assert result.summary.overall_ratio == 1.0 + + def test_partial_placement_due_to_capacity(self) -> None: + """Test partial placement when demand exceeds capacity.""" + network = Network() + network.add_node(Node("A")) + network.add_node(Node("B")) + network.add_link(Link("A", "B", capacity=30.0, cost=1.0)) + + demands_config = [ + { + "source_path": "A", + "sink_path": "B", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + assert flow.placed == 30.0 # Limited by capacity + assert flow.dropped == 20.0 + assert result.summary.overall_ratio == 0.6 + + def test_empty_cost_distribution_when_not_requested(self) -> None: + """Test that cost_distribution is empty when not requested.""" + network = Network() + network.add_node(Node("A")) + network.add_node(Node("B")) + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + + demands_config = [ + { + "source_path": "A", + "sink_path": "B", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_flow_details=False, + ) + + assert len(result.flows) == 1 + assert result.flows[0].cost_distribution == {} + + def test_empty_edges_when_not_requested(self) -> None: + """Test that edges data is empty when not requested.""" + network = Network() + network.add_node(Node("A")) + network.add_node(Node("B")) + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + + demands_config = [ + { + "source_path": "A", + "sink_path": "B", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_used_edges=False, + ) + + assert len(result.flows) == 1 + assert result.flows[0].data == {} + + +class TestSPFCachingWithExclusions: + """Test SPF caching with node and link exclusions.""" + + @pytest.fixture + def triangle_network(self) -> Network: + """Create a triangle network: A -- B -- C -- A.""" + network = Network() + for node in ["A", "B", "C"]: + network.add_node(Node(node)) + + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + network.add_link(Link("B", "C", capacity=100.0, cost=1.0)) + network.add_link(Link("A", "C", capacity=100.0, cost=2.0)) # Longer path + + return network + + def test_placement_with_excluded_link(self, triangle_network: Network) -> None: + """Test that excluded links are respected in cached placement.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "C", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + # Exclude direct A-C link, forcing traffic through B + result = demand_placement_analysis( + network=triangle_network, + excluded_nodes=set(), + excluded_links={"link_A_C"}, # Link ID format + demands_config=demands_config, + include_flow_details=True, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + assert flow.placed == 50.0 + # Should use path A -> B -> C (cost 2) instead of A -> C (cost 2) + if flow.cost_distribution: + # Cost should be 2 (through B) not 2 (direct, which is excluded) + assert 2.0 in flow.cost_distribution + + def test_placement_with_excluded_node(self, triangle_network: Network) -> None: + """Test that excluded nodes are respected in cached placement.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "C", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + # Exclude node B, forcing traffic through direct A-C link + result = demand_placement_analysis( + network=triangle_network, + excluded_nodes={"B"}, + excluded_links=set(), + demands_config=demands_config, + include_flow_details=True, + ) + + assert len(result.flows) == 1 + flow = result.flows[0] + assert flow.placed == 50.0 + # Should use direct A -> C path (cost 2) + if flow.cost_distribution: + assert 2.0 in flow.cost_distribution + + +class TestSPFCachingCostDistribution: + """Test cost distribution correctness with SPF caching.""" + + @pytest.fixture + def multi_tier_network(self) -> Network: + """Create a network with multiple cost tiers. + + A --[cost=1]--> B --[cost=1]--> D (cost 2, capacity 30) + A --[cost=2]--> C --[cost=2]--> D (cost 4, capacity 30) + """ + network = Network() + for node in ["A", "B", "C", "D"]: + network.add_node(Node(node)) + + # Tier 1: cost 2, capacity 30 + network.add_link(Link("A", "B", capacity=30.0, cost=1.0)) + network.add_link(Link("B", "D", capacity=30.0, cost=1.0)) + + # Tier 2: cost 4, capacity 30 + network.add_link(Link("A", "C", capacity=30.0, cost=2.0)) + network.add_link(Link("C", "D", capacity=30.0, cost=2.0)) + + return network + + def test_cost_distribution_single_tier(self, multi_tier_network: Network) -> None: + """Test cost distribution when only one tier is used.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 25.0, # Fits in tier 1 + "mode": "pairwise", + }, + ] + + result = demand_placement_analysis( + network=multi_tier_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_flow_details=True, + ) + + flow = result.flows[0] + assert flow.placed == 25.0 + # Should all be on tier 1 (cost 2) + assert flow.cost_distribution == {2.0: 25.0} + + def test_cost_distribution_multiple_tiers_te_policy( + self, multi_tier_network: Network + ) -> None: + """Test cost distribution with TE policy using multiple tiers.""" + demands_config = [ + { + "source_path": "A", + "sink_path": "D", + "demand": 50.0, # Exceeds tier 1 capacity + "mode": "pairwise", + "flow_policy_config": FlowPolicyPreset.TE_WCMP_UNLIM, + }, + ] + + result = demand_placement_analysis( + network=multi_tier_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + include_flow_details=True, + ) + + flow = result.flows[0] + # Should place 50 total (30 on tier 1 + 20 on tier 2) + assert flow.placed == pytest.approx(50.0, rel=1e-6) + + # Cost distribution should show both tiers + if flow.cost_distribution: + total = sum(flow.cost_distribution.values()) + assert total == pytest.approx(50.0, rel=1e-6) + # Should have cost 2 (tier 1) and cost 4 (tier 2) + assert len(flow.cost_distribution) >= 1 From beca5f99e0ef3c7fef1cab24c0623027c26fbf96 Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Thu, 11 Dec 2025 23:22:34 -0800 Subject: [PATCH 2/2] Enhance TrafficDemand handling and context caching --- CHANGELOG.md | 5 + docs/reference/api-full.md | 11 +- ngraph/exec/analysis/flow.py | 52 +-- ngraph/exec/demand/expand.py | 2 +- ngraph/exec/failure/manager.py | 19 +- ngraph/model/demand/spec.py | 11 +- ngraph/model/failure/policy.py | 6 - ngraph/results/artifacts.py | 111 ------- ngraph/results/snapshot.py | 1 + ngraph/types/base.py | 21 ++ ngraph/workflow/base.py | 17 +- ngraph/workflow/max_flow_step.py | 24 +- .../workflow/maximum_supported_demand_step.py | 292 +++++++++-------- .../workflow/traffic_matrix_placement_step.py | 50 ++- tests/exec/analysis/test_functions.py | 131 ++++++++ tests/exec/demand/test_expand.py | 304 ++++++++++++++++++ tests/model/demand/test_spec.py | 39 +++ .../workflow/test_maximum_supported_demand.py | 34 +- 18 files changed, 776 insertions(+), 354 deletions(-) create mode 100644 tests/exec/demand/test_expand.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d930934..d94e089 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - **SPF caching in demand placement**: `demand_placement_analysis()` caches SPF results by (source, policy_preset) for ECMP, WCMP, and TE_WCMP_UNLIM policies; TE policies recompute when capacity constraints require alternate paths +- **MSD AnalysisContext caching**: `MaximumSupportedDemand` builds `AnalysisContext` once and reuses it across all binary search probes + +### Fixed + +- **TrafficDemand ID preservation**: Fixed context caching with `mode: combine` by ensuring `TrafficDemand.id` is preserved through serialization; pseudo node names now remain consistent across context build and analysis ## [0.12.2] - 2025-12-08 diff --git a/docs/reference/api-full.md b/docs/reference/api-full.md index aa77927..8587d66 100644 --- a/docs/reference/api-full.md +++ b/docs/reference/api-full.md @@ -12,7 +12,7 @@ Quick links: - [CLI Reference](cli.md) - [DSL Reference](dsl.md) -Generated from source code on: December 11, 2025 at 22:53 UTC +Generated from source code on: December 11, 2025 at 23:43 UTC Modules auto-discovered: 44 @@ -460,12 +460,12 @@ Attributes: demand: Total demand volume. demand_placed: Portion of this demand placed so far. flow_policy_config: Policy preset (FlowPolicyPreset enum) used to build - a `FlowPolicy` if ``flow_policy`` is not provided. + a `FlowPolicy`` if ``flow_policy`` is not provided. flow_policy: Concrete policy instance. If set, it overrides ``flow_policy_config``. mode: Expansion mode, ``"combine"`` or ``"pairwise"``. attrs: Arbitrary user metadata. - id: Unique identifier assigned at initialization. + id: Unique identifier. Auto-generated if empty or not provided. **Attributes:** @@ -1238,6 +1238,9 @@ placeable for a given matrix. Stores results under `data` as: - `base_demands`: serialized base demand specs - `probes`: bracket/bisect evaluations with feasibility +Performance: AnalysisContext is built once at search start and reused across +all binary search probes. Only demand volumes change per probe. + ### MaximumSupportedDemand MaximumSupportedDemand(name: 'str' = '', seed: 'Optional[int]' = None, _seed_source: 'str' = '', matrix_name: 'str' = 'default', acceptance_rule: 'str' = 'hard', alpha_start: 'float' = 1.0, growth_factor: 'float' = 2.0, alpha_min: 'float' = 1e-06, alpha_max: 'float' = 1000000000.0, resolution: 'float' = 0.01, max_bracket_iters: 'int' = 32, max_bisect_iters: 'int' = 32, seeds_per_alpha: 'int' = 1, placement_rounds: 'int | str' = 'auto') @@ -2520,7 +2523,7 @@ Attributes: volume: Traffic volume to place. priority: Priority class (lower is higher priority). policy_preset: FlowPolicy configuration preset. - demand_id: Parent TrafficDemand ID (for tracking). + demand_id: Parent TrafficDemand ID for tracking. **Attributes:** diff --git a/ngraph/exec/analysis/flow.py b/ngraph/exec/analysis/flow.py index 565b6f7..02f9ca7 100644 --- a/ngraph/exec/analysis/flow.py +++ b/ngraph/exec/analysis/flow.py @@ -30,6 +30,32 @@ from ngraph.results.flow import FlowEntry, FlowIterationResult, FlowSummary from ngraph.types.base import FlowPlacement, Mode + +def _reconstruct_traffic_demands( + demands_config: list[dict[str, Any]], +) -> list[TrafficDemand]: + """Reconstruct TrafficDemand objects from serialized config. + + Args: + demands_config: List of demand configurations. + + Returns: + List of TrafficDemand objects with preserved IDs. + """ + return [ + TrafficDemand( + id=config.get("id") or "", + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + mode=config.get("mode", "pairwise"), + flow_policy_config=config.get("flow_policy_config"), + priority=config.get("priority", 0), + ) + for config in demands_config + ] + + if TYPE_CHECKING: from ngraph.model.network import Network @@ -445,18 +471,7 @@ def demand_placement_analysis( Returns: FlowIterationResult describing this iteration. """ - # Reconstruct TrafficDemand objects from config - traffic_demands = [] - for config in demands_config: - demand = TrafficDemand( - source_path=config["source_path"], - sink_path=config["sink_path"], - demand=config["demand"], - mode=config.get("mode", "pairwise"), - flow_policy_config=config.get("flow_policy_config"), - priority=config.get("priority", 0), - ) - traffic_demands.append(demand) + traffic_demands = _reconstruct_traffic_demands(demands_config) # Phase 1: Expand demands (pure logic, returns names + augmentations) expansion = expand_demands( @@ -714,18 +729,7 @@ def build_demand_context( Returns: AnalysisContext ready for use with demand_placement_analysis. """ - # Reconstruct TrafficDemand objects - traffic_demands = [] - for config in demands_config: - demand = TrafficDemand( - source_path=config["source_path"], - sink_path=config["sink_path"], - demand=config["demand"], - mode=config.get("mode", "pairwise"), - flow_policy_config=config.get("flow_policy_config"), - priority=config.get("priority", 0), - ) - traffic_demands.append(demand) + traffic_demands = _reconstruct_traffic_demands(demands_config) # Expand demands to get augmentations expansion = expand_demands( diff --git a/ngraph/exec/demand/expand.py b/ngraph/exec/demand/expand.py index 48076e0..cb3f015 100644 --- a/ngraph/exec/demand/expand.py +++ b/ngraph/exec/demand/expand.py @@ -31,7 +31,7 @@ class ExpandedDemand: volume: Traffic volume to place. priority: Priority class (lower is higher priority). policy_preset: FlowPolicy configuration preset. - demand_id: Parent TrafficDemand ID (for tracking). + demand_id: Parent TrafficDemand ID for tracking. """ src_name: str diff --git a/ngraph/exec/failure/manager.py b/ngraph/exec/failure/manager.py index 97604ee..866c713 100644 --- a/ngraph/exec/failure/manager.py +++ b/ngraph/exec/failure/manager.py @@ -942,14 +942,7 @@ def run_max_flow_monte_carlo( # Convert string flow_placement to enum if needed if isinstance(flow_placement, str): - try: - flow_placement = FlowPlacement[flow_placement.upper()] - except KeyError as exc: - valid_values = ", ".join([e.name for e in FlowPlacement]) - raise ValueError( - f"Invalid flow_placement '{flow_placement}'. " - f"Valid values are: {valid_values}" - ) from exc + flow_placement = FlowPlacement.from_string(flow_placement) # Run Monte Carlo analysis raw_results = self.run_monte_carlo_analysis( @@ -1069,6 +1062,7 @@ def run_demand_placement_monte_carlo( for demand in td_iter: # type: ignore[assignment] serializable_demands.append( { + "id": getattr(demand, "id", None), "source_path": getattr(demand, "source_path", ""), "sink_path": getattr(demand, "sink_path", ""), "demand": float(getattr(demand, "demand", 0.0)), @@ -1140,14 +1134,7 @@ def run_sensitivity_monte_carlo( # Convert string flow_placement to enum if needed if isinstance(flow_placement, str): - try: - flow_placement = FlowPlacement[flow_placement.upper()] - except KeyError as exc: - valid_values = ", ".join([e.name for e in FlowPlacement]) - raise ValueError( - f"Invalid flow_placement '{flow_placement}'. " - f"Valid values are: {valid_values}" - ) from exc + flow_placement = FlowPlacement.from_string(flow_placement) raw_results = self.run_monte_carlo_analysis( analysis_func=sensitivity_analysis, diff --git a/ngraph/model/demand/spec.py b/ngraph/model/demand/spec.py index c3c3805..2c1ef65 100644 --- a/ngraph/model/demand/spec.py +++ b/ngraph/model/demand/spec.py @@ -30,12 +30,12 @@ class TrafficDemand: demand: Total demand volume. demand_placed: Portion of this demand placed so far. flow_policy_config: Policy preset (FlowPolicyPreset enum) used to build - a `FlowPolicy` if ``flow_policy`` is not provided. + a `FlowPolicy`` if ``flow_policy`` is not provided. flow_policy: Concrete policy instance. If set, it overrides ``flow_policy_config``. mode: Expansion mode, ``"combine"`` or ``"pairwise"``. attrs: Arbitrary user metadata. - id: Unique identifier assigned at initialization. + id: Unique identifier. Auto-generated if empty or not provided. """ source_path: str = "" @@ -47,8 +47,9 @@ class TrafficDemand: flow_policy: Optional["FlowPolicy"] = None # type: ignore[valid-type] mode: str = "combine" attrs: Dict[str, Any] = field(default_factory=dict) - id: str = field(init=False) + id: str = "" def __post_init__(self) -> None: - """Assign a unique id from source, sink, and a Base64 UUID.""" - self.id = f"{self.source_path}|{self.sink_path}|{new_base64_uuid()}" + """Generate id if not provided.""" + if not self.id: + self.id = f"{self.source_path}|{self.sink_path}|{new_base64_uuid()}" diff --git a/ngraph/model/failure/policy.py b/ngraph/model/failure/policy.py index 0087538..ecb140d 100644 --- a/ngraph/model/failure/policy.py +++ b/ngraph/model/failure/policy.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple from .conditions import FailureCondition as EvalCondition -from .conditions import evaluate_condition as _shared_evaluate_condition from .conditions import evaluate_conditions as _shared_evaluate_conditions @@ -595,8 +594,3 @@ def to_dict(self) -> Dict[str, Any]: for mode in self.modes ] return data - - -def _evaluate_condition(entity_attrs: Dict[str, Any], cond: FailureCondition) -> bool: - """Wrapper using the shared evaluator.""" - return _shared_evaluate_condition(entity_attrs, cond) diff --git a/ngraph/results/artifacts.py b/ngraph/results/artifacts.py index 64363ac..4db5e6b 100644 --- a/ngraph/results/artifacts.py +++ b/ngraph/results/artifacts.py @@ -6,7 +6,6 @@ - `CapacityEnvelope`: frequency-based capacity distributions and optional aggregated flow statistics - `FailurePatternResult`: capacity results for specific failure patterns -- `PlacementEnvelope`: per-demand placement envelopes """ from __future__ import annotations @@ -336,113 +335,3 @@ def from_dict(cls, data: Dict[str, Any]) -> "FailurePatternResult": count=int(data.get("count", 0)), is_baseline=bool(data.get("is_baseline", False)), ) - - -@dataclass -class PlacementEnvelope: - """Per-demand placement envelope keyed like capacity envelopes. - - Each envelope captures frequency distribution of placement ratio for a - specific demand definition across Monte Carlo iterations. - - Attributes: - source: Source selection regex or node label. - sink: Sink selection regex or node label. - mode: Demand expansion mode ("combine" or "pairwise"). - priority: Demand priority class. - frequencies: Mapping of placement ratio to occurrence count. - min: Minimum observed placement ratio. - max: Maximum observed placement ratio. - mean: Mean placement ratio. - stdev: Standard deviation of placement ratio. - total_samples: Number of iterations represented. - """ - - source: str - sink: str - mode: str - priority: int - frequencies: Dict[float, int] - min: float - max: float - mean: float - stdev: float - total_samples: int - - @staticmethod - def _compute_stats(values: List[float]) -> tuple[float, float, float, float]: - n = len(values) - total = sum(values) - mean = total / n - sum_squares = sum(v * v for v in values) - variance = (sum_squares / n) - (mean * mean) - stdev = variance**0.5 - return (min(values), max(values), mean, stdev) - - @classmethod - def from_values( - cls, - source: str, - sink: str, - mode: str, - priority: int, - ratios: List[float], - rounding_decimals: int = 4, - ) -> "PlacementEnvelope": - if not ratios: - raise ValueError("Cannot create placement envelope from empty ratios list") - freqs: Dict[float, int] = {} - quantized: List[float] = [] - for r in ratios: - q = round(float(r), rounding_decimals) - quantized.append(q) - freqs[q] = freqs.get(q, 0) + 1 - mn, mx, mean, stdev = cls._compute_stats(quantized) - return cls( - source=source, - sink=sink, - mode=mode, - priority=int(priority), - frequencies=freqs, - min=mn, - max=mx, - mean=mean, - stdev=stdev, - total_samples=len(quantized), - ) - - def to_dict(self) -> Dict[str, Any]: - return { - "source": self.source, - "sink": self.sink, - "mode": self.mode, - "priority": self.priority, - "frequencies": self.frequencies, - "min": self.min, - "max": self.max, - "mean": self.mean, - "stdev": self.stdev, - "total_samples": self.total_samples, - } - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "PlacementEnvelope": - """Construct a PlacementEnvelope from a dictionary.""" - freqs_raw = data.get("frequencies", {}) or {} - freqs: Dict[float, int] = {} - for k, v in freqs_raw.items(): - key_f = float(k) - freqs[key_f] = int(v) - - return cls( - source=str(data.get("source", "")), - sink=str(data.get("sink", "")), - mode=str(data.get("mode", "pairwise")), - priority=int(data.get("priority", 0)), - frequencies=freqs, - min=float(data.get("min", 0.0)), - max=float(data.get("max", 0.0)), - mean=float(data.get("mean", 0.0)), - stdev=float(data.get("stdev", 0.0)), - total_samples=int(data.get("total_samples", 0)), - ) diff --git a/ngraph/results/snapshot.py b/ngraph/results/snapshot.py index 388268c..52dfa01 100644 --- a/ngraph/results/snapshot.py +++ b/ngraph/results/snapshot.py @@ -54,6 +54,7 @@ def build_scenario_snapshot( for d in demands: entries.append( { + "id": getattr(d, "id", None), "source_path": getattr(d, "source_path", ""), "sink_path": getattr(d, "sink_path", ""), "demand": float(getattr(d, "demand", 0.0)), diff --git a/ngraph/types/base.py b/ngraph/types/base.py index 0072c0e..0b80575 100644 --- a/ngraph/types/base.py +++ b/ngraph/types/base.py @@ -34,6 +34,27 @@ class FlowPlacement(IntEnum): PROPORTIONAL = 1 # Flow is split proportional to capacity (Dinic-like approach) EQUAL_BALANCED = 2 # Flow is equally divided among parallel paths of equal cost + @classmethod + def from_string(cls, value: str) -> "FlowPlacement": + """Parse a string into a FlowPlacement enum value. + + Args: + value: Case-insensitive string name (e.g., "proportional", "EQUAL_BALANCED"). + + Returns: + The corresponding FlowPlacement enum member. + + Raises: + ValueError: If the string doesn't match any enum member. + """ + try: + return cls[value.upper()] + except KeyError: + valid = ", ".join(e.name for e in cls) + raise ValueError( + f"Invalid flow_placement '{value}'. Valid values are: {valid}" + ) from None + class Mode(IntEnum): """Analysis mode for source/sink group handling. diff --git a/ngraph/workflow/base.py b/ngraph/workflow/base.py index 769820c..11fb189 100644 --- a/ngraph/workflow/base.py +++ b/ngraph/workflow/base.py @@ -7,10 +7,11 @@ from __future__ import annotations +import os import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING, Dict, Optional, Type, Union from ngraph.logging import get_logger @@ -45,6 +46,20 @@ def decorator(cls: Type["WorkflowStep"]) -> Type["WorkflowStep"]: return decorator +def resolve_parallelism(parallelism: Union[int, str]) -> int: + """Resolve parallelism setting to a concrete worker count. + + Args: + parallelism: Either an integer worker count or "auto" for CPU count. + + Returns: + Positive integer worker count (minimum 1). + """ + if isinstance(parallelism, str): + return max(1, int(os.cpu_count() or 1)) + return max(1, int(parallelism)) + + @dataclass class WorkflowStep(ABC): """Base class for all workflow steps. diff --git a/ngraph/workflow/max_flow_step.py b/ngraph/workflow/max_flow_step.py index 815b6fc..2554ab6 100644 --- a/ngraph/workflow/max_flow_step.py +++ b/ngraph/workflow/max_flow_step.py @@ -26,7 +26,6 @@ from __future__ import annotations -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING @@ -35,7 +34,11 @@ from ngraph.logging import get_logger from ngraph.results.flow import FlowIterationResult from ngraph.types.base import FlowPlacement -from ngraph.workflow.base import WorkflowStep, register_workflow_step +from ngraph.workflow.base import ( + WorkflowStep, + register_workflow_step, + resolve_parallelism, +) if TYPE_CHECKING: from ngraph.scenario import Scenario @@ -97,20 +100,7 @@ def __post_init__(self) -> None: "(first iteration is baseline, remaining are with failures)" ) if isinstance(self.flow_placement, str): - try: - self.flow_placement = FlowPlacement[self.flow_placement.upper()] - except KeyError: - valid_values = ", ".join([e.name for e in FlowPlacement]) - raise ValueError( - f"Invalid flow_placement '{self.flow_placement}'. " - f"Valid values are: {valid_values}" - ) from None - - @staticmethod - def _resolve_parallelism(parallelism: int | str) -> int: - if isinstance(parallelism, str): - return max(1, int(os.cpu_count() or 1)) - return max(1, int(parallelism)) + self.flow_placement = FlowPlacement.from_string(self.flow_placement) def run(self, scenario: "Scenario") -> None: t0 = time.perf_counter() @@ -134,7 +124,7 @@ def run(self, scenario: "Scenario") -> None: failure_policy_set=scenario.failure_policy_set, policy_name=self.failure_policy, ) - effective_parallelism = self._resolve_parallelism(self.parallelism) + effective_parallelism = resolve_parallelism(self.parallelism) raw = fm.run_max_flow_monte_carlo( source_path=self.source_path, sink_path=self.sink_path, diff --git a/ngraph/workflow/maximum_supported_demand_step.py b/ngraph/workflow/maximum_supported_demand_step.py index 4a8e77d..47bb379 100644 --- a/ngraph/workflow/maximum_supported_demand_step.py +++ b/ngraph/workflow/maximum_supported_demand_step.py @@ -7,25 +7,49 @@ - `context`: parameters used for the search - `base_demands`: serialized base demand specs - `probes`: bracket/bisect evaluations with feasibility + +Performance: AnalysisContext is built once at search start and reused across +all binary search probes. Only demand volumes change per probe. """ from __future__ import annotations import time from dataclasses import dataclass -from typing import Any +from typing import TYPE_CHECKING, Any import netgraph_core +import numpy as np -from ngraph.exec.demand.expand import expand_demands +from ngraph.exec.demand.expand import ExpandedDemand, expand_demands from ngraph.logging import get_logger from ngraph.model.demand.spec import TrafficDemand from ngraph.model.flow.policy_config import FlowPolicyPreset, create_flow_policy from ngraph.workflow.base import WorkflowStep, register_workflow_step +if TYPE_CHECKING: + from ngraph.analysis import AnalysisContext + logger = get_logger(__name__) +@dataclass +class _MSDCache: + """Cache for MSD binary search. + + Attributes: + ctx: Pre-built AnalysisContext with augmentations. + node_mask: Pre-built node mask (no exclusions during MSD). + edge_mask: Pre-built edge mask (no exclusions during MSD). + base_expanded: Expanded demands with base volumes. + """ + + ctx: "AnalysisContext" + node_mask: np.ndarray + edge_mask: np.ndarray + base_expanded: list[ExpandedDemand] + + @dataclass class MaximumSupportedDemand(WorkflowStep): matrix_name: str = "default" @@ -62,6 +86,7 @@ def __post_init__(self) -> None: def run(self, scenario: "Any") -> None: if self.acceptance_rule != "hard": raise ValueError("Only 'hard' acceptance_rule is implemented") + t0 = time.perf_counter() logger.info( "Starting MSD: name=%s matrix=%s alpha_start=%.6g growth=%.3f seeds=%d resolution=%.6g", @@ -72,12 +97,14 @@ def run(self, scenario: "Any") -> None: int(self.seeds_per_alpha), float(self.resolution), ) - base_tds = scenario.traffic_matrix_set.get_matrix(self.matrix_name) + # Serialize base demands for result output from ngraph.model.flow.policy_config import serialize_policy_preset + base_tds = scenario.traffic_matrix_set.get_matrix(self.matrix_name) base_demands: list[dict[str, Any]] = [ { + "id": getattr(td, "id", None), "source_path": getattr(td, "source_path", ""), "sink_path": getattr(td, "sink_path", ""), "demand": float(getattr(td, "demand", 0.0)), @@ -90,37 +117,71 @@ def run(self, scenario: "Any") -> None: for td in base_tds ] - # Validation: Ensure traffic matrix contains demands if not base_demands: raise ValueError( f"Traffic matrix '{self.matrix_name}' contains no demands. " "Cannot compute maximum supported demand without traffic specifications." ) - start_alpha = float(self.alpha_start) - g = float(self.growth_factor) - if not (g > 1.0): - raise ValueError("growth_factor must be > 1.0") - if self.resolution <= 0.0: - raise ValueError("resolution must be positive") + # Build cache once for all probes + cache = self._build_cache(scenario, self.matrix_name) + logger.debug( + "MSD cache built: %d expanded demands", + len(cache.base_expanded), + ) + # Binary search probes: list[dict[str, Any]] = [] def probe(alpha: float) -> tuple[bool, dict[str, Any]]: - feasible, details = self._evaluate_alpha( - alpha=alpha, - scenario=scenario, - matrix_name=self.matrix_name, - placement_rounds=self.placement_rounds, - seeds=self.seeds_per_alpha, - ) - probe_entry = {"alpha": alpha, "feasible": bool(feasible)} | details - probes.append(probe_entry) + feasible, details = self._evaluate_alpha(cache, alpha, self.seeds_per_alpha) + probes.append({"alpha": alpha, "feasible": bool(feasible)} | details) return feasible, details + alpha_star = self._binary_search(probe) + + # Store results + context = { + "acceptance_rule": self.acceptance_rule, + "alpha_start": self.alpha_start, + "growth_factor": self.growth_factor, + "alpha_min": self.alpha_min, + "alpha_max": self.alpha_max, + "resolution": self.resolution, + "max_bracket_iters": self.max_bracket_iters, + "max_bisect_iters": self.max_bisect_iters, + "seeds_per_alpha": self.seeds_per_alpha, + "matrix_name": self.matrix_name, + "placement_rounds": self.placement_rounds, + } + scenario.results.put("metadata", {}) + scenario.results.put( + "data", + { + "alpha_star": float(alpha_star), + "context": context, + "base_demands": base_demands, + "probes": probes, + }, + ) + logger.info( + "MSD completed: name=%s matrix=%s alpha_star=%.6g probes=%d duration=%.3fs", + self.name or self.__class__.__name__, + self.matrix_name, + float(alpha_star), + len(probes), + time.perf_counter() - t0, + ) + + def _binary_search(self, probe: "Any") -> float: + """Bracket and bisect to find alpha_star.""" + start_alpha = float(self.alpha_start) + g = float(self.growth_factor) + feasible0, _ = probe(start_alpha) lower: float | None = None upper: float | None = None + if feasible0: lower = start_alpha alpha = start_alpha @@ -151,147 +212,100 @@ def probe(alpha: float) -> tuple[bool, dict[str, Any]]: raise ValueError("No feasible alpha found above alpha_min") assert lower is not None and upper is not None and lower < upper - left = lower - right = upper - iters = 0 - while (right - left) > self.resolution and iters < self.max_bisect_iters: + + left, right = lower, upper + for _ in range(self.max_bisect_iters): + if (right - left) <= self.resolution: + break mid = (left + right) / 2.0 feas, _ = probe(mid) if feas: left = mid else: right = mid - iters += 1 - alpha_star = left - context = { - "acceptance_rule": self.acceptance_rule, - "alpha_start": self.alpha_start, - "growth_factor": self.growth_factor, - "alpha_min": self.alpha_min, - "alpha_max": self.alpha_max, - "resolution": self.resolution, - "max_bracket_iters": self.max_bracket_iters, - "max_bisect_iters": self.max_bisect_iters, - "seeds_per_alpha": self.seeds_per_alpha, - "matrix_name": self.matrix_name, - "placement_rounds": self.placement_rounds, - } - scenario.results.put("metadata", {}) - scenario.results.put( - "data", - { - "alpha_star": float(alpha_star), - "context": context, - "base_demands": base_demands, - "probes": probes, - }, - ) - logger.info( - "MSD completed: name=%s matrix=%s alpha_star=%.6g iterations=%d duration=%.3fs", - self.name or self.__class__.__name__, - self.matrix_name, - float(alpha_star), - int(self.max_bisect_iters), - time.perf_counter() - t0, - ) + return left @staticmethod - def _build_scaled_demands( - base_demands: list[dict[str, Any]], alpha: float - ) -> list[TrafficDemand]: - """Build scaled traffic demands for alpha probe.""" - demands: list[TrafficDemand] = [] - for d in base_demands: - demands.append( - TrafficDemand( - source_path=str(d["source_path"]), - sink_path=str(d["sink_path"]), - priority=int(d["priority"]), - demand=float(d["demand"]) * alpha, - flow_policy_config=d.get("flow_policy_config"), - mode=str(d.get("mode", "pairwise")), - ) - ) - return demands - - @classmethod - def _evaluate_alpha( - cls, - *, - alpha: float, - scenario: Any, - matrix_name: str, - placement_rounds: int | str, - seeds: int, - ) -> tuple[bool, dict[str, Any]]: - """Evaluate if alpha is feasible using Core-based placement. + def _build_cache(scenario: Any, matrix_name: str) -> _MSDCache: + """Build cache for MSD binary search. - Args: - alpha: Scale factor to test. - scenario: Scenario containing network and traffic matrix. - matrix_name: Name of traffic matrix to use. - placement_rounds: Placement rounds (unused - Core handles internally). - seeds: Number of seeds to test. - - Returns: - Tuple of (feasible, details_dict). + Creates stable TrafficDemand objects, expands them once, and builds + AnalysisContext. Called once at search start. """ + from ngraph.analysis import AnalysisContext + base_tds = scenario.traffic_matrix_set.get_matrix(matrix_name) - base_demands: list[dict[str, Any]] = [ - { - "source_path": getattr(td, "source_path", ""), - "sink_path": getattr(td, "sink_path", ""), - "demand": float(getattr(td, "demand", 0.0)), - "mode": getattr(td, "mode", "pairwise"), - "priority": int(getattr(td, "priority", 0)), - "flow_policy_config": getattr(td, "flow_policy_config", None), - } + + # Create stable TrafficDemand objects (same IDs for all probes) + stable_demands: list[TrafficDemand] = [ + TrafficDemand( + id=getattr(td, "id", "") or "", + source_path=str(getattr(td, "source_path", "")), + sink_path=str(getattr(td, "sink_path", "")), + priority=int(getattr(td, "priority", 0)), + demand=float(getattr(td, "demand", 0.0)), + flow_policy_config=getattr(td, "flow_policy_config", None), + mode=str(getattr(td, "mode", "pairwise")), + ) for td in base_tds ] - # Build scaled demands - scaled_demands = cls._build_scaled_demands(base_demands, alpha) - - # Phase 1: Expand demands (get names + augmentations) + # Expand once (augmentations depend on td.id, now stable) expansion = expand_demands( scenario.network, - scaled_demands, + stable_demands, default_policy_preset=FlowPolicyPreset.SHORTEST_PATHS_ECMP, ) - # Phase 2: Build Core infrastructure with augmentations - from ngraph.analysis import AnalysisContext - + # Build AnalysisContext once ctx = AnalysisContext.from_network( scenario.network, augmentations=expansion.augmentations, ) - # Build masks for disabled nodes/links (using internal methods) + # Build masks once (no exclusions during MSD) node_mask = ctx._build_node_mask(excluded_nodes=None) edge_mask = ctx._build_edge_mask(excluded_links=None) + return _MSDCache( + ctx=ctx, + node_mask=node_mask, + edge_mask=edge_mask, + base_expanded=expansion.demands, + ) + + @staticmethod + def _evaluate_alpha( + cache: _MSDCache, + alpha: float, + seeds: int, + ) -> tuple[bool, dict[str, Any]]: + """Evaluate if alpha is feasible. + + Uses pre-built cache; only scales demand volumes by alpha. + """ + ctx = cache.ctx + node_mask = cache.node_mask + edge_mask = cache.edge_mask + decisions: list[bool] = [] min_ratios: list[float] = [] for _ in range(max(1, int(seeds))): - # Create fresh FlowGraph for each seed flow_graph = netgraph_core.FlowGraph(ctx.multidigraph) - - # Phase 3: Place demands using Core total_demand = 0.0 total_placed = 0.0 - for demand in expansion.demands: - # Resolve node names to IDs (includes pseudo nodes) - src_id = ctx.node_mapper.to_id(demand.src_name) - dst_id = ctx.node_mapper.to_id(demand.dst_name) + for base_demand in cache.base_expanded: + scaled_volume = base_demand.volume * alpha + src_id = ctx.node_mapper.to_id(base_demand.src_name) + dst_id = ctx.node_mapper.to_id(base_demand.dst_name) policy = create_flow_policy( ctx.algorithms, ctx.handle, - demand.policy_preset, + base_demand.policy_preset, node_mask=node_mask, edge_mask=edge_mask, ) @@ -300,37 +314,55 @@ def _evaluate_alpha( flow_graph, src_id, dst_id, - demand.priority, - demand.volume, + base_demand.priority, + scaled_volume, ) - total_demand += demand.volume + total_demand += scaled_volume total_placed += placed - # Validation: Ensure we have non-zero demand to evaluate if total_demand == 0.0: raise ValueError( - f"Cannot evaluate feasibility for alpha={alpha:.6g}: total demand is zero. " - "This indicates that no demands were successfully expanded or all demand volumes are zero." + f"Cannot evaluate feasibility for alpha={alpha:.6g}: " + "total demand is zero." ) - # Check feasibility ratio = total_placed / total_demand is_feasible = ratio >= 1.0 - 1e-12 decisions.append(is_feasible) min_ratios.append(ratio) - # Majority vote across seeds yes = sum(1 for d in decisions if d) required = (len(decisions) // 2) + 1 feasible = yes >= required - details = { + return feasible, { "seeds": len(decisions), "feasible_seeds": yes, "min_placement_ratio": min(min_ratios) if min_ratios else 1.0, } - return feasible, details + + @staticmethod + def _build_scaled_demands( + base_demands: list[dict[str, Any]], alpha: float + ) -> list[TrafficDemand]: + """Build scaled TrafficDemand objects from serialized demands. + + Utility for tests to verify results at specific alpha values. + Preserves ID if present for context caching compatibility. + """ + return [ + TrafficDemand( + id=d.get("id") or "", + source_path=str(d["source_path"]), + sink_path=str(d["sink_path"]), + priority=int(d["priority"]), + demand=float(d["demand"]) * alpha, + flow_policy_config=d.get("flow_policy_config"), + mode=str(d.get("mode", "pairwise")), + ) + for d in base_demands + ] register_workflow_step("MaximumSupportedDemand")(MaximumSupportedDemand) diff --git a/ngraph/workflow/traffic_matrix_placement_step.py b/ngraph/workflow/traffic_matrix_placement_step.py index 9fef413..7ea2a58 100644 --- a/ngraph/workflow/traffic_matrix_placement_step.py +++ b/ngraph/workflow/traffic_matrix_placement_step.py @@ -6,7 +6,6 @@ from __future__ import annotations -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -14,7 +13,11 @@ from ngraph.exec.failure.manager import FailureManager from ngraph.logging import get_logger from ngraph.results.flow import FlowIterationResult -from ngraph.workflow.base import WorkflowStep, register_workflow_step +from ngraph.workflow.base import ( + WorkflowStep, + register_workflow_step, + resolve_parallelism, +) if TYPE_CHECKING: from ngraph.scenario import Scenario @@ -68,12 +71,6 @@ def __post_init__(self) -> None: if not (float(self.alpha) > 0.0): raise ValueError("alpha must be > 0.0") - @staticmethod - def _resolve_parallelism(parallelism: int | str) -> int: - if isinstance(parallelism, str): - return max(1, int(os.cpu_count() or 1)) - return max(1, int(parallelism)) - def run(self, scenario: "Scenario") -> None: if not self.matrix_name: raise ValueError("'matrix_name' is required for TrafficMatrixPlacement") @@ -105,20 +102,6 @@ def run(self, scenario: "Scenario") -> None: from ngraph.model.flow.policy_config import serialize_policy_preset - base_demands: list[dict[str, Any]] = [ - { - "source_path": getattr(td, "source_path", ""), - "sink_path": getattr(td, "sink_path", ""), - "demand": float(getattr(td, "demand", 0.0)), - "mode": getattr(td, "mode", "pairwise"), - "priority": int(getattr(td, "priority", 0)), - "flow_policy_config": serialize_policy_preset( - getattr(td, "flow_policy_config", None) - ), - } - for td in td_list - ] - # Resolve alpha effective_alpha = self._resolve_alpha(scenario) alpha_src = getattr(self, "_alpha_source", None) or "explicit" @@ -128,10 +111,14 @@ def run(self, scenario: "Scenario") -> None: str(alpha_src), ) + # Build demands_config with scaled demands (used for analysis) + # Also build base_demands for output (with serialized policy, unscaled) demands_config: list[dict[str, Any]] = [] + base_demands: list[dict[str, Any]] = [] for td in td_list: demands_config.append( { + "id": td.id, "source_path": td.source_path, "sink_path": td.sink_path, "demand": float(td.demand) * float(effective_alpha), @@ -140,6 +127,19 @@ def run(self, scenario: "Scenario") -> None: "priority": getattr(td, "priority", 0), } ) + base_demands.append( + { + "id": td.id, + "source_path": getattr(td, "source_path", ""), + "sink_path": getattr(td, "sink_path", ""), + "demand": float(getattr(td, "demand", 0.0)), + "mode": getattr(td, "mode", "pairwise"), + "priority": int(getattr(td, "priority", 0)), + "flow_policy_config": serialize_policy_preset( + getattr(td, "flow_policy_config", None) + ), + } + ) # Run via FailureManager fm = FailureManager( @@ -147,7 +147,7 @@ def run(self, scenario: "Scenario") -> None: failure_policy_set=scenario.failure_policy_set, policy_name=self.failure_policy, ) - effective_parallelism = self._resolve_parallelism(self.parallelism) + effective_parallelism = resolve_parallelism(self.parallelism) raw = fm.run_demand_placement_monte_carlo( demands_config=demands_config, @@ -219,9 +219,7 @@ def run(self, scenario: "Scenario") -> None: baseline_str = str(step_metadata.get("baseline", self.baseline)) iterations = int(step_metadata.get("iterations", self.iterations)) workers = int( - step_metadata.get( - "parallelism", self._resolve_parallelism(self.parallelism) - ) + step_metadata.get("parallelism", resolve_parallelism(self.parallelism)) ) logger.info( ( diff --git a/tests/exec/analysis/test_functions.py b/tests/exec/analysis/test_functions.py index 79741a9..2b3edb7 100644 --- a/tests/exec/analysis/test_functions.py +++ b/tests/exec/analysis/test_functions.py @@ -200,6 +200,137 @@ def test_demand_placement_analysis_zero_total_demand( assert summary.overall_ratio == 1.0 +class TestDemandPlacementWithContextCaching: + """Test demand_placement_analysis with pre-built context caching.""" + + @pytest.fixture + def diamond_network(self) -> Network: + """Create a diamond network for testing.""" + network = Network() + for node in ["A", "B", "C", "D"]: + network.add_node(Node(node)) + network.add_link(Link("A", "B", capacity=60.0, cost=1.0)) + network.add_link(Link("A", "C", capacity=60.0, cost=1.0)) + network.add_link(Link("B", "D", capacity=60.0, cost=1.0)) + network.add_link(Link("C", "D", capacity=60.0, cost=1.0)) + return network + + def test_context_caching_pairwise_mode(self, diamond_network: Network) -> None: + """Context caching works with pairwise mode.""" + from ngraph.exec.analysis.flow import build_demand_context + + demands_config = [ + { + "id": "stable-pairwise-id", + "source_path": "A", + "sink_path": "D", + "demand": 50.0, + "mode": "pairwise", + }, + ] + + # Build context once + ctx = build_demand_context(diamond_network, demands_config) + + # Use context for analysis + result = demand_placement_analysis( + network=diamond_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + context=ctx, + ) + + assert result.summary.total_placed == 50.0 + assert result.summary.overall_ratio == 1.0 + + def test_context_caching_combine_mode(self, diamond_network: Network) -> None: + """Context caching works with combine mode (uses pseudo nodes).""" + from ngraph.exec.analysis.flow import build_demand_context + + demands_config = [ + { + "id": "stable-combine-id", + "source_path": "[AB]", + "sink_path": "[CD]", + "demand": 50.0, + "mode": "combine", + }, + ] + + # Build context once + ctx = build_demand_context(diamond_network, demands_config) + + # Use context for analysis - this is where the bug manifested + result = demand_placement_analysis( + network=diamond_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + context=ctx, + ) + + assert result.summary.total_placed == 50.0 + assert result.summary.overall_ratio == 1.0 + + def test_context_caching_combine_multiple_iterations( + self, diamond_network: Network + ) -> None: + """Context can be reused for multiple analysis iterations.""" + from ngraph.exec.analysis.flow import build_demand_context + + demands_config = [ + { + "id": "reusable-id", + "source_path": "[AB]", + "sink_path": "[CD]", + "demand": 50.0, + "mode": "combine", + }, + ] + + ctx = build_demand_context(diamond_network, demands_config) + + # Run multiple iterations with different exclusions + for excluded in [set(), {"B"}, {"C"}]: + result = demand_placement_analysis( + network=diamond_network, + excluded_nodes=excluded, + excluded_links=set(), + demands_config=demands_config, + context=ctx, + ) + assert isinstance(result, FlowIterationResult) + + def test_context_caching_without_id_raises(self, diamond_network: Network) -> None: + """Context caching without stable ID raises KeyError for combine mode.""" + from ngraph.exec.analysis.flow import build_demand_context + + # Config without explicit ID - each reconstruction generates new ID + demands_config = [ + { + "source_path": "[AB]", + "sink_path": "[CD]", + "demand": 50.0, + "mode": "combine", + }, + ] + + # Build context - creates pseudo nodes with auto-generated ID (uuid1) + ctx = build_demand_context(diamond_network, demands_config) + + # Analysis reconstructs TrafficDemand without ID → generates new ID (uuid2) + # Tries to find pseudo nodes _src_...|uuid2 which don't exist → KeyError + with pytest.raises(KeyError): + demand_placement_analysis( + network=diamond_network, + excluded_nodes=set(), + excluded_links=set(), + demands_config=demands_config, + context=ctx, + ) + + class TestSensitivityAnalysis: """Test sensitivity_analysis function.""" diff --git a/tests/exec/demand/test_expand.py b/tests/exec/demand/test_expand.py new file mode 100644 index 0000000..5c24e8c --- /dev/null +++ b/tests/exec/demand/test_expand.py @@ -0,0 +1,304 @@ +"""Tests for demand expansion and TrafficDemand round-trip serialization.""" + +import pytest + +from ngraph.exec.demand.expand import expand_demands +from ngraph.model.demand.spec import TrafficDemand +from ngraph.model.network import Link, Network, Node + + +@pytest.fixture +def simple_network() -> Network: + """Create a simple 4-node network for testing.""" + network = Network() + for name in ["A", "B", "C", "D"]: + network.add_node(Node(name)) + network.add_link(Link("A", "B", capacity=100.0, cost=1.0)) + network.add_link(Link("B", "C", capacity=100.0, cost=1.0)) + network.add_link(Link("C", "D", capacity=100.0, cost=1.0)) + return network + + +class TestTrafficDemandIdRoundTrip: + """Test TrafficDemand ID preservation through serialization.""" + + def test_explicit_id_preserved(self) -> None: + """TrafficDemand with explicit ID preserves it.""" + td = TrafficDemand( + id="my-stable-id", + source_path="A", + sink_path="B", + demand=100.0, + ) + assert td.id == "my-stable-id" + + def test_auto_generated_id_when_none(self) -> None: + """TrafficDemand without explicit ID auto-generates one.""" + td = TrafficDemand(source_path="A", sink_path="B", demand=100.0) + assert td.id is not None + assert "|" in td.id # Format: source|sink|uuid + + def test_id_round_trip_through_dict(self) -> None: + """TrafficDemand ID survives dict serialization round-trip.""" + original = TrafficDemand( + source_path="A", + sink_path="B", + demand=100.0, + mode="combine", + priority=1, + ) + original_id = original.id + + # Serialize to dict (as done in workflow steps) + config = { + "id": original.id, + "source_path": original.source_path, + "sink_path": original.sink_path, + "demand": original.demand, + "mode": original.mode, + "priority": original.priority, + } + + # Reconstruct (as done in flow.py) + reconstructed = TrafficDemand( + id=config.get("id"), + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + mode=config.get("mode", "pairwise"), + priority=config.get("priority", 0), + ) + + assert reconstructed.id == original_id + + def test_id_mismatch_without_explicit_id(self) -> None: + """Two TrafficDemands from same config get different IDs if id not passed.""" + config = { + "source_path": "A", + "sink_path": "B", + "demand": 100.0, + } + + td1 = TrafficDemand( + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + ) + td2 = TrafficDemand( + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + ) + + # Without explicit ID, each gets a different auto-generated ID + assert td1.id != td2.id + + +class TestExpandDemandsPairwise: + """Test expand_demands with pairwise mode.""" + + def test_pairwise_single_pair(self, simple_network: Network) -> None: + """Pairwise mode with single source-sink creates one demand.""" + td = TrafficDemand( + source_path="A", sink_path="D", demand=100.0, mode="pairwise" + ) + expansion = expand_demands(simple_network, [td]) + + assert len(expansion.demands) == 1 + assert len(expansion.augmentations) == 0 # No pseudo nodes for pairwise + + demand = expansion.demands[0] + assert demand.src_name == "A" + assert demand.dst_name == "D" + assert demand.volume == 100.0 + + def test_pairwise_multiple_sources(self, simple_network: Network) -> None: + """Pairwise mode with regex creates demand per (src, dst) pair.""" + td = TrafficDemand( + source_path="[AB]", # A and B + sink_path="[CD]", # C and D + demand=100.0, + mode="pairwise", + ) + expansion = expand_demands(simple_network, [td]) + + # 2 sources x 2 sinks = 4 pairs + assert len(expansion.demands) == 4 + assert len(expansion.augmentations) == 0 + + # Volume distributed evenly + for demand in expansion.demands: + assert demand.volume == 25.0 # 100 / 4 + + def test_pairwise_no_self_loops(self, simple_network: Network) -> None: + """Pairwise mode excludes self-loops.""" + td = TrafficDemand( + source_path="[AB]", + sink_path="[AB]", # Same as sources + demand=100.0, + mode="pairwise", + ) + expansion = expand_demands(simple_network, [td]) + + # A->B and B->A only (no A->A or B->B) + assert len(expansion.demands) == 2 + for demand in expansion.demands: + assert demand.src_name != demand.dst_name + + +class TestExpandDemandsCombine: + """Test expand_demands with combine mode.""" + + def test_combine_creates_pseudo_nodes(self, simple_network: Network) -> None: + """Combine mode creates pseudo source and sink nodes.""" + td = TrafficDemand( + source_path="[AB]", + sink_path="[CD]", + demand=100.0, + mode="combine", + ) + expansion = expand_demands(simple_network, [td]) + + # One aggregated demand + assert len(expansion.demands) == 1 + + # Augmentations: pseudo_src->A, pseudo_src->B, C->pseudo_snk, D->pseudo_snk + assert len(expansion.augmentations) == 4 + + demand = expansion.demands[0] + assert demand.src_name.startswith("_src_") + assert demand.dst_name.startswith("_snk_") + assert demand.volume == 100.0 + + def test_combine_pseudo_node_names_use_id(self, simple_network: Network) -> None: + """Combine mode pseudo node names include TrafficDemand.id.""" + td = TrafficDemand( + id="stable-id-123", + source_path="A", + sink_path="D", + demand=100.0, + mode="combine", + ) + expansion = expand_demands(simple_network, [td]) + + demand = expansion.demands[0] + assert demand.src_name == "_src_stable-id-123" + assert demand.dst_name == "_snk_stable-id-123" + + def test_combine_augmentations_structure(self, simple_network: Network) -> None: + """Combine mode augmentations connect pseudo nodes to real nodes.""" + td = TrafficDemand( + id="test-id", + source_path="[AB]", + sink_path="[CD]", + demand=100.0, + mode="combine", + ) + expansion = expand_demands(simple_network, [td]) + + # Check augmentation edges + aug_edges = [(a.source, a.target) for a in expansion.augmentations] + + # Pseudo source -> real sources + assert ("_src_test-id", "A") in aug_edges + assert ("_src_test-id", "B") in aug_edges + + # Real sinks -> pseudo sink + assert ("C", "_snk_test-id") in aug_edges + assert ("D", "_snk_test-id") in aug_edges + + +class TestExpandDemandsIdConsistency: + """Test that expansion uses consistent IDs for pseudo node naming.""" + + def test_same_id_produces_same_pseudo_nodes(self, simple_network: Network) -> None: + """Same TrafficDemand ID produces identical pseudo node names.""" + td1 = TrafficDemand( + id="shared-id", + source_path="A", + sink_path="D", + demand=100.0, + mode="combine", + ) + td2 = TrafficDemand( + id="shared-id", + source_path="A", + sink_path="D", + demand=200.0, # Different demand + mode="combine", + ) + + exp1 = expand_demands(simple_network, [td1]) + exp2 = expand_demands(simple_network, [td2]) + + # Same pseudo node names + assert exp1.demands[0].src_name == exp2.demands[0].src_name + assert exp1.demands[0].dst_name == exp2.demands[0].dst_name + + def test_different_ids_produce_different_pseudo_nodes( + self, simple_network: Network + ) -> None: + """Different TrafficDemand IDs produce different pseudo node names.""" + td1 = TrafficDemand( + id="id-alpha", + source_path="A", + sink_path="D", + demand=100.0, + mode="combine", + ) + td2 = TrafficDemand( + id="id-beta", + source_path="A", + sink_path="D", + demand=100.0, + mode="combine", + ) + + exp1 = expand_demands(simple_network, [td1]) + exp2 = expand_demands(simple_network, [td2]) + + # Different pseudo node names + assert exp1.demands[0].src_name != exp2.demands[0].src_name + assert exp1.demands[0].dst_name != exp2.demands[0].dst_name + + +class TestExpandDemandsEdgeCases: + """Test edge cases for expand_demands.""" + + def test_empty_demands_raises(self, simple_network: Network) -> None: + """Empty demands list raises ValueError.""" + with pytest.raises(ValueError, match="No demands could be expanded"): + expand_demands(simple_network, []) + + def test_no_matching_nodes_raises(self, simple_network: Network) -> None: + """Demand with no matching nodes raises ValueError.""" + td = TrafficDemand( + source_path="nonexistent", + sink_path="also_nonexistent", + demand=100.0, + ) + with pytest.raises(ValueError, match="No demands could be expanded"): + expand_demands(simple_network, [td]) + + def test_multiple_demands_mixed_modes(self, simple_network: Network) -> None: + """Multiple demands with different modes expand correctly.""" + td_pairwise = TrafficDemand( + source_path="A", + sink_path="B", + demand=50.0, + mode="pairwise", + ) + td_combine = TrafficDemand( + source_path="[CD]", + sink_path="[AB]", + demand=100.0, + mode="combine", + ) + + expansion = expand_demands(simple_network, [td_pairwise, td_combine]) + + # 1 pairwise + 1 combined = 2 demands + assert len(expansion.demands) == 2 + + # Only combine mode creates augmentations + assert len(expansion.augmentations) == 4 # 2 sources + 2 sinks diff --git a/tests/model/demand/test_spec.py b/tests/model/demand/test_spec.py index 7c35270..f50f019 100644 --- a/tests/model/demand/test_spec.py +++ b/tests/model/demand/test_spec.py @@ -24,6 +24,45 @@ def test_defaults_and_id_generation() -> None: assert demand2.id != demand.id +def test_explicit_id_preserved() -> None: + """TrafficDemand with explicit ID preserves it unchanged.""" + demand = TrafficDemand( + id="my-explicit-id", + source_path="Src", + sink_path="Dst", + demand=100.0, + ) + assert demand.id == "my-explicit-id" + + +def test_explicit_id_round_trip() -> None: + """TrafficDemand ID survives serialization to dict and reconstruction.""" + original = TrafficDemand(source_path="A", sink_path="B", demand=50.0) + original_id = original.id + + # Simulate serialization (as done in workflow steps) + config = { + "id": original.id, + "source_path": original.source_path, + "sink_path": original.sink_path, + "demand": original.demand, + "mode": original.mode, + "priority": original.priority, + } + + # Simulate reconstruction (as done in flow.py) + reconstructed = TrafficDemand( + id=config.get("id"), + source_path=config["source_path"], + sink_path=config["sink_path"], + demand=config["demand"], + mode=config.get("mode", "pairwise"), + priority=config.get("priority", 0), + ) + + assert reconstructed.id == original_id + + def test_attrs_isolation_between_instances() -> None: """Each instance gets its own attrs dict; mutating one does not affect others.""" d1 = TrafficDemand(source_path="A", sink_path="B") diff --git a/tests/workflow/test_maximum_supported_demand.py b/tests/workflow/test_maximum_supported_demand.py index 3ea4029..86d35d3 100644 --- a/tests/workflow/test_maximum_supported_demand.py +++ b/tests/workflow/test_maximum_supported_demand.py @@ -21,12 +21,16 @@ def _mock_scenario_with_matrix() -> MagicMock: return mock_scenario -@patch( - "ngraph.workflow.maximum_supported_demand_step.MaximumSupportedDemand._evaluate_alpha" -) -def test_msd_basic_bracket_and_bisect(mock_eval: MagicMock) -> None: +@patch.object(MaximumSupportedDemand, "_evaluate_alpha") +@patch.object(MaximumSupportedDemand, "_build_cache") +def test_msd_basic_bracket_and_bisect( + mock_build_cache: MagicMock, mock_eval: MagicMock +) -> None: + """Test binary search logic with mocked evaluation.""" + mock_build_cache.return_value = MagicMock() + # Feasible if alpha <= 1.3, infeasible otherwise - def _eval(*, alpha, scenario, matrix_name, placement_rounds, seeds): # type: ignore[no-redef] + def _eval(cache, alpha, seeds): feasible = alpha <= 1.3 return feasible, { "seeds": 1, @@ -37,7 +41,6 @@ def _eval(*, alpha, scenario, matrix_name, placement_rounds, seeds): # type: ig mock_eval.side_effect = _eval scenario = _mock_scenario_with_matrix() - step = MaximumSupportedDemand( name="msd_step", matrix_name="default", @@ -60,18 +63,21 @@ def _eval(*, alpha, scenario, matrix_name, placement_rounds, seeds): # type: ig assert base and base[0]["source_path"] == "A" -@patch( - "ngraph.workflow.maximum_supported_demand_step.MaximumSupportedDemand._evaluate_alpha" -) -def test_msd_no_feasible_raises(mock_eval: MagicMock) -> None: +@patch.object(MaximumSupportedDemand, "_evaluate_alpha") +@patch.object(MaximumSupportedDemand, "_build_cache") +def test_msd_no_feasible_raises( + mock_build_cache: MagicMock, mock_eval: MagicMock +) -> None: + """Test that MSD raises when no feasible alpha is found.""" + mock_build_cache.return_value = MagicMock() + # Always infeasible - def _eval(*, alpha, scenario, matrix_name, placement_rounds, seeds): # type: ignore[no-redef] + def _eval(cache, alpha, seeds): return False, {"seeds": 1, "feasible_seeds": 0, "min_placement_ratio": 0.0} mock_eval.side_effect = _eval scenario = _mock_scenario_with_matrix() - step = MaximumSupportedDemand( name="msd_step", matrix_name="default", @@ -124,10 +130,11 @@ def test_msd_end_to_end_single_link() -> None: base_demands = data.get("base_demands") assert isinstance(base_demands, list) and base_demands - # Verify feasibility at alpha* using new Core-based API + # Verify feasibility at alpha* using demand_placement_analysis scaled_demands = MSD._build_scaled_demands(base_demands, float(alpha_star)) demands_config = [ { + "id": d.id, "source_path": d.source_path, "sink_path": d.sink_path, "demand": d.demand, @@ -154,6 +161,7 @@ def test_msd_end_to_end_single_link() -> None: scaled_demands_above = MSD._build_scaled_demands(base_demands, alpha_above) demands_config_above = [ { + "id": d.id, "source_path": d.source_path, "sink_path": d.sink_path, "demand": d.demand,