From 8c1beab5e6bfd3cd0125c7e8abb6607af39adbac Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Fri, 7 Mar 2025 00:19:27 +0000 Subject: [PATCH 1/2] adding shared risk groups --- ngraph/failure_policy.py | 211 +++++++++++++++++++++-------------- tests/test_failure_policy.py | 117 +++++++++++++++++++ 2 files changed, 242 insertions(+), 86 deletions(-) diff --git a/ngraph/failure_policy.py b/ngraph/failure_policy.py index a7f330c..845b0a3 100644 --- a/ngraph/failure_policy.py +++ b/ngraph/failure_policy.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Literal from random import random, sample +from collections import defaultdict, deque @dataclass @@ -8,22 +9,21 @@ class FailureCondition: """ A single condition for matching an entity's attribute with an operator and value. - Example usage (YAML-ish): + Example usage (YAML): - .. code-block:: yaml - - conditions: - - attr: "capacity" - operator: "<" - value: 100 + conditions: + - attr: "capacity" + operator: "<" + value: 100 Attributes: attr (str): - The name of the attribute to inspect (e.g. "type", "capacity"). + The name of the attribute to inspect (e.g., "type", "capacity"). operator (str): - The comparison operator: "==", "!=", "<", "<=", ">", ">=". + The comparison operator: "==", "!=", "<", "<=", ">", ">=", "contains", + "not_contains", "any_value", or "no_value". value (Any): - The value to compare against (e.g. "node", 100, True, etc.). + The value to compare against (e.g., "node", 100, True, etc.). """ attr: str @@ -34,35 +34,24 @@ class FailureCondition: @dataclass class FailureRule: """ - A single rule defining how to match entities and then select them for failure. - - * conditions: list of conditions - * logic: how to combine conditions ("and", "or", "any") - * rule_type: how to pick from matched entities ("random", "choice", "all") - * probability: used by "random" (a float in [0,1]) - * count: used by "choice" (e.g. pick 2) - - When multiple FailureRules appear in a FailurePolicy, the final - set of failures is the **union** of all entities selected by each rule. + Defines how to match entities and then select them for failure. Attributes: conditions (List[FailureCondition]): A list of conditions to filter matching entities. logic (Literal["and", "or", "any"]): - - "and": All conditions must be true. - - "or": At least one condition is true. - - "any": Skip condition checks; everything is matched. + - "and": All conditions must be true for a match. + - "or": At least one condition is true for a match. + - "any": Skip condition checks and match all entities. rule_type (Literal["random", "choice", "all"]): The selection strategy among the matched set: - - "random": Each matched entity is chosen independently - with probability = `probability`. - - "choice": Pick exactly `count` items from the matched set - (randomly sampled). + - "random": Each matched entity is chosen with probability=`probability`. + - "choice": Pick exactly `count` items (random sample). - "all": Select every matched entity. probability (float): - Probability in [0,1], used only if `rule_type="random"`. + Probability in [0,1], used if `rule_type="random"`. count (int): - Number of matched entities to pick, used only if `rule_type="choice"`. + Number of entities to pick if `rule_type="choice"`. """ conditions: List[FailureCondition] = field(default_factory=list) @@ -73,29 +62,33 @@ class FailureRule: def __post_init__(self) -> None: """ - Validate certain fields after initialization. + Validate the probability if rule_type is 'random'. """ if self.rule_type == "random": - if not (0.0 <= self.probability <= 1.0): + if not 0.0 <= self.probability <= 1.0: raise ValueError( - f"probability={self.probability} must be within [0,1] for rule_type='random'." + f"probability={self.probability} must be within [0,1] " + f"for rule_type='random'." ) @dataclass class FailurePolicy: """ - A container for multiple FailureRules and arbitrary metadata in `attrs`. + A container for multiple FailureRules plus optional metadata in `attrs`. - The method :meth:`apply_failures` merges nodes and links into a single - dictionary (by their unique ID), then applies each rule in turn. The final - result is the union of all failures from each rule. + The main entry point is `apply_failures`, which: + 1) Merges all nodes and links into a single entity dictionary. + 2) Applies each FailureRule, collecting a set of failed entity IDs. + 3) Optionally expands failures to include entities sharing a + 'shared_risk_group' with any entity that failed. Attributes: rules (List[FailureRule]): A list of FailureRules to apply. attrs (Dict[str, Any]): Arbitrary metadata about this policy (e.g. "name", "description"). + If `fail_shared_risk_groups=True`, then shared-risk expansion is used. """ rules: List[FailureRule] = field(default_factory=list) @@ -107,28 +100,29 @@ def apply_failures( links: Dict[str, Dict[str, Any]], ) -> List[str]: """ - Identify which entities (nodes or links) fail, given the defined rules. - Returns a combined list (union) of all entity IDs that fail. + Identify which entities fail given the defined rules, then optionally + expand by shared-risk groups. Args: - nodes: A mapping of node_name -> node.attrs (must have "type"="node"). - links: A mapping of link_id -> link.attrs (must have "type"="link"). + nodes: Dict[node_name, node_attributes]. Must have 'type'="node". + links: Dict[link_id, link_attributes]. Must have 'type'="link". Returns: - A list of failed entity IDs (node names or link IDs). + A list of failed entity IDs (union of all rule matches). """ - # Merge nodes and links into a single map of entity_id -> entity_attrs - # Example: { "SEA": {...}, "SEA-DEN-xxx": {...} } all_entities = {**nodes, **links} - failed_entities = set() - # Apply each rule, union all selected entities + # 1) Collect matched failures from each rule for rule in self.rules: matched = self._match_entities(all_entities, rule.conditions, rule.logic) selected = self._select_entities(matched, all_entities, rule) failed_entities.update(selected) + # 2) Optionally expand failures by shared-risk group + if self.attrs.get("fail_shared_risk_groups", False): + self._expand_shared_risk_groups(failed_entities, all_entities) + return list(failed_entities) def _match_entities( @@ -138,17 +132,24 @@ def _match_entities( logic: str, ) -> List[str]: """ - Find which entity IDs satisfy the given conditions - combined by 'and'/'or' logic (or 'any' to skip checks). + Return all entity IDs matching the given conditions based on 'and'/'or'/'any' logic. Args: - all_entities: Mapping of entity_id -> attribute dict. - conditions: List of FailureCondition to apply. + all_entities: {entity_id: attributes}. + conditions: List of FailureCondition to evaluate. logic: "and", "or", or "any". Returns: - A list of entity IDs that match according to the logic. + A list of matching entity IDs. """ + if logic == "any": + # Skip condition checks; everything matches. + return list(all_entities.keys()) + + if not conditions: + # If zero conditions, we match nothing unless logic='any'. + return [] + matched = [] for entity_id, attr_dict in all_entities.items(): if self._evaluate_conditions(attr_dict, conditions, logic): @@ -162,32 +163,22 @@ def _evaluate_conditions( logic: str, ) -> bool: """ - Check if the given entity meets all or any of the conditions, or if logic='any'. + Evaluate multiple conditions on a single entity. All or any condition(s) + must pass, depending on 'logic'. Args: - entity_attrs: Attributes dict for one entity (node or link). - conditions: List of FailureCondition. - logic: "and", "or", or "any". + entity_attrs: Attribute dict for one entity. + conditions: List of FailureCondition to test. + logic: "and" or "or". Returns: True if conditions pass, else False. """ - if logic == "any": - # 'any' means skip condition checks and always match - return True - if not conditions: - # If we have zero conditions, we treat this as no match unless logic='any' - return False + if logic not in ("and", "or"): + raise ValueError(f"Unsupported logic: {logic}") - # Evaluate each condition results = [_evaluate_condition(entity_attrs, c) for c in conditions] - - if logic == "and": - return all(results) - elif logic == "or": - return any(results) - else: - raise ValueError(f"Unsupported logic: {logic}") + return all(results) if logic == "and" else any(results) @staticmethod def _select_entities( @@ -196,44 +187,78 @@ def _select_entities( rule: FailureRule, ) -> List[str]: """ - From the matched set, pick which entities fail according to rule_type. + From the matched IDs, pick which entities fail under the given rule_type. Args: - entity_ids: IDs that matched the rule's conditions. - all_entities: Full entity dictionary (for potential future use). - rule: The FailureRule specifying random/choice/all selection. + entity_ids: Matched entity IDs from _match_entities. + all_entities: Full entity map (unused now, but available if needed). + rule: The FailureRule specifying 'random', 'choice', or 'all'. Returns: - The final list of entity IDs that fail under this rule. + A list of selected entity IDs to fail. """ + if not entity_ids: + return [] + if rule.rule_type == "random": - # Each entity is chosen with probability=rule.probability - return [ent_id for ent_id in entity_ids if random() < rule.probability] + return [eid for eid in entity_ids if random() < rule.probability] elif rule.rule_type == "choice": - # Sample exactly 'count' from the matched set (or fewer if matched < count) count = min(rule.count, len(entity_ids)) - # Use sorted(...) for deterministic results return sample(sorted(entity_ids), k=count) elif rule.rule_type == "all": return entity_ids else: raise ValueError(f"Unsupported rule_type: {rule.rule_type}") + def _expand_shared_risk_groups( + self, failed_entities: set[str], all_entities: Dict[str, Dict[str, Any]] + ) -> None: + """ + Expand the 'failed_entities' set so that if an entity has + shared_risk_group=X, all other entities with the same group also fail. + + This is done iteratively until no new failures are found. + + Args: + failed_entities: Set of entity IDs already marked as failed. + all_entities: Map of entity_id -> attributes (which may contain 'shared_risk_group'). + """ + # Pre-compute SRG -> entity IDs mapping for efficiency + srg_map = defaultdict(set) + for eid, attrs in all_entities.items(): + srg = attrs.get("shared_risk_group") + if srg: + srg_map[srg].add(eid) + + queue = deque(failed_entities) + while queue: + current = queue.popleft() + current_srg = all_entities[current].get("shared_risk_group") + if not current_srg: + continue + + # All entities in the same SRG should fail + for other_eid in srg_map[current_srg]: + if other_eid not in failed_entities: + failed_entities.add(other_eid) + queue.append(other_eid) + def _evaluate_condition(entity: Dict[str, Any], cond: FailureCondition) -> bool: """ - Evaluate one FailureCondition (attr, operator, value) against entity attributes. + Evaluate a single FailureCondition against an entity's attributes. + + Operators supported: + ==, !=, <, <=, >, >=, contains, not_contains, any_value, no_value Args: - entity: The entity's attributes (e.g., node.attrs or link.attrs). + entity: Entity attributes (e.g., node.attrs or link.attrs). cond: FailureCondition specifying (attr, operator, value). Returns: True if the condition passes, else False. - - Raises: - ValueError: If the operator is not recognized. """ + has_attr = cond.attr in entity derived_value = entity.get(cond.attr, None) op = cond.operator @@ -242,12 +267,26 @@ def _evaluate_condition(entity: Dict[str, Any], cond: FailureCondition) -> bool: elif op == "!=": return derived_value != cond.value elif op == "<": - return derived_value < cond.value + return (derived_value is not None) and (derived_value < cond.value) elif op == "<=": - return derived_value <= cond.value + return (derived_value is not None) and (derived_value <= cond.value) elif op == ">": - return derived_value > cond.value + return (derived_value is not None) and (derived_value > cond.value) elif op == ">=": - return derived_value >= cond.value + return (derived_value is not None) and (derived_value >= cond.value) + elif op == "contains": + if derived_value is None: + return False + return cond.value in derived_value + elif op == "not_contains": + if derived_value is None: + return True + return cond.value not in derived_value + elif op == "any_value": + # Pass if the attribute key exists, even if the value is None + return has_attr + elif op == "no_value": + # Pass if the attribute key is missing or the value is None + return (not has_attr) or (derived_value is None) else: raise ValueError(f"Unsupported operator: {op}") diff --git a/tests/test_failure_policy.py b/tests/test_failure_policy.py index 768d90a..99f3526 100644 --- a/tests/test_failure_policy.py +++ b/tests/test_failure_policy.py @@ -316,3 +316,120 @@ def test_operator_condition_le_ge(): e3 = {"capacity": 110} assert _evaluate_condition(e3, cond_le) is False assert _evaluate_condition(e3, cond_ge) is True + + +def test_operator_contains_not_contains(): + """ + Verify that 'contains' and 'not_contains' operators work with string or list attributes. + """ + rule_contains = FailureRule( + conditions=[FailureCondition(attr="tags", operator="contains", value="foo")], + logic="and", + rule_type="all", + ) + rule_not_contains = FailureRule( + conditions=[ + FailureCondition(attr="tags", operator="not_contains", value="bar") + ], + logic="and", + rule_type="all", + ) + + # Entities with a 'tags' attribute + nodes = { + "N1": {"type": "node", "tags": ["foo", "bar"]}, # contains 'foo' + "N2": {"type": "node", "tags": ["baz", "qux"]}, # doesn't contain 'foo' + "N3": {"type": "node", "tags": "foobar"}, # string containing 'foo' + "N4": {"type": "node", "tags": ""}, # string not containing anything + } + links = {} + + # Test the 'contains' rule + failed_contains = FailurePolicy(rules=[rule_contains]).apply_failures(nodes, links) + # N1 has 'foo' in list, N3 has 'foo' in string "foobar" + assert set(failed_contains) == {"N1", "N3"} + + # Test the 'not_contains' rule + failed_not_contains = FailurePolicy(rules=[rule_not_contains]).apply_failures( + nodes, links + ) + # N2 => doesn't have 'bar', N4 => empty string, also doesn't have 'bar' + assert set(failed_not_contains) == {"N2", "N4"} + + +def test_operator_any_value_no_value(): + """ + Verify that 'any_value' matches entities that have the attribute (non-None), + and 'no_value' matches entities that do not have that attribute or None. + """ + any_rule = FailureRule( + conditions=[ + FailureCondition(attr="capacity", operator="any_value", value=None) + ], + logic="and", + rule_type="all", + ) + none_rule = FailureRule( + conditions=[FailureCondition(attr="capacity", operator="no_value", value=None)], + logic="and", + rule_type="all", + ) + + nodes = { + "N1": {"type": "node", "capacity": 100}, # has capacity + "N2": {"type": "node"}, # no 'capacity' attr + "N3": {"type": "node", "capacity": None}, # capacity is explicitly None + } + links = {} + + failed_any = FailurePolicy(rules=[any_rule]).apply_failures(nodes, links) + # N1 has capacity=100, N3 has capacity=None (still present, even if None) + assert set(failed_any) == {"N1", "N3"} + + failed_none = FailurePolicy(rules=[none_rule]).apply_failures(nodes, links) + # N2 has no 'capacity' attribute. N3 has capacity=None => attribute is present but None => also matches + # This depends on your interpretation of "no_value", but typically "no_value" means derived_value is None. + # We do see from the code that derived_value= entity.get(cond.attr, None) => if the key is missing or is None, we pass + assert set(failed_none) == {"N2", "N3"} + + +def test_shared_risk_groups_expansion(): + """ + Verify that if fail_shared_risk_groups=True is set, any failed entity + causes all entities in the same shared_risk_group to fail. + """ + # This rule matches link type=link, then chooses exactly 1 + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="link")], + logic="and", + rule_type="choice", + count=1, + ) + policy = FailurePolicy( + rules=[rule], + attrs={"fail_shared_risk_groups": True}, + ) + + nodes = { + "N1": {"type": "node"}, + "N2": {"type": "node"}, + } + # Suppose L1 and L2 are in SRG1, L3 in SRG2 + links = { + "L1": {"type": "link", "shared_risk_group": "SRG1"}, + "L2": {"type": "link", "shared_risk_group": "SRG1"}, + "L3": {"type": "link", "shared_risk_group": "SRG2"}, + } + + # Mock picking "L1" + with patch("ngraph.failure_policy.sample", return_value=["L1"]): + failed = policy.apply_failures(nodes, links) + + # L1 was chosen, which triggers any others in SRG1 => L2 also fails. + # L3 is unaffected. + assert set(failed) == {"L1", "L2"} + + # If we pick "L3", L1 & L2 remain healthy + with patch("ngraph.failure_policy.sample", return_value=["L3"]): + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"L3"} From f7995290c32f872a14a4158f57c8eee4756d7903 Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Fri, 7 Mar 2025 00:59:25 +0000 Subject: [PATCH 2/2] adding node overrides --- ngraph/blueprints.py | 261 +++++++++++++++++++---------- ngraph/scenario.py | 4 +- tests/scenarios/scenario_3.yaml | 33 +++- tests/scenarios/test_scenario_3.py | 58 ++++++- tests/test_blueprints_helpers.py | 128 ++++++++++++++ 5 files changed, 388 insertions(+), 96 deletions(-) diff --git a/ngraph/blueprints.py b/ngraph/blueprints.py index 1c13d2f..b44f3d3 100644 --- a/ngraph/blueprints.py +++ b/ngraph/blueprints.py @@ -18,7 +18,7 @@ class Blueprint: Attributes: name (str): Unique identifier of this blueprint. groups (Dict[str, Any]): A mapping of group_name -> group definition - (e.g. node_count, name_template). + (e.g. node_count, name_template, node_attrs). adjacency (List[Dict[str, Any]]): A list of adjacency dictionaries describing how groups are linked. """ @@ -51,14 +51,15 @@ def expand_network_dsl(data: Dict[str, Any]) -> Network: Overall flow: 1) Parse "blueprints" into Blueprint objects. - 2) Build a new Network from "network" metadata (name, version, etc.). - 3) Expand 'network["groups"]': + 2) Build a new Network from "network" metadata (e.g. name, version). + 3) Expand 'network["groups"]'. - If a group references a blueprint, incorporate that blueprint's subgroups. - Otherwise, directly create nodes (e.g., node_count). 4) Process any direct node definitions. 5) Expand adjacency definitions in 'network["adjacency"]'. 6) Process any direct link definitions. 7) Process link overrides. + 8) Process node overrides. Args: data (Dict[str, Any]): The YAML-parsed dictionary containing @@ -87,7 +88,7 @@ def expand_network_dsl(data: Dict[str, Any]) -> Network: # Create a context ctx = DSLExpansionContext(blueprints=blueprint_map, network=net) - # 3) Expand top-level groups + # 3) Expand top-level groups (blueprint usage or direct node groups) for group_name, group_def in network_data.get("groups", {}).items(): _expand_group( ctx, @@ -110,26 +111,10 @@ def expand_network_dsl(data: Dict[str, Any]) -> Network: # 7) Process link overrides _process_link_overrides(ctx.network, network_data) - return net - - -def _process_link_overrides(network: Network, network_data: Dict[str, Any]) -> None: - """ - Processes the 'link_overrides' section of the network DSL, updating - existing links with new parameters. + # 8) Process node overrides + _process_node_overrides(ctx.network, network_data) - Args: - network (Network): The Network whose links will be updated. - network_data (Dict[str, Any]): The overall DSL data for the 'network'. - Expected to contain 'link_overrides' as a list of dicts, each with - 'source', 'target', and 'link_params'. - """ - link_overrides = network_data.get("link_overrides", []) - for link_override in link_overrides: - source = link_override["source"] - target = link_override["target"] - link_params = link_override["link_params"] - _update_links(network, source, target, link_params) + return net def _expand_group( @@ -143,7 +128,10 @@ def _expand_group( """ Expands a single group definition into either: - Another blueprint's subgroups, or - - A direct node group (node_count, name_template). + - A direct node group (node_count, name_template, node_attrs). + + If the group references 'use_blueprint', we expand that blueprint's groups + under the current hierarchy path. Otherwise, we create nodes directly. Args: ctx (DSLExpansionContext): The context containing all blueprint info @@ -155,7 +143,6 @@ def _expand_group( blueprint_expansion (bool): Indicates whether we are expanding within a blueprint context or not. """ - # Construct the effective path by appending group_name if parent_path is non-empty if parent_path: effective_path = f"{parent_path}/{group_name}" else: @@ -195,14 +182,17 @@ def _expand_group( # It's a direct node group node_count = group_def.get("node_count", 1) name_template = group_def.get("name_template", f"{group_name}-{{node_num}}") + node_attrs = group_def.get("node_attrs", {}) for i in range(1, node_count + 1): label = name_template.format(node_num=i) node_name = f"{effective_path}/{label}" if effective_path else label node = Node(name=node_name) + # Merge any extra attributes if "coords" in group_def: node.attrs["coords"] = group_def["coords"] + node.attrs.update(node_attrs) # Apply bulk attributes node.attrs.setdefault("type", "node") ctx.network.add_node(node) @@ -251,7 +241,7 @@ def _expand_adjacency( pattern = adj_def.get("pattern", "mesh") link_params = adj_def.get("link_params", {}) - # Strip leading '/' from source/target paths + # Convert to an absolute or relative path source_path = _join_paths("", source_path_raw) target_path = _join_paths("", target_path_raw) @@ -269,17 +259,17 @@ def _expand_adjacency_pattern( Generates Link objects for the chosen adjacency pattern among matched nodes. Supported Patterns: - * "mesh": Cross-connect every node from source side to every node on target side, + * "mesh": Connect every node from source side to every node on target side, skipping self-loops, and deduplicating reversed pairs. - * "one_to_one": Pair each source node with exactly one target node, supporting - wrap-around if one side is an integer multiple of the other. - Also skips self-loops. + * "one_to_one": Pair each source node with exactly one target node (wrap-around). + * "ring": (Example pattern) For demonstration, connect nodes in a ring among + the union of source + target sets (ignores directionality). Args: ctx (DSLExpansionContext): The context containing the target network. - source_path (str): The path pattern that identifies the source node group(s). - target_path (str): The path pattern that identifies the target node group(s). - pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one"). + source_path (str): The path pattern identifying the source node group(s). + target_path (str): The path pattern identifying the target node group(s). + pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one", "ring"). link_params (Dict[str, Any]): Additional link parameters (capacity, cost, attrs). """ source_node_groups = ctx.network.select_node_groups_by_path(source_path) @@ -288,6 +278,7 @@ def _expand_adjacency_pattern( source_nodes = [node for _, nodes in source_node_groups.items() for node in nodes] target_nodes = [node for _, nodes in target_node_groups.items() for node in nodes] + # If either list is empty, no links to create if not source_nodes or not target_nodes: return @@ -296,7 +287,6 @@ def _expand_adjacency_pattern( if pattern == "mesh": for sn in source_nodes: for tn in target_nodes: - # Skip self-loops if sn.name == tn.name: continue pair = tuple(sorted((sn.name, tn.name))) @@ -309,13 +299,13 @@ def _expand_adjacency_pattern( t_count = len(target_nodes) bigger, smaller = max(s_count, t_count), min(s_count, t_count) + # Basic check for wrap-around scenario if bigger % smaller != 0: raise ValueError( f"one_to_one pattern requires either equal node counts " f"or a valid wrap-around. Got {s_count} vs {t_count}." ) - # total 'bigger' connections for i in range(bigger): if s_count >= t_count: sn = source_nodes[i].name @@ -324,7 +314,6 @@ def _expand_adjacency_pattern( sn = source_nodes[i % s_count].name tn = target_nodes[i].name - # Skip self-loops if sn == tn: continue @@ -363,6 +352,121 @@ def _create_link( net.add_link(link) +def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None: + """ + Processes direct node definitions (network_data["nodes"]) and adds them to the network + if they do not already exist. + + Example: + nodes: + my_node: + coords: [10, 20] + hw_type: "X100" + + Args: + net (Network): The network to which nodes are added. + network_data (Dict[str, Any]): DSL data containing a "nodes" dict + keyed by node name -> attributes. + """ + for node_name, node_attrs in network_data.get("nodes", {}).items(): + if node_name not in net.nodes: + new_node = Node(name=node_name, attrs=node_attrs or {}) + new_node.attrs.setdefault("type", "node") + net.add_node(new_node) + + +def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None: + """ + Processes direct link definitions (network_data["links"]) and adds them to the network. + + Example: + links: + - source: A + target: B + link_params: + capacity: 100 + cost: 2 + attrs: + color: "blue" + + Args: + net (Network): The network to which links are added. + network_data (Dict[str, Any]): DSL data containing a "links" list, + each item must have "source", "target", and optionally "link_params". + """ + existing_node_names = set(net.nodes.keys()) + for link_info in network_data.get("links", []): + source = link_info["source"] + target = link_info["target"] + if source not in existing_node_names or target not in existing_node_names: + raise ValueError(f"Link references unknown node(s): {source}, {target}.") + if source == target: + raise ValueError(f"Link cannot have the same source and target: {source}") + link_params = link_info.get("link_params", {}) + link = Link( + source=source, + target=target, + capacity=link_params.get("capacity", 1.0), + cost=link_params.get("cost", 1.0), + attrs=link_params.get("attrs", {}), + ) + net.add_link(link) + + +def _process_link_overrides(network: Network, network_data: Dict[str, Any]) -> None: + """ + Processes the 'link_overrides' section of the network DSL, updating + existing links with new parameters. + + Example: + link_overrides: + - source: "/region1/*" + target: "/region2/*" + link_params: + capacity: 200 + attrs: + shared_risk_group: "SRG1" + + Args: + network (Network): The Network whose links will be updated. + network_data (Dict[str, Any]): The overall DSL data for the 'network'. + Expected to contain 'link_overrides' as a list of dicts, each with + 'source', 'target', and 'link_params'. + """ + link_overrides = network_data.get("link_overrides", []) + for link_override in link_overrides: + source = link_override["source"] + target = link_override["target"] + link_params = link_override["link_params"] + any_direction = link_override.get("any_direction", True) + _update_links(network, source, target, link_params, any_direction) + + +def _process_node_overrides(net: Network, network_data: Dict[str, Any]) -> None: + """ + Processes the 'node_overrides' section of the network DSL, updating + existing nodes with new attributes in bulk. + + Example: + node_overrides: + - path: "/region1/spine*" + attrs: + hw_type: "DellX" + shared_risk_group: "SRG2" + + Args: + net (Network): The Network whose nodes will be updated. + network_data (Dict[str, Any]): The overall DSL data for the 'network'. + Expected to contain 'node_overrides' as a list of dicts, each with + 'path' and 'attrs'. + """ + node_overrides = network_data.get("node_overrides", []) + for override in node_overrides: + path = override["path"] + attrs_to_set = override["attrs"] + _update_nodes(net, path, attrs_to_set) + + def _update_links( net: Network, source: str, @@ -374,6 +478,9 @@ def _update_links( Update all Link objects between nodes matching 'source' and 'target' paths with new parameters. + If any_direction=True, both (source->target) and (target->source) links + are updated. + Args: net (Network): The network whose links should be updated. source (str): A path pattern identifying source node group(s). @@ -392,21 +499,37 @@ def _update_links( } for link in net.links.values(): - if link.source in source_nodes and link.target in target_nodes: - link.capacity = link_params.get("capacity", link.capacity) - link.cost = link_params.get("cost", link.cost) - link.attrs.update(link_params.get("attrs", {})) - - if ( + forward_match = link.source in source_nodes and link.target in target_nodes + reverse_match = ( any_direction and link.source in target_nodes and link.target in source_nodes - ): + ) + if forward_match or reverse_match: link.capacity = link_params.get("capacity", link.capacity) link.cost = link_params.get("cost", link.cost) link.attrs.update(link_params.get("attrs", {})) +def _update_nodes( + net: Network, + path: str, + node_attrs: Dict[str, Any], +) -> None: + """ + Updates attributes on all nodes matching a given path pattern. + + Args: + net (Network): The network containing nodes. + path (str): A path pattern identifying which node group(s) to modify. + node_attrs (Dict[str, Any]): A dictionary of new attributes to set/merge. + """ + node_groups = net.select_node_groups_by_path(path) + for _, nodes in node_groups.items(): + for node in nodes: + node.attrs.update(node_attrs) + + def _apply_parameters( subgroup_name: str, subgroup_def: Dict[str, Any], params_overrides: Dict[str, Any] ) -> Dict[str, Any]: @@ -420,7 +543,8 @@ def _apply_parameters( Args: subgroup_name (str): Name of the subgroup in the blueprint (e.g. 'spine'). subgroup_def (Dict[str, Any]): The default definition of the subgroup. - params_overrides (Dict[str, Any]): Overrides in the form of { 'spine.node_count': }. + params_overrides (Dict[str, Any]): Overrides in the form of + {'spine.node_count': 6, 'spine.node_attrs.hw_type': 'Dell'}. Returns: Dict[str, Any]: A copy of subgroup_def with parameter overrides applied. @@ -437,8 +561,8 @@ def _apply_parameters( def _join_paths(parent_path: str, rel_path: str) -> str: """ Joins two path segments according to NetGraph's DSL conventions: - - If rel_path starts with '/', remove the leading slash and treat it - as a relative path appended to parent_path (if present). + - If rel_path starts with '/', strip the leading slash and treat it + as appended to parent_path if parent_path is not empty. - Otherwise, simply append rel_path to parent_path if parent_path is non-empty. Args: @@ -457,48 +581,3 @@ def _join_paths(parent_path: str, rel_path: str) -> str: if parent_path: return f"{parent_path}/{rel_path}" return rel_path - - -def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None: - """ - Processes direct node definitions (network_data["nodes"]) and adds them to the network - if they do not already exist. - - Args: - net (Network): The network to which nodes are added. - network_data (Dict[str, Any]): DSL data containing a "nodes" dict - keyed by node name -> attributes. - """ - for node_name, node_attrs in network_data.get("nodes", {}).items(): - if node_name not in net.nodes: - new_node = Node(name=node_name, attrs=node_attrs or {}) - new_node.attrs.setdefault("type", "node") - net.add_node(new_node) - - -def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None: - """ - Processes direct link definitions (network_data["links"]) and adds them to the network. - - Args: - net (Network): The network to which links are added. - network_data (Dict[str, Any]): DSL data containing a "links" list, - each item must have "source", "target", and optionally "link_params". - """ - existing_node_names = set(net.nodes.keys()) - for link_info in network_data.get("links", []): - source = link_info["source"] - target = link_info["target"] - if source not in existing_node_names or target not in existing_node_names: - raise ValueError(f"Link references unknown node(s): {source}, {target}.") - if source == target: - raise ValueError(f"Link cannot have the same source and target: {source}") - link_params = link_info.get("link_params", {}) - link = Link( - source=source, - target=target, - capacity=link_params.get("capacity", 1.0), - cost=link_params.get("cost", 1.0), - attrs=link_params.get("attrs", {}), - ) - net.add_link(link) diff --git a/ngraph/scenario.py b/ngraph/scenario.py index b4484e0..448bc2c 100644 --- a/ngraph/scenario.py +++ b/ngraph/scenario.py @@ -79,7 +79,6 @@ def from_yaml(cls, yaml_str: str) -> Scenario: raise ValueError("The provided YAML must map to a dictionary at top-level.") # 1) Build the network using blueprint expansion logic - # This handles both "blueprints" and "network" sections if present. network = expand_network_dsl(data) # 2) Build the multi-rule failure policy @@ -110,6 +109,7 @@ def _build_failure_policy(fp_data: Dict[str, Any]) -> FailurePolicy: failure_policy: name: "anySingleLink" description: "Test single-link failures." + fail_shared_risk_groups: true rules: - conditions: - attr: "type" @@ -155,7 +155,7 @@ def _build_failure_policy(fp_data: Dict[str, Any]) -> FailurePolicy: @staticmethod def _build_workflow_steps( - workflow_data: List[Dict[str, Any]] + workflow_data: List[Dict[str, Any]], ) -> List[WorkflowStep]: """ Converts workflow step dictionaries into instantiated WorkflowStep objects. diff --git a/tests/scenarios/scenario_3.yaml b/tests/scenarios/scenario_3.yaml index b3df060..18f8063 100644 --- a/tests/scenarios/scenario_3.yaml +++ b/tests/scenarios/scenario_3.yaml @@ -67,6 +67,37 @@ network: capacity: 1 cost: 1 + # Demonstrates setting a shared_risk_group and optic_type on all spine to spine links. + - source: .*/spine/.* + target: .*/spine/.* + any_direction: True + link_params: + attrs: + shared_risk_group: "SpineSRG" + optic_type: "400G-LR4" + + # Example node overrides that assign SRGs and hardware types + node_overrides: + - path: my_clos1/b1/t1 + attrs: + shared_risk_group: "clos1-b1t1-SRG" + hw_type: "LeafHW-A" + + - path: my_clos2/b2/t1 + attrs: + shared_risk_group: "clos2-b2t1-SRG" + hw_type: "LeafHW-B" + + - path: my_clos1/spine/t3.* + attrs: + shared_risk_group: "clos1-spine-SRG" + hw_type: "SpineHW" + + - path: my_clos2/spine/t3.* + attrs: + shared_risk_group: "clos2-spine-SRG" + hw_type: "SpineHW" + workflow: - step_type: BuildGraph name: build_graph @@ -87,4 +118,4 @@ workflow: mode: combine probe_reverse: True shortest_path: True - flow_placement: EQUAL_BALANCED \ No newline at end of file + flow_placement: EQUAL_BALANCED diff --git a/tests/scenarios/test_scenario_3.py b/tests/scenarios/test_scenario_3.py index 0d8af32..58bfbdd 100644 --- a/tests/scenarios/test_scenario_3.py +++ b/tests/scenarios/test_scenario_3.py @@ -16,7 +16,9 @@ def test_scenario_3_build_graph_and_capacity_probe() -> None: 2) Presence of certain expanded node names. 3) No traffic demands in this scenario. 4) An empty failure policy by default. - 5) The max flow from my_clos1/b -> my_clos2/b (and reverse) is as expected. + 5) The max flow from my_clos1/b -> my_clos2/b (and reverse) is as expected for + the two capacity probe steps (PROPORTIONAL vs. EQUAL_BALANCED). + 6) That node overrides and link overrides have been applied (e.g. SRG, hw_type). """ # 1) Load the YAML file scenario_path = Path(__file__).parent / "scenario_3.yaml" @@ -68,6 +70,58 @@ def test_scenario_3_build_graph_and_capacity_probe() -> None: "my_clos2/spine/t3-16" in scenario.network.nodes ), "Missing expected node 'my_clos2/spine/t3-16' in expanded blueprint." + net = scenario.network + + # (A) Node attribute checks from node_overrides: + # For "my_clos1/b1/t1/t1-1", we expect hw_type="LeafHW-A" and SRG="clos1-b1t1-SRG" + node_a1 = net.nodes["my_clos1/b1/t1/t1-1"] + assert ( + node_a1.attrs.get("hw_type") == "LeafHW-A" + ), "Expected hw_type=LeafHW-A for 'my_clos1/b1/t1/t1-1', but not found." + assert ( + node_a1.attrs.get("shared_risk_group") == "clos1-b1t1-SRG" + ), "Expected shared_risk_group=clos1-b1t1-SRG for 'my_clos1/b1/t1/t1-1'." + + # For "my_clos2/b2/t1/t1-1", check hw_type="LeafHW-B" and SRG="clos2-b2t1-SRG" + node_b2 = net.nodes["my_clos2/b2/t1/t1-1"] + assert node_b2.attrs.get("hw_type") == "LeafHW-B" + assert node_b2.attrs.get("shared_risk_group") == "clos2-b2t1-SRG" + + # For "my_clos1/spine/t3-1", check hw_type="SpineHW" and SRG="clos1-spine-SRG" + node_spine1 = net.nodes["my_clos1/spine/t3-1"] + assert node_spine1.attrs.get("hw_type") == "SpineHW" + assert node_spine1.attrs.get("shared_risk_group") == "clos1-spine-SRG" + + # (B) Link attribute checks from link_overrides: + # The override sets capacity=1 for "my_clos1/spine/t3-1" <-> "my_clos2/spine/t3-1" + # Confirm link capacity=1 + link_id_1 = net.find_links( + "my_clos1/spine/t3-1$", + "my_clos2/spine/t3-1$", + ) + # find_links should return a list of Link objects (bidirectional included). + assert link_id_1, "Override link (t3-1) not found." + for link_obj in link_id_1: + assert link_obj.capacity == 1, ( + "Expected capacity=1 on overridden link 'my_clos1/spine/t3-1' <-> " + "'my_clos2/spine/t3-1'" + ) + + # Another override sets shared_risk_group="SpineSRG" + optic_type="400G-LR4" on all spine-spine links + # We'll check a random spine pair, e.g. "t3-2" + link_id_2 = net.find_links( + "my_clos1/spine/t3-2$", + "my_clos2/spine/t3-2$", + ) + assert link_id_2, "Spine link (t3-2) not found for override check." + for link_obj in link_id_2: + assert ( + link_obj.attrs.get("shared_risk_group") == "SpineSRG" + ), "Expected SRG=SpineSRG on spine<->spine link." + assert ( + link_obj.attrs.get("optic_type") == "400G-LR4" + ), "Expected optic_type=400G-LR4 on spine<->spine link." + # 10) The capacity probe step computed forward and reverse flows in 'combine' mode # with PROPORTIONAL flow placement. flow_result_label_fwd = "max_flow:[my_clos1/b.*/t1 -> my_clos2/b.*/t1]" @@ -81,7 +135,7 @@ def test_scenario_3_build_graph_and_capacity_probe() -> None: # 11) Assert the expected flows # The main bottleneck is the 16 spine-to-spine links of capacity=2 => total 32 # (same in both forward and reverse). - # However, one link is overriden to capacity=1, so, with PROPORTIONAL flow placement, + # However, one link is overridden to capacity=1, so, with PROPORTIONAL flow placement, # the max flow is 31. expected_flow = 31.0 assert forward_flow == expected_flow, ( diff --git a/tests/test_blueprints_helpers.py b/tests/test_blueprints_helpers.py index 5d33373..d7a0901 100644 --- a/tests/test_blueprints_helpers.py +++ b/tests/test_blueprints_helpers.py @@ -13,6 +13,10 @@ _expand_blueprint_adjacency, _expand_adjacency, _expand_group, + _update_nodes, + _update_links, + _process_node_overrides, + _process_link_overrides, ) @@ -356,3 +360,127 @@ def test_expand_group_blueprint(): link = next(iter(ctx_net.links.values())) sources_targets = {link.source, link.target} assert sources_targets == {"Main/leaf/leaf-1", "Main/leaf/leaf-2"} + + +def test_update_nodes(): + """ + Tests _update_nodes to ensure it updates matching node attributes in bulk. + """ + net = Network() + net.add_node(Node("N1", attrs={"foo": "old"})) + net.add_node(Node("N2", attrs={"foo": "old"})) + net.add_node(Node("M1", attrs={"foo": "unchanged"})) + + # We only want to update nodes whose path matches "N" + _update_nodes(net, "N", {"hw_type": "X100", "foo": "new"}) + + # N1, N2 should get updated + assert net.nodes["N1"].attrs["hw_type"] == "X100" + assert net.nodes["N1"].attrs["foo"] == "new" + assert net.nodes["N2"].attrs["hw_type"] == "X100" + assert net.nodes["N2"].attrs["foo"] == "new" + + # M1 remains unchanged + assert "hw_type" not in net.nodes["M1"].attrs + assert net.nodes["M1"].attrs["foo"] == "unchanged" + + +def test_update_links(): + """ + Tests _update_links to ensure it updates matching links in bulk. + """ + net = Network() + net.add_node(Node("S1")) + net.add_node(Node("S2")) + net.add_node(Node("T1")) + net.add_node(Node("T2")) + + # Create some links + net.add_link(Link("S1", "T1")) + net.add_link(Link("S2", "T2")) + net.add_link(Link("T1", "S2")) # reversed + + # Update all links from S->T with capacity=999 + _update_links(net, "S", "T", {"capacity": 999}) + + # The link S1->T1 is updated + link_st = [l for l in net.links.values() if l.source == "S1" and l.target == "T1"] + assert link_st[0].capacity == 999 + + link_st2 = [l for l in net.links.values() if l.source == "S2" and l.target == "T2"] + assert link_st2[0].capacity == 999 + + # The reversed link T1->S2 also matches if any_direction is True by default + link_ts = [l for l in net.links.values() if l.source == "T1" and l.target == "S2"] + assert link_ts[0].capacity == 999 + + +def test_process_node_overrides(): + """ + Tests _process_node_overrides to verify node attributes get updated + based on the DSL's node_overrides block. + """ + net = Network() + net.add_node(Node("A/1")) + net.add_node(Node("A/2")) + net.add_node(Node("B/1")) + + network_data = { + "node_overrides": [ + { + "path": "A", # matches "A/1" and "A/2" + "attrs": {"optics_type": "SR4", "shared_risk_group": "SRG1"}, + } + ] + } + _process_node_overrides(net, network_data) + + # "A/1" and "A/2" should be updated + assert net.nodes["A/1"].attrs["optics_type"] == "SR4" + assert net.nodes["A/1"].attrs["shared_risk_group"] == "SRG1" + assert net.nodes["A/2"].attrs["optics_type"] == "SR4" + assert net.nodes["A/2"].attrs["shared_risk_group"] == "SRG1" + + # "B/1" remains unchanged + assert "optics_type" not in net.nodes["B/1"].attrs + assert "shared_risk_group" not in net.nodes["B/1"].attrs + + +def test_process_link_overrides(): + """ + Tests _process_link_overrides to verify link attributes get updated + based on the DSL's link_overrides block. + """ + net = Network() + net.add_node(Node("A/1")) + net.add_node(Node("A/2")) + net.add_node(Node("B/1")) + + net.add_link(Link("A/1", "A/2", attrs={"color": "red"})) + net.add_link(Link("A/1", "B/1")) + + network_data = { + "link_overrides": [ + { + "source": "A/1", + "target": "A/2", + "link_params": {"capacity": 123, "attrs": {"color": "blue"}}, + } + ] + } + + _process_link_overrides(net, network_data) + + # Only the link A/1->A/2 is updated + link1 = [l for l in net.links.values() if l.source == "A/1" and l.target == "A/2"][ + 0 + ] + assert link1.capacity == 123 + assert link1.attrs["color"] == "blue" + + # The other link remains unmodified + link2 = [l for l in net.links.values() if l.source == "A/1" and l.target == "B/1"][ + 0 + ] + assert link2.capacity == 1.0 # default + assert "color" not in link2.attrs