diff --git a/ngraph/blueprints.py b/ngraph/blueprints.py index d397938..1c13d2f 100644 --- a/ngraph/blueprints.py +++ b/ngraph/blueprints.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Any, Dict, List -from ngraph.network import Network, Node, Link +from ngraph.network import Link, Network, Node @dataclass(slots=True) @@ -16,9 +16,11 @@ class Blueprint: and a name_template), plus adjacency rules describing how those groups connect. Attributes: - name: Unique identifier of this blueprint. - groups: A mapping of group_name -> group definition (e.g. node_count, name_template). - adjacency: A list of adjacency dictionaries describing how groups are linked. + name (str): Unique identifier of this blueprint. + groups (Dict[str, Any]): A mapping of group_name -> group definition + (e.g. node_count, name_template). + adjacency (List[Dict[str, Any]]): A list of adjacency dictionaries + describing how groups are linked. """ name: str @@ -33,8 +35,10 @@ class DSLExpansionContext: to be populated during DSL expansion. Attributes: - blueprints: A dictionary of blueprint name -> Blueprint object. - network: The Network into which expanded nodes/links will be inserted. + blueprints (Dict[str, Blueprint]): A dictionary of blueprint-name -> + Blueprint object. + network (Network): The Network into which expanded nodes/links + will be inserted. """ blueprints: Dict[str, Blueprint] @@ -54,12 +58,14 @@ def expand_network_dsl(data: Dict[str, Any]) -> Network: 4) Process any direct node definitions. 5) Expand adjacency definitions in 'network["adjacency"]'. 6) Process any direct link definitions. + 7) Process link overrides. Args: - data: The YAML-parsed dictionary containing optional "blueprints" + "network". + data (Dict[str, Any]): The YAML-parsed dictionary containing + optional "blueprints" + "network". Returns: - A fully expanded Network object with all nodes and links. + Network: A fully expanded Network object with all nodes and links. """ # 1) Parse blueprint definitions blueprint_map: Dict[str, Blueprint] = {} @@ -101,9 +107,31 @@ def expand_network_dsl(data: Dict[str, Any]) -> Network: # 6) Process direct link definitions _process_direct_links(ctx.network, network_data) + # 7) Process link overrides + _process_link_overrides(ctx.network, network_data) + return net +def _process_link_overrides(network: Network, network_data: Dict[str, Any]) -> None: + """ + Processes the 'link_overrides' section of the network DSL, updating + existing links with new parameters. + + Args: + network (Network): The Network whose links will be updated. + network_data (Dict[str, Any]): The overall DSL data for the 'network'. + Expected to contain 'link_overrides' as a list of dicts, each with + 'source', 'target', and 'link_params'. + """ + link_overrides = network_data.get("link_overrides", []) + for link_override in link_overrides: + source = link_override["source"] + target = link_override["target"] + link_params = link_override["link_params"] + _update_links(network, source, target, link_params) + + def _expand_group( ctx: DSLExpansionContext, parent_path: str, @@ -117,11 +145,15 @@ def _expand_group( - Another blueprint's subgroups, or - A direct node group (node_count, name_template). - We do *not* skip the subgroup name even inside blueprint expansion, because - typically the 'group_name' is "leaf"/"spine" etc., not the blueprint’s name. - - So the final path is always 'parent_path + "/" + group_name' if parent_path is non-empty, - otherwise just group_name. + Args: + ctx (DSLExpansionContext): The context containing all blueprint info + and the target Network. + parent_path (str): The parent path in the hierarchy. + group_name (str): The current group's name. + group_def (Dict[str, Any]): The group definition (e.g. {node_count, name_template} + or {use_blueprint, parameters, ...}). + blueprint_expansion (bool): Indicates whether we are expanding within + a blueprint context or not. """ # Construct the effective path by appending group_name if parent_path is non-empty if parent_path: @@ -182,7 +214,14 @@ def _expand_blueprint_adjacency( parent_path: str, ) -> None: """ - Expands adjacency definitions from within a blueprint, using parent_path as the local root. + Expands adjacency definitions from within a blueprint, using parent_path + as the local root. + + Args: + ctx (DSLExpansionContext): The context object with blueprint info and the network. + adj_def (Dict[str, Any]): The adjacency definition inside the blueprint, + containing 'source', 'target', 'pattern', etc. + parent_path (str): The path that serves as the base for the blueprint's node paths. """ source_rel = adj_def["source"] target_rel = adj_def["target"] @@ -201,6 +240,11 @@ def _expand_adjacency( ) -> None: """ Expands a top-level adjacency definition from 'network.adjacency'. + + Args: + ctx (DSLExpansionContext): The context containing the target network. + adj_def (Dict[str, Any]): The adjacency definition dict, containing + 'source', 'target', and optional 'pattern', 'link_params'. """ source_path_raw = adj_def["source"] target_path_raw = adj_def["target"] @@ -230,6 +274,13 @@ def _expand_adjacency_pattern( * "one_to_one": Pair each source node with exactly one target node, supporting wrap-around if one side is an integer multiple of the other. Also skips self-loops. + + Args: + ctx (DSLExpansionContext): The context containing the target network. + source_path (str): The path pattern that identifies the source node group(s). + target_path (str): The path pattern that identifies the target node group(s). + pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one"). + link_params (Dict[str, Any]): Additional link parameters (capacity, cost, attrs). """ source_node_groups = ctx.network.select_node_groups_by_path(source_path) target_node_groups = ctx.network.select_node_groups_by_path(target_path) @@ -281,7 +332,6 @@ def _expand_adjacency_pattern( if pair not in dedup_pairs: dedup_pairs.add(pair) _create_link(ctx.network, sn, tn, link_params) - else: raise ValueError(f"Unknown adjacency pattern: {pattern}") @@ -291,6 +341,13 @@ def _create_link( ) -> None: """ Creates and adds a Link to the network, applying capacity/cost/attrs from link_params. + + Args: + net (Network): The network to which the new link is added. + source (str): Source node name for the link. + target (str): Target node name for the link. + link_params (Dict[str, Any]): A dict possibly containing 'capacity', 'cost', + and 'attrs' keys. """ capacity = link_params.get("capacity", 1.0) cost = link_params.get("cost", 1.0) @@ -306,15 +363,67 @@ def _create_link( net.add_link(link) +def _update_links( + net: Network, + source: str, + target: str, + link_params: Dict[str, Any], + any_direction: bool = True, +) -> None: + """ + Update all Link objects between nodes matching 'source' and 'target' paths + with new parameters. + + Args: + net (Network): The network whose links should be updated. + source (str): A path pattern identifying source node group(s). + target (str): A path pattern identifying target node group(s). + link_params (Dict[str, Any]): New parameter values for the links (capacity, cost, attrs). + any_direction (bool): If True, also update links in the reverse direction. + """ + source_node_groups = net.select_node_groups_by_path(source) + target_node_groups = net.select_node_groups_by_path(target) + + source_nodes = { + node.name for _, nodes in source_node_groups.items() for node in nodes + } + target_nodes = { + node.name for _, nodes in target_node_groups.items() for node in nodes + } + + for link in net.links.values(): + if link.source in source_nodes and link.target in target_nodes: + link.capacity = link_params.get("capacity", link.capacity) + link.cost = link_params.get("cost", link.cost) + link.attrs.update(link_params.get("attrs", {})) + + if ( + any_direction + and link.source in target_nodes + and link.target in source_nodes + ): + link.capacity = link_params.get("capacity", link.capacity) + link.cost = link_params.get("cost", link.cost) + link.attrs.update(link_params.get("attrs", {})) + + def _apply_parameters( subgroup_name: str, subgroup_def: Dict[str, Any], params_overrides: Dict[str, Any] ) -> Dict[str, Any]: """ Applies user-provided parameter overrides to a blueprint subgroup. - E.g.: - if 'spine.node_count' = 6 is in params_overrides, - we set 'node_count'=6 for the 'spine' subgroup. + Example: + If 'spine.node_count'=6 is in params_overrides, + we set 'node_count'=6 for the 'spine' subgroup. + + Args: + subgroup_name (str): Name of the subgroup in the blueprint (e.g. 'spine'). + subgroup_def (Dict[str, Any]): The default definition of the subgroup. + params_overrides (Dict[str, Any]): Overrides in the form of { 'spine.node_count': }. + + Returns: + Dict[str, Any]: A copy of subgroup_def with parameter overrides applied. """ out = dict(subgroup_def) for key, val in params_overrides.items(): @@ -327,22 +436,39 @@ def _apply_parameters( def _join_paths(parent_path: str, rel_path: str) -> str: """ - If rel_path starts with '/', interpret that as relative to 'parent_path'; - otherwise, simply append rel_path to parent_path with '/' if needed. + Joins two path segments according to NetGraph's DSL conventions: + - If rel_path starts with '/', remove the leading slash and treat it + as a relative path appended to parent_path (if present). + - Otherwise, simply append rel_path to parent_path if parent_path is non-empty. + + Args: + parent_path (str): The existing path prefix. + rel_path (str): A relative path that may start with '/'. + + Returns: + str: The combined path as a single string. """ if rel_path.startswith("/"): rel_path = rel_path[1:] if parent_path: return f"{parent_path}/{rel_path}" - else: - return rel_path + return rel_path + if parent_path: return f"{parent_path}/{rel_path}" return rel_path def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None: - """Processes direct node definitions (network_data["nodes"]).""" + """ + Processes direct node definitions (network_data["nodes"]) and adds them to the network + if they do not already exist. + + Args: + net (Network): The network to which nodes are added. + network_data (Dict[str, Any]): DSL data containing a "nodes" dict + keyed by node name -> attributes. + """ for node_name, node_attrs in network_data.get("nodes", {}).items(): if node_name not in net.nodes: new_node = Node(name=node_name, attrs=node_attrs or {}) @@ -352,7 +478,12 @@ def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None: def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None: """ - Processes direct link definitions (network_data["links"]). + Processes direct link definitions (network_data["links"]) and adds them to the network. + + Args: + net (Network): The network to which links are added. + network_data (Dict[str, Any]): DSL data containing a "links" list, + each item must have "source", "target", and optionally "link_params". """ existing_node_names = set(net.nodes.keys()) for link_info in network_data.get("links", []): @@ -360,6 +491,8 @@ def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None: target = link_info["target"] if source not in existing_node_names or target not in existing_node_names: raise ValueError(f"Link references unknown node(s): {source}, {target}.") + if source == target: + raise ValueError(f"Link cannot have the same source and target: {source}") link_params = link_info.get("link_params", {}) link = Link( source=source, diff --git a/ngraph/lib/algorithms/base.py b/ngraph/lib/algorithms/base.py index 3e52d2a..0b309c3 100644 --- a/ngraph/lib/algorithms/base.py +++ b/ngraph/lib/algorithms/base.py @@ -58,7 +58,7 @@ class EdgeSelect(IntEnum): class FlowPlacement(IntEnum): - """Ways to distribute flow on parallel edges.""" + """Ways to distribute flow across parallel equal cost paths.""" PROPORTIONAL = 1 # Flow is split proportional to capacity (Dinic-like approach) - EQUAL_BALANCED = 2 # Flow is equally divided among parallel edges + EQUAL_BALANCED = 2 # Flow is equally divided among parallel paths of equal cost diff --git a/ngraph/lib/algorithms/max_flow.py b/ngraph/lib/algorithms/max_flow.py index d51a4b5..eb62731a 100644 --- a/ngraph/lib/algorithms/max_flow.py +++ b/ngraph/lib/algorithms/max_flow.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import Optional from ngraph.lib.algorithms.spf import spf from ngraph.lib.algorithms.place_flow import place_flow_on_graph from ngraph.lib.algorithms.base import EdgeSelect, FlowPlacement @@ -27,8 +26,8 @@ def calc_max_flow( By default, this function: 1. Creates or re-initializes a flow-aware copy of the graph (using ``init_flow_graph``). - 2. Repeatedly finds a path from ``src_node`` to any reachable node via `spf` with - capacity constraints (via ``EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING``). + 2. Repeatedly finds a path from ``src_node`` to any reachable node via ``spf`` with + capacity constraints (through ``EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING``). 3. Places flow along that path (using ``place_flow_on_graph``) until no augmenting path remains or the capacities are exhausted. @@ -43,14 +42,14 @@ def calc_max_flow( dst_node (NodeID): The destination node for flow. flow_placement (FlowPlacement): - Determines how flow is split among parallel edges. + Determines how flow is split among parallel edges of equal cost. Defaults to ``FlowPlacement.PROPORTIONAL``. shortest_path (bool): - If True, place flow only once along the first shortest path and return immediately, - rather than attempting to find the true max flow. Defaults to False. + If True, place flow only once along the first shortest path found and return + immediately, rather than iterating to find the true max flow. reset_flow_graph (bool): - If True, reset any existing flow data (``flow_attr`` and ``flows_attr``) on the graph. - Defaults to False. + If True, reset any existing flow data (e.g., attributes in ``flow_attr`` and + ``flows_attr``). Defaults to False. capacity_attr (str): The name of the capacity attribute on edges. Defaults to "capacity". flow_attr (str): @@ -63,21 +62,20 @@ def calc_max_flow( Returns: float: The total flow placed between ``src_node`` and ``dst_node``. - If ``shortest_path=True``, returns the flow placed by a single augmentation. + If ``shortest_path=True``, this is the flow placed by a single augmentation. Examples: - >>> # Basic usage: >>> g = StrictMultiDiGraph() >>> g.add_node('A') >>> g.add_node('B') >>> g.add_node('C') - >>> e1 = g.add_edge('A', 'B', capacity=10.0, flow=0.0, flows={}) - >>> e2 = g.add_edge('B', 'C', capacity=5.0, flow=0.0, flows={}) + >>> _ = g.add_edge('A', 'B', capacity=10.0, flow=0.0, flows={}) + >>> _ = g.add_edge('B', 'C', capacity=5.0, flow=0.0, flows={}) >>> max_flow_value = calc_max_flow(g, 'A', 'C') >>> print(max_flow_value) 5.0 """ - # Optionally copy/initialize a flow-aware version of the graph + # Initialize a flow-aware graph (copy or in-place). flow_graph = init_flow_graph( graph.copy() if copy_graph else graph, flow_attr, @@ -85,10 +83,10 @@ def calc_max_flow( reset_flow_graph, ) - # Cache the edge selection function for repeated use + # Prepare the edge selection function (selects edges with capacity remaining). edge_select_func = edge_select_fabric(EdgeSelect.ALL_MIN_COST_WITH_CAP_REMAINING) - # First path-finding iteration + # First path-finding iteration. _, pred = spf(flow_graph, src_node, edge_select_func=edge_select_func) flow_meta = place_flow_on_graph( flow_graph, @@ -102,15 +100,15 @@ def calc_max_flow( ) max_flow: float = flow_meta.placed_flow - # If only the "first shortest path" flow is requested, stop here + # If only one path (single augmentation) is desired, return early. if shortest_path: return max_flow - # Otherwise, repeatedly find augmenting paths and place flow + # Otherwise, repeatedly find augmenting paths until no new flow can be placed. while True: _, pred = spf(flow_graph, src_node, edge_select_func=edge_select_func) if dst_node not in pred: - # No path found, we're done + # No path found; we've reached the max flow. break flow_meta = place_flow_on_graph( @@ -123,8 +121,8 @@ def calc_max_flow( flow_attr=flow_attr, flows_attr=flows_attr, ) - # If no additional flow was placed, we are at capacity if flow_meta.placed_flow <= 0: + # No additional flow could be placed; we're at capacity. break max_flow += flow_meta.placed_flow diff --git a/ngraph/lib/algorithms/place_flow.py b/ngraph/lib/algorithms/place_flow.py index 33148df..b1e3d75 100644 --- a/ngraph/lib/algorithms/place_flow.py +++ b/ngraph/lib/algorithms/place_flow.py @@ -50,7 +50,7 @@ def place_flow_on_graph( pred: A dictionary of node->(adj_node->list_of_edge_IDs) giving path adjacency. flow: Requested flow amount; can be infinite. flow_index: Identifier for this flow (used to track multiple flows). - flow_placement: Strategy for distributing flow among parallel edges. + flow_placement: Strategy for distributing flow among parallel equal cost paths. capacity_attr: Attribute name on edges for capacity. flow_attr: Attribute name on edges/nodes for aggregated flow. flows_attr: Attribute name on edges/nodes for per-flow tracking. diff --git a/ngraph/lib/flow.py b/ngraph/lib/flow.py index 31a8037..d22ce70 100644 --- a/ngraph/lib/flow.py +++ b/ngraph/lib/flow.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import ( - Hashable, - NamedTuple, - Optional, - Set, - Tuple, -) + +from typing import Hashable, NamedTuple, Optional, Set, Tuple + from ngraph.lib.algorithms.base import MIN_FLOW from ngraph.lib.algorithms.place_flow import ( FlowPlacement, @@ -21,10 +17,10 @@ class FlowIndex(NamedTuple): Describes a unique identifier for a Flow in the network. Attributes: - src_node: The source node of the flow. - dst_node: The destination node of the flow. - flow_class: An integer representing the 'class' of this flow (e.g. a traffic class). - flow_id: A unique integer ID for this flow. + src_node (NodeID): The source node of the flow. + dst_node (NodeID): The destination node of the flow. + flow_class (int): Integer representing the 'class' of this flow (e.g., traffic class). + flow_id (int): A unique integer ID for this flow. """ src_node: NodeID @@ -54,26 +50,30 @@ def __init__( Initialize a Flow object. Args: - path_bundle: A `PathBundle` representing the set of paths this flow will use. - flow_index: A unique identifier (can be any hashable) that tags this flow - in the network (e.g. an MPLS label, a tuple of (src, dst, class, id), etc.). - excluded_edges: An optional set of edges to exclude from consideration. - excluded_nodes: An optional set of nodes to exclude from consideration. + path_bundle (PathBundle): The set of paths this flow uses. + flow_index (Hashable): A unique identifier for this flow (e.g., MPLS label, tuple, etc.). + excluded_edges (Optional[Set[EdgeID]]): Edges to exclude from usage. + excluded_nodes (Optional[Set[NodeID]]): Nodes to exclude from usage. """ self.path_bundle: PathBundle = path_bundle self.flow_index: Hashable = flow_index self.excluded_edges: Set[EdgeID] = excluded_edges or set() self.excluded_nodes: Set[NodeID] = excluded_nodes or set() - # Store convenience references for the Flow's endpoints + # Convenience references for flow endpoints self.src_node: NodeID = path_bundle.src_node self.dst_node: NodeID = path_bundle.dst_node - # Track how much flow has been successfully placed so far + # Track how much flow has been successfully placed self.placed_flow: float = 0.0 def __str__(self) -> str: - """String representation of the Flow.""" + """ + Returns a string representation of the Flow. + + Returns: + str: String representation including flow index and placed flow amount. + """ return f"Flow(flow_index={self.flow_index}, placed_flow={self.placed_flow})" def place_flow( @@ -83,22 +83,21 @@ def place_flow( flow_placement: FlowPlacement, ) -> Tuple[float, float]: """ - Attempt to place (or update) this flow on `flow_graph`. + Attempt to place (or update) this flow on the given `flow_graph`. Args: - flow_graph: The network graph where flow capacities and usage are tracked. - to_place: The amount of flow requested to be placed on this path bundle. - flow_placement: Strategy determining how flow is distributed among parallel edges. + flow_graph (StrictMultiDiGraph): The network graph tracking capacities and usage. + to_place (float): The amount of flow requested to be placed. + flow_placement (FlowPlacement): Strategy for distributing flow among equal-cost paths. Returns: - A tuple `(placed_flow, remaining_flow)` where: - - `placed_flow` is the amount of flow actually placed on `flow_graph`. - - `remaining_flow` is how much of `to_place` could not be placed - (due to capacity limits or other constraints). + Tuple[float, float]: A tuple of: + placed_flow (float): The amount of flow actually placed. + remaining_flow (float): The flow that could not be placed. """ placed_flow = 0.0 - # Only place flow if it's above the MIN_FLOW threshold + # Only place flow if above the minimum threshold if to_place >= MIN_FLOW: flow_placement_meta = place_flow_on_graph( flow_graph=flow_graph, @@ -117,10 +116,10 @@ def place_flow( def remove_flow(self, flow_graph: StrictMultiDiGraph) -> None: """ - Remove this flow's contribution from `flow_graph`. + Remove this flow's contribution from the provided `flow_graph`. Args: - flow_graph: The network graph from which this flow's usage should be removed. + flow_graph (StrictMultiDiGraph): The network graph from which to remove this flow's usage. """ remove_flow_from_graph(flow_graph, flow_index=self.flow_index) self.placed_flow = 0.0 diff --git a/ngraph/network.py b/ngraph/network.py index 44fd2af..2df6804 100644 --- a/ngraph/network.py +++ b/ngraph/network.py @@ -1,14 +1,14 @@ from __future__ import annotations -import uuid import base64 import re +import uuid from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple -from ngraph.lib.graph import StrictMultiDiGraph -from ngraph.lib.algorithms.max_flow import calc_max_flow from ngraph.lib.algorithms.base import FlowPlacement +from ngraph.lib.algorithms.max_flow import calc_max_flow +from ngraph.lib.graph import StrictMultiDiGraph def new_base64_uuid() -> str: @@ -32,7 +32,8 @@ class Node: Attributes: name (str): Unique identifier for the node. attrs (Dict[str, Any]): Optional metadata (e.g., type, coordinates, region). - Use attrs["disabled"] = True/False to mark active/inactive. + Set attrs["disabled"] = True to mark the node as inactive. + Defaults to an empty dict. """ name: str @@ -50,7 +51,8 @@ class Link: capacity (float): Link capacity (default 1.0). cost (float): Link cost (default 1.0). attrs (Dict[str, Any]): Optional metadata (e.g., type, distance). - Use attrs["disabled"] = True/False to mark active/inactive. + Set attrs["disabled"] = True to mark link as inactive. + Defaults to an empty dict. id (str): Auto-generated unique identifier in the form "{source}|{target}|". """ @@ -89,7 +91,8 @@ def add_node(self, node: Node) -> None: """ Add a node to the network (keyed by node.name). - Auto-tags node.attrs["type"] = "node" if not already set. + Auto-tags node.attrs["type"] = "node" if not already set, + and node.attrs["disabled"] = False if not specified. Args: node: Node to add. @@ -107,13 +110,14 @@ def add_link(self, link: Link) -> None: """ Add a link to the network (keyed by the link's auto-generated ID). - Auto-tags link.attrs["type"] = "link" if not already set. + Auto-tags link.attrs["type"] = "link" if not already set, + and link.attrs["disabled"] = False if not specified. Args: link: Link to add. Raises: - ValueError: If the link's source or target node is missing. + ValueError: If the link's source or target node does not exist. """ if link.source not in self.nodes: raise ValueError(f"Source node '{link.source}' not found in network.") @@ -128,30 +132,30 @@ def to_strict_multidigraph(self, add_reverse: bool = True) -> StrictMultiDiGraph """ Create a StrictMultiDiGraph representation of this Network. - Nodes and links with attrs["disabled"] = True are omitted. + Nodes and links whose attrs["disabled"] == True are omitted. Args: add_reverse: If True, also add a reverse edge for each link. Returns: - StrictMultiDiGraph: A directed multigraph representation. + StrictMultiDiGraph: A directed multigraph representation of the network. """ graph = StrictMultiDiGraph() - # Add enabled nodes + # Identify disabled nodes for quick checks disabled_nodes = { name for name, node in self.nodes.items() if node.attrs.get("disabled", False) } + + # Add enabled nodes for node_name, node in self.nodes.items(): - if node.attrs.get("disabled", False): - continue - graph.add_node(node_name, **node.attrs) + if not node.attrs.get("disabled", False): + graph.add_node(node_name, **node.attrs) # Add enabled links for link_id, link in self.links.items(): - # Skip if link is disabled or if source/target is disabled if link.attrs.get("disabled", False): continue if link.source in disabled_nodes or link.target in disabled_nodes: @@ -185,17 +189,17 @@ def select_node_groups_by_path(self, path: str) -> Dict[str, List[Node]]: """ Select and group nodes whose names match a given regular expression. - This method uses re.match(), so the pattern is automatically anchored - at the start of the node name. If the pattern includes capturing groups, + This method uses re.match(), so the pattern is anchored at the start + of the node name. If the pattern includes capturing groups, the group label is formed by joining all non-None captures with '|'. - If no capturing groups exist, the group label is simply the original + If no capturing groups exist, the group label is the original pattern string. Args: path: A Python regular expression pattern (e.g., "^foo", "bar(\\d+)", etc.). Returns: - A mapping from group label -> list of nodes that matched the pattern. + Dict[str, List[Node]]: A mapping from group label -> list of matching nodes. """ pattern = re.compile(path) groups_map: Dict[str, List[Node]] = {} @@ -222,27 +226,24 @@ def max_flow( ) -> Dict[Tuple[str, str], float]: """ Compute maximum flow between groups of source nodes and sink nodes. - Always returns a dictionary of flow values. The dict keys are - (source_label, sink_label), and the values are the flow amounts. + Returns a dictionary of flow values keyed by (source_label, sink_label). Args: source_path: Regex pattern for selecting source nodes. sink_path: Regex pattern for selecting sink nodes. - mode: "combine" or "pairwise". - - "combine": All matched sources become one combined source group, - all matched sinks become one combined sink group. Returns a dict - with a single entry {("", ""): flow_value}. - - "pairwise": Compute flow for each (source_group, sink_group) and - return a dict of flows for all pairs. - shortest_path: If True, flow is constrained to shortest paths. - flow_placement: Determines how parallel edges are handled. + mode: Either "combine" or "pairwise". + - "combine": Treat all matched sources as one group, + and all matched sinks as one group. Returns a single dict entry. + - "pairwise": Compute flow for each (source_group, sink_group) pair. + shortest_path: If True, flows are constrained to shortest paths. + flow_placement: Determines how parallel equal-cost paths are handled. Returns: - A dictionary mapping (src_label, snk_label) -> flow. + Dict[Tuple[str, str], float]: Flow values for each (src_label, snk_label) pair. Raises: ValueError: If no matching source or sink groups are found, - or if mode is invalid. + or if the mode is invalid. """ src_groups = self.select_node_groups_by_path(source_path) snk_groups = self.select_node_groups_by_path(sink_path) @@ -285,9 +286,7 @@ def max_flow( return results else: - raise ValueError( - f"Invalid mode '{mode}' for max_flow. Must be 'combine' or 'pairwise'." - ) + raise ValueError(f"Invalid mode '{mode}'. Must be 'combine' or 'pairwise'.") def _compute_flow_single_group( self, @@ -297,22 +296,21 @@ def _compute_flow_single_group( flow_placement: FlowPlacement, ) -> float: """ - Attach a pseudo-source and pseudo-sink to the given node lists, + Attach a pseudo-source and pseudo-sink to the provided node lists, then run calc_max_flow. Returns the resulting flow from all sources to all sinks as a single float. - Ignores disabled nodes. + + Disabled nodes are excluded from flow computation. Args: sources: List of source nodes. sinks: List of sink nodes. - shortest_path: Whether to use shortest paths only. - flow_placement: How parallel edges are handled. + shortest_path: If True, use only shortest paths for flow. + flow_placement: Strategy for placing flow among parallel equal-cost paths. Returns: - The computed max-flow value, or 0.0 if either list is empty - or all are disabled. + float: The computed maximum flow value, or 0.0 if there are no active sources or sinks. """ - # Filter out disabled nodes at the source/sink stage active_sources = [s for s in sources if not s.attrs.get("disabled", False)] active_sinks = [s for s in sinks if not s.attrs.get("disabled", False)] @@ -325,7 +323,6 @@ def _compute_flow_single_group( for src_node in active_sources: graph.add_edge("source", src_node.name, capacity=float("inf"), cost=0) - for sink_node in active_sinks: graph.add_edge(sink_node.name, "sink", capacity=float("inf"), cost=0) @@ -340,10 +337,13 @@ def _compute_flow_single_group( def disable_node(self, node_name: str) -> None: """ - Mark a node as disabled. Raises ValueError if the node doesn't exist. + Mark a node as disabled. Args: node_name: Name of the node to disable. + + Raises: + ValueError: If the specified node does not exist. """ if node_name not in self.nodes: raise ValueError(f"Node '{node_name}' does not exist.") @@ -351,10 +351,13 @@ def disable_node(self, node_name: str) -> None: def enable_node(self, node_name: str) -> None: """ - Mark a node as enabled. Raises ValueError if the node doesn't exist. + Mark a node as enabled. Args: node_name: Name of the node to enable. + + Raises: + ValueError: If the specified node does not exist. """ if node_name not in self.nodes: raise ValueError(f"Node '{node_name}' does not exist.") @@ -362,10 +365,13 @@ def enable_node(self, node_name: str) -> None: def disable_link(self, link_id: str) -> None: """ - Mark a link as disabled. Raises ValueError if the link doesn't exist. + Mark a link as disabled. Args: link_id: ID of the link to disable. + + Raises: + ValueError: If the specified link does not exist. """ if link_id not in self.links: raise ValueError(f"Link '{link_id}' does not exist.") @@ -373,10 +379,13 @@ def disable_link(self, link_id: str) -> None: def enable_link(self, link_id: str) -> None: """ - Mark a link as enabled. Raises ValueError if the link doesn't exist. + Mark a link as enabled. Args: link_id: ID of the link to enable. + + Raises: + ValueError: If the specified link does not exist. """ if link_id not in self.links: raise ValueError(f"Link '{link_id}' does not exist.") @@ -402,15 +411,14 @@ def disable_all(self) -> None: def get_links_between(self, source: str, target: str) -> List[str]: """ - Return all link IDs that connect the specified source and target exactly. + Retrieve all link IDs that connect the specified source node to the target node. Args: source: Name of the source node. target: Name of the target node. Returns: - A list of link IDs for links where (link.source == source - and link.target == target). + List[str]: All link IDs where (link.source == source and link.target == target). """ matches = [] for link_id, link in self.links.items(): @@ -424,15 +432,15 @@ def find_links( target_regex: Optional[str] = None, ) -> List[Link]: """ - Search for links based on optional regex patterns for source or target. + Search for links using optional regex patterns for source or target node names. Args: - source_regex: Regex pattern to match the link's source node. - target_regex: Regex pattern to match the link's target node. + source_regex: Regex pattern to match link.source. If None, matches all. + target_regex: Regex pattern to match link.target. If None, matches all. Returns: - A list of Link objects matching the criteria. If both patterns - are None, returns all links. + List[Link]: A list of Link objects that match the provided criteria. + If both patterns are None, returns all links. """ if source_regex: src_pat = re.compile(source_regex) diff --git a/ngraph/workflow/capacity_probe.py b/ngraph/workflow/capacity_probe.py index e41af49..afa27d7 100644 --- a/ngraph/workflow/capacity_probe.py +++ b/ngraph/workflow/capacity_probe.py @@ -23,7 +23,7 @@ class CapacityProbe(WorkflowStep): - "pairwise": Compute flow for each (source_group, sink_group). probe_reverse (bool): If True, also compute flow in the reverse direction (sink→source). shortest_path (bool): If True, only use shortest paths when computing flow. - flow_placement (FlowPlacement): Handling strategy for parallel edges (default PROPORTIONAL). + flow_placement (FlowPlacement): Handling strategy for parallel equal cost paths (default PROPORTIONAL). """ source_path: str = "" @@ -33,6 +33,17 @@ class CapacityProbe(WorkflowStep): shortest_path: bool = False flow_placement: FlowPlacement = FlowPlacement.PROPORTIONAL + def __post_init__(self): + if isinstance(self.flow_placement, str): + try: + self.flow_placement = FlowPlacement[self.flow_placement.upper()] + except KeyError: + valid_values = ", ".join([e.name for e in FlowPlacement]) + raise ValueError( + f"Invalid flow_placement '{self.flow_placement}'. " + f"Valid values are: {valid_values}" + ) + def run(self, scenario: Scenario) -> None: """ Executes the capacity probe by computing max flow between node groups diff --git a/tests/scenarios/scenario_3.yaml b/tests/scenarios/scenario_3.yaml index adbeb33..b3df060 100644 --- a/tests/scenarios/scenario_3.yaml +++ b/tests/scenarios/scenario_3.yaml @@ -59,6 +59,14 @@ network: capacity: 2 cost: 1 + link_overrides: + # Overriding a link between two spine devices. + - source: my_clos1/spine/t3-1$ + target: my_clos2/spine/t3-1$ + link_params: + capacity: 1 + cost: 1 + workflow: - step_type: BuildGraph name: build_graph @@ -70,3 +78,13 @@ workflow: mode: combine probe_reverse: True shortest_path: True + flow_placement: PROPORTIONAL + + - step_type: CapacityProbe + name: capacity_probe2 + source_path: my_clos1/b.*/t1 + sink_path: my_clos2/b.*/t1 + mode: combine + probe_reverse: True + shortest_path: True + flow_placement: EQUAL_BALANCED \ No newline at end of file diff --git a/tests/scenarios/test_scenario_3.py b/tests/scenarios/test_scenario_3.py index 762c989..0d8af32 100644 --- a/tests/scenarios/test_scenario_3.py +++ b/tests/scenarios/test_scenario_3.py @@ -69,6 +69,7 @@ def test_scenario_3_build_graph_and_capacity_probe() -> None: ), "Missing expected node 'my_clos2/spine/t3-16' in expanded blueprint." # 10) The capacity probe step computed forward and reverse flows in 'combine' mode + # with PROPORTIONAL flow placement. flow_result_label_fwd = "max_flow:[my_clos1/b.*/t1 -> my_clos2/b.*/t1]" flow_result_label_rev = "max_flow:[my_clos2/b.*/t1 -> my_clos1/b.*/t1]" @@ -80,7 +81,31 @@ def test_scenario_3_build_graph_and_capacity_probe() -> None: # 11) Assert the expected flows # The main bottleneck is the 16 spine-to-spine links of capacity=2 => total 32 # (same in both forward and reverse). - expected_flow = 32.0 + # However, one link is overriden to capacity=1, so, with PROPORTIONAL flow placement, + # the max flow is 31. + expected_flow = 31.0 + assert forward_flow == expected_flow, ( + f"Expected forward max flow of {expected_flow}, got {forward_flow}. " + "Check blueprint or link capacities if this fails." + ) + assert reverse_flow == expected_flow, ( + f"Expected reverse max flow of {expected_flow}, got {reverse_flow}. " + "Check blueprint or link capacities if this fails." + ) + + # 12) The capacity probe step computed with EQUAL_BALANCED flow placement + + # Retrieve the forward flow + forward_flow = scenario.results.get("capacity_probe2", flow_result_label_fwd) + # Retrieve the reverse flow + reverse_flow = scenario.results.get("capacity_probe2", flow_result_label_rev) + + # 13) Assert the expected flows + # The main bottleneck is the 16 spine-to-spine links of capacity=2 => total 32 + # (same in both forward and reverse). + # However, one link is overriden to capacity=1, so, with EQUAL_BALANCED flow placement, + # the max flow is 16. + expected_flow = 16.0 assert forward_flow == expected_flow, ( f"Expected forward max flow of {expected_flow}, got {forward_flow}. " "Check blueprint or link capacities if this fails."