From e9d95aac55f064b97f4ba62fc787fc807962939d Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Sun, 9 Feb 2025 14:38:50 +0000 Subject: [PATCH 1/2] Enhanced flexibility of failure policy by adding rules. --- ngraph/failure_policy.py | 235 +++++++++++++++++++- ngraph/network.py | 46 ++-- ngraph/scenario.py | 246 +++++++++++++++------ tests/scenarios/scenario_1.yaml | 316 ++++++++------------------- tests/scenarios/test_scenario_1.py | 62 ++++-- tests/test_failure_policy.py | 332 ++++++++++++++++++++++++----- tests/test_network.py | 126 ++++++----- tests/test_scenario.py | 98 ++++++--- 8 files changed, 1001 insertions(+), 460 deletions(-) diff --git a/ngraph/failure_policy.py b/ngraph/failure_policy.py index 8cc6efd..302550b 100644 --- a/ngraph/failure_policy.py +++ b/ngraph/failure_policy.py @@ -1,18 +1,235 @@ from dataclasses import dataclass, field -from random import random +from typing import Any, Dict, List, Literal +from random import random, sample -@dataclass(slots=True) +@dataclass +class FailureCondition: + """ + A single condition for matching an entity's attribute with an operator and value. + + Example usage: + + .. code-block:: yaml + + conditions: + - attr: "capacity" + operator: "<" + value: 100 + + :param attr: + The name of the attribute to inspect, e.g. "type", "capacity". + :param operator: + The comparison operator: "==", "!=", "<", "<=", ">", ">=". + :param value: + The value to compare against, e.g. "node", 100, True, etc. + """ + + attr: str # e.g. "type", "capacity", "region" + operator: str # "==", "!=", "<", "<=", ">", ">=" + value: Any # e.g. "node", 100, "east_coast" + + +@dataclass +class FailureRule: + """ + A single rule defining how to match entities and then select them for failure. + + - conditions: list of conditions + - logic: how to combine conditions ("and", "or", "any") + - rule_type: how to pick from matched entities ("random", "choice", "all") + - probability: used by "random" (a float in [0,1]) + - count: used by "choice" (e.g. pick 2) + + :param conditions: + A list of :class:`FailureCondition` to filter matching entities. + :param logic: + How to combine the conditions for matching: "and", "or", or "any". + - "and": all conditions must be true + - "or": at least one condition is true + - "any": skip condition checks; everything is matched + :param rule_type: + The selection strategy. One of: + - "random": pick each matched entity with `probability` + - "choice": pick exactly `count` from matched + - "all": pick all matched + :param probability: + Probability of selecting any matched entity (used only if rule_type="random"). + :param count: + Number of matched entities to pick (used only if rule_type="choice"). + """ + + conditions: List[FailureCondition] = field(default_factory=list) + logic: Literal["and", "or", "any"] = "and" + rule_type: Literal["random", "choice", "all"] = "all" + probability: float = 1.0 + count: int = 1 + + +@dataclass class FailurePolicy: """ - Mapping from element tag to failure probability. + A container for multiple FailureRules and arbitrary metadata in `attrs`. + + The method :meth:`apply_failures` merges nodes and links into a single + dictionary (by their unique ID), and then applies each rule in turn, + building a union of all failed entities. + + :param rules: + A list of :class:`FailureRule` objects to apply. + :param attrs: + A dictionary for storing policy-wide metadata (e.g. "name", "description"). """ - failure_probabilities: dict[str, float] = field(default_factory=dict) - distribution: str = "uniform" + rules: List[FailureRule] = field(default_factory=list) + attrs: Dict[str, Any] = field(default_factory=dict) - def test_failure(self, tag: str) -> bool: - if self.distribution == "uniform": - return random() < self.failure_probabilities.get(tag, 0) + def apply_failures( + self, nodes: Dict[str, Dict[str, Any]], links: Dict[str, Dict[str, Any]] + ) -> List[str]: + """ + Identify which entities (nodes or links) fail according to the + defined rules. + + :param nodes: + A mapping of node_name -> node.attrs, where node.attrs has at least + a "type" = "node". + :param links: + A mapping of link_id -> link.attrs, where link.attrs has at least + a "type" = "link". + :returns: + A list of failed entity IDs. For nodes, that ID is typically the + node's name. For links, it's the link's ID. + """ + # Merge nodes and links into a single map of entity_id -> entity_attrs + # e.g. { "SEA": { "type": "node", ...}, "SEA-DEN-xxx": { "type": "link", ...} } + all_entities = {**nodes, **links} + + failed_entities = set() + + # Evaluate each rule to find matched entities and union them + for rule in self.rules: + matched = self._match_entities(all_entities, rule.conditions, rule.logic) + selected = self._select_entities(matched, all_entities, rule) + failed_entities.update(selected) + + return list(failed_entities) + + def _match_entities( + self, + all_entities: Dict[str, Dict[str, Any]], + conditions: List[FailureCondition], + logic: str, + ) -> List[str]: + """ + Find which entities (by ID) satisfy the given list of conditions + combined by 'and'/'or' logic (or 'any' to skip checks). + + :param all_entities: + Mapping of entity_id -> attribute dict. + :param conditions: + List of :class:`FailureCondition` to apply. + :param logic: + "and", "or", or "any". + :returns: + A list of entity IDs that match. + """ + matched = [] + for entity_id, attr_dict in all_entities.items(): + if self._evaluate_conditions(attr_dict, conditions, logic): + matched.append(entity_id) + return matched + + @staticmethod + def _evaluate_conditions( + entity_attrs: Dict[str, Any], conditions: List[FailureCondition], logic: str + ) -> bool: + """ + Check if the given entity (via entity_attrs) meets all/any of the conditions. + + :param entity_attrs: + The dictionary of attributes for a single entity (node or link). + :param conditions: + A list of conditions to evaluate. + :param logic: + "and" -> all must be true + "or" -> at least one true + "any" -> skip condition checks (always true) + :returns: + True if conditions pass for the specified logic, else False. + """ + if logic == "any": + return True # means "select everything" + if not conditions: + return False # no conditions => no match, unless logic='any' + + results = [] + for cond in conditions: + results.append(_evaluate_condition(entity_attrs, cond)) + + if logic == "and": + return all(results) + elif logic == "or": + return any(results) else: - raise ValueError(f"Unsupported distribution: {self.distribution}") + raise ValueError(f"Unsupported logic: {logic}") + + @staticmethod + def _select_entities( + entity_ids: List[str], + all_entities: Dict[str, Dict[str, Any]], + rule: FailureRule, + ) -> List[str]: + """ + Select which entity IDs will fail from the matched set, based on rule_type. + + :param entity_ids: + IDs that matched the rule's conditions. + :param all_entities: + The full entity dictionary (not strictly needed for some rule_types). + :param rule: + The FailureRule specifying how to pick the final subset. + :returns: + The final list of entity IDs that fail from this rule. + """ + if rule.rule_type == "random": + return [e for e in entity_ids if random() < rule.probability] + elif rule.rule_type == "choice": + count = min(rule.count, len(entity_ids)) + # Use sorted(...) to ensure consistent picks when testing + return sample(sorted(entity_ids), k=count) + elif rule.rule_type == "all": + return entity_ids + else: + raise ValueError(f"Unsupported rule_type: {rule.rule_type}") + + +def _evaluate_condition(entity: Dict[str, Any], cond: FailureCondition) -> bool: + """ + Evaluate one condition (attr, operator, value) against an entity's attrs. + + :param entity: + The entity's attribute dictionary (node.attrs or link.attrs). + :param cond: + A single :class:`FailureCondition` specifying 'attr', 'operator', 'value'. + :returns: + True if the condition passes, else False. + :raises ValueError: + If the condition's operator is not recognized. + """ + derived_value = entity.get(cond.attr, None) + op = cond.operator + if op == "==": + return derived_value == cond.value + elif op == "!=": + return derived_value != cond.value + elif op == "<": + return derived_value < cond.value + elif op == "<=": + return derived_value <= cond.value + elif op == ">": + return derived_value > cond.value + elif op == ">=": + return derived_value >= cond.value + else: + raise ValueError(f"Unsupported operator: {op}") diff --git a/ngraph/network.py b/ngraph/network.py index 39ba847..d8f346d 100644 --- a/ngraph/network.py +++ b/ngraph/network.py @@ -7,7 +7,7 @@ def new_base64_uuid() -> str: """ - Generate a Base64-encoded UUID without padding (~22 characters). + Generate a Base64-encoded UUID without padding (a string with 22 characters). """ return base64.urlsafe_b64encode(uuid.uuid4().bytes).decode("ascii").rstrip("=") @@ -21,7 +21,12 @@ class Node: in the Network's node dictionary. :param name: The unique name of the node. - :param attrs: Optional extra metadata for the node. + :param attrs: Optional extra metadata for the node. For example: + { + "type": "node", # auto-tagged upon add_node + "coords": [lat, lon], # user-provided + "region": "west_coast" # user-provided + } """ name: str @@ -42,8 +47,13 @@ class Link: :param capacity: Link capacity (default 1.0). :param latency: Link latency (default 1.0). :param cost: Link cost (default 1.0). - :param attrs: Optional extra metadata for the link. - :param id: Auto-generated unique link identifier. + :param attrs: Optional extra metadata for the link. For example: + { + "type": "link", # auto-tagged upon add_link + "distance_km": 1500, # user-provided + "fiber_provider": "Lumen", # user-provided + } + :param id: Auto-generated unique link identifier, e.g. "SEA-DEN-abCdEf..." """ source: str @@ -67,13 +77,13 @@ class Network: """ A container for network nodes and links. - Nodes are stored in a dictionary keyed by their unique names. - Links are stored in a dictionary keyed by their auto-generated IDs. + Nodes are stored in a dictionary keyed by their unique names (:attr:`Node.name`). + Links are stored in a dictionary keyed by their auto-generated IDs (:attr:`Link.id`). The 'attrs' dict allows extra network metadata. - :param nodes: Mapping from node name to Node. - :param links: Mapping from link id to Link. - :param attrs: Optional extra metadata for the network. + :param nodes: Mapping from node name -> Node object. + :param links: Mapping from link id -> Link object. + :param attrs: Optional extra metadata for the network itself. """ nodes: Dict[str, Node] = field(default_factory=dict) @@ -82,21 +92,33 @@ class Network: def add_node(self, node: Node) -> None: """ - Add a node to the network, keyed by its name. + Add a node to the network, keyed by its :attr:`Node.name`. + + This method also auto-tags the node with ``node.attrs["type"] = "node"`` + if it's not already set. :param node: The Node to add. + :raises ValueError: If a node with the same name is already in the network. """ + node.attrs.setdefault("type", "node") + if node.name in self.nodes: + raise ValueError(f"Node '{node.name}' already exists in the network.") self.nodes[node.name] = node def add_link(self, link: Link) -> None: """ - Add a link to the network. Both source and target nodes must exist. + Add a link to the network, keyed by its auto-generated :attr:`Link.id`. + + This method also auto-tags the link with ``link.attrs["type"] = "link"`` + if it's not already set. :param link: The Link to add. - :raises ValueError: If the source or target node is not present. + :raises ValueError: If the source/target node is not present in the network. """ if link.source not in self.nodes: raise ValueError(f"Source node '{link.source}' not found in network.") if link.target not in self.nodes: raise ValueError(f"Target node '{link.target}' not found in network.") + + link.attrs.setdefault("type", "link") self.links[link.id] = link diff --git a/ngraph/scenario.py b/ngraph/scenario.py index 6d61fda..6542e87 100644 --- a/ngraph/scenario.py +++ b/ngraph/scenario.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List from ngraph.network import Network, Node, Link -from ngraph.failure_policy import FailurePolicy +from ngraph.failure_policy import FailurePolicy, FailureRule, FailureCondition from ngraph.traffic_demand import TrafficDemand from ngraph.results import Results from ngraph.workflow.base import WorkflowStep, WORKFLOW_STEP_REGISTRY @@ -13,49 +13,34 @@ @dataclass(slots=True) class Scenario: """ - Represents a complete scenario, including the network, failure policy, - traffic demands, workflow steps, and a results store. - - Usage: - scenario = Scenario.from_yaml(yaml_str) - scenario.run() - # Access scenario.results for workflow outputs - - Example YAML structure: - - network: - nodes: - JFK: - coords: [40.64, -73.78] - LAX: - coords: [33.94, -118.41] - links: - - source: JFK - target: LAX - capacity: 100 - latency: 50 - cost: 50 - attrs: { distance_km: 4000 } - - failure_policy: - failure_probabilities: - node: 0.001 - link: 0.002 - - traffic_demands: - - source: JFK - target: LAX - demand: 50 - - workflow: - - step_type: BuildGraph - name: build_graph - - :param network: The network model. - :param failure_policy: The policy for element failures. - :param traffic_demands: A list of traffic demands. - :param workflow: A list of WorkflowStep objects to be executed in order. - :param results: A Results object to store step outputs, summary, etc. + Represents a complete scenario, including: + - The network (nodes and links). + - A failure policy (with one or more rules). + - Traffic demands. + - A workflow of steps to execute. + - A results container for storing outputs. + + Typical usage: + 1. Create a Scenario from YAML: :: + + scenario = Scenario.from_yaml(yaml_str) + + 2. Run it: :: + + scenario.run() + + 3. Check scenario.results for step outputs. + + :param network: + The network model containing nodes and links. + :param failure_policy: + The multi-rule failure policy describing how and which entities fail. + :param traffic_demands: + A list of traffic demands describing source/target flows. + :param workflow: + A list of workflow steps defining the scenario pipeline. + :param results: + A Results object to store outputs from workflow steps. """ network: Network @@ -66,9 +51,11 @@ class Scenario: def run(self) -> None: """ - Execute the scenario's workflow steps in the given order. - Each WorkflowStep has access to this Scenario object and - can store output in scenario.results. + Execute the scenario's workflow steps in the defined order. + + Each step has access to :attr:`Scenario.network`, + :attr:`Scenario.failure_policy`, etc. Steps may store outputs in + :attr:`Scenario.results`. """ for step in self.workflow: step.run(self) @@ -76,12 +63,54 @@ def run(self) -> None: @classmethod def from_yaml(cls, yaml_str: str) -> Scenario: """ - Construct a Scenario from a YAML string. + Construct a :class:`Scenario` from a YAML string. + + Expected top-level YAML keys: + - ``network``: Node/Link definitions + - ``failure_policy``: A multi-rule policy + - ``traffic_demands``: List of demands + - ``workflow``: Steps to run + + Example: + + .. code-block:: yaml - This looks for top-level sections: - 'network', 'failure_policy', 'traffic_demands', and 'workflow'. + network: + nodes: + SEA: { coords: [47.6062, -122.3321] } + SFO: { coords: [37.7749, -122.4194] } + links: + - source: SEA + target: SFO + capacity: 100 + attrs: { distance_km: 1300 } - See the class docstring for a short example of the expected structure. + failure_policy: + name: "multi_rule_example" + rules: + - conditions: + - attr: "type" + operator: "==" + value: "node" + logic: "and" + rule_type: "choice" + count: 1 + + traffic_demands: + - source: SEA + target: SFO + demand: 50 + + workflow: + - step_type: BuildGraph + name: build_graph + + :param yaml_str: + The YAML string defining a scenario. + :returns: + A fully constructed :class:`Scenario` instance. + :raises ValueError: + If the YAML is malformed or missing required sections. """ data = yaml.safe_load(yaml_str) if not isinstance(data, dict): @@ -91,17 +120,15 @@ def from_yaml(cls, yaml_str: str) -> Scenario: network_data = data.get("network", {}) network = cls._build_network(network_data) - # 2) Build the failure policy + # 2) Build the (new) multi-rule failure policy fp_data = data.get("failure_policy", {}) - failure_policy = FailurePolicy( - failure_probabilities=fp_data.get("failure_probabilities", {}) - ) + failure_policy = cls._build_failure_policy(fp_data) # 3) Build traffic demands traffic_demands_data = data.get("traffic_demands", []) traffic_demands = [TrafficDemand(**td) for td in traffic_demands_data] - # 4) Build workflow steps using the registry + # 4) Build workflow steps workflow_data = data.get("workflow", []) workflow_steps = cls._build_workflow_steps(workflow_data) @@ -115,7 +142,30 @@ def from_yaml(cls, yaml_str: str) -> Scenario: @staticmethod def _build_network(network_data: Dict[str, Any]) -> Network: """ - Construct a Network object from a dictionary containing 'nodes' and 'links'. + Construct a :class:`Network` object from a dictionary containing 'nodes' and 'links'. + The dictionary is expected to look like: + + .. code-block:: yaml + + nodes: + SEA: { coords: [47.6062, -122.3321] } + SFO: { coords: [37.7749, -122.4194] } + + links: + - source: SEA + target: SFO + capacity: 100 + latency: 5 + cost: 10 + attrs: + distance_km: 1300 + + :param network_data: + Dictionary with optional keys 'nodes' and 'links'. + :returns: + A :class:`Network` containing the parsed nodes and links. + :raises ValueError: + If a link references nodes not defined in the network. """ net = Network() @@ -139,15 +189,89 @@ def _build_network(network_data: Dict[str, Any]) -> Network: return net + @staticmethod + def _build_failure_policy(fp_data: Dict[str, Any]) -> FailurePolicy: + """ + Construct a :class:`FailurePolicy` from YAML data that may look like: + + .. code-block:: yaml + + failure_policy: + name: "multi_rule_example" + description: "Example of multi-rule approach" + rules: + - conditions: + - attr: "type" + operator: "==" + value: "link" + logic: "and" + rule_type: "random" + probability: 0.1 + + :param fp_data: + Dictionary from the 'failure_policy' section of YAML. + :returns: + A :class:`FailurePolicy` object with a list of :class:`FailureRule`. + """ + # Extract the list of rules + rules_data = fp_data.get("rules", []) + rules: List[FailureRule] = [] + + for rule_dict in rules_data: + conditions_data = rule_dict.get("conditions", []) + conditions: List[FailureCondition] = [] + for cond_dict in conditions_data: + condition = FailureCondition( + attr=cond_dict["attr"], + operator=cond_dict["operator"], + value=cond_dict["value"], + ) + conditions.append(condition) + + rule = FailureRule( + conditions=conditions, + logic=rule_dict.get("logic", "and"), + rule_type=rule_dict.get("rule_type", "all"), + probability=rule_dict.get("probability", 1.0), + count=rule_dict.get("count", 1), + ) + rules.append(rule) + + # All other key-value pairs go into policy.attrs (e.g. "name", "description") + attrs = {k: v for k, v in fp_data.items() if k != "rules"} + + return FailurePolicy(rules=rules, attrs=attrs) + @staticmethod def _build_workflow_steps( workflow_data: List[Dict[str, Any]] ) -> List[WorkflowStep]: """ - Instantiate workflow steps listed in 'workflow_data' using WORKFLOW_STEP_REGISTRY. + Convert a list of workflow step dictionaries into instantiated + :class:`WorkflowStep` objects. + + Each step dict must have a ``step_type`` referencing a registered + workflow step in :attr:`WORKFLOW_STEP_REGISTRY`. Any additional + keys are passed as init arguments. + + Example: + + .. code-block:: yaml + + workflow: + - step_type: BuildGraph + name: build_graph + - step_type: ComputeRoutes + name: compute_routes + + :param workflow_data: + A list of dictionaries, each describing a workflow step. + :returns: + A list of instantiated :class:`WorkflowStep` objects in the same order. + :raises ValueError: + If any dict lacks "step_type" or references an unknown type. """ steps: List[WorkflowStep] = [] - for step_info in workflow_data: step_type = step_info.get("step_type") if not step_type: @@ -155,13 +279,11 @@ def _build_workflow_steps( "Each workflow entry must have a 'step_type' field " "indicating which WorkflowStep subclass to use." ) - step_cls = WORKFLOW_STEP_REGISTRY.get(step_type) if not step_cls: raise ValueError(f"Unrecognized 'step_type': {step_type}") - # Remove 'step_type' so it doesn't clash with the step_class __init__ + # Remove 'step_type' so it doesn't clash with step_cls.__init__ step_args = {k: v for k, v in step_info.items() if k != "step_type"} steps.append(step_cls(**step_args)) - return steps diff --git a/tests/scenarios/scenario_1.yaml b/tests/scenarios/scenario_1.yaml index f3c3f50..aa0cad4 100644 --- a/tests/scenarios/scenario_1.yaml +++ b/tests/scenarios/scenario_1.yaml @@ -1,176 +1,26 @@ network: + name: "6-node-l3-us-backbone" + version: "1.0" + nodes: - JFK: - coords: [40.641766, -73.780968] - LAX: - coords: [33.941589, -118.40853] - ORD: - coords: [41.974163, -87.907321] - IAH: - coords: [29.99022, -95.336783] - PHX: - coords: [33.437269, -112.007788] - PHL: - coords: [39.874395, -75.242423] - SAT: - coords: [29.424122, -98.493629] - SAN: - coords: [32.733801, -117.193304] - DFW: - coords: [32.899809, -97.040335] - SJC: - coords: [37.363947, -121.928938] - AUS: - coords: [30.197475, -97.666305] - JAX: - coords: [30.332184, -81.655651] - CMH: - coords: [39.961176, -82.998794] - IND: - coords: [39.768403, -86.158068] - CLT: - coords: [35.227087, -80.843127] - SFO: - coords: [37.774929, -122.419416] SEA: - coords: [47.606209, -122.332071] + coords: [47.6062, -122.3321] + SFO: + coords: [37.7749, -122.4194] DEN: - coords: [39.739236, -104.990251] + coords: [39.7392, -104.9903] + DFW: + coords: [32.8998, -97.0403] + JFK: + coords: [40.641766, -73.780968] DCA: coords: [38.907192, -77.036871] links: - - source: JFK - target: PHL - capacity: 100 - latency: 756 - cost: 756 - attrs: - distance_km: 151.19 - - - source: JFK - target: DCA - capacity: 100 - latency: 1714 - cost: 1714 - attrs: - distance_km: 342.69 - - - source: LAX - target: SFO - capacity: 100 - latency: 2720 - cost: 2720 - attrs: - distance_km: 543.95 - - - source: LAX - target: SAN - capacity: 100 - latency: 893 - cost: 893 - attrs: - distance_km: 178.55 - - - source: ORD - target: IND - capacity: 100 - latency: 1395 - cost: 1395 - attrs: - distance_km: 278.84 - - - source: ORD - target: CMH - capacity: 100 - latency: 1386 - cost: 1386 - attrs: - distance_km: 277.17 - - - source: IAH - target: DFW - capacity: 100 - latency: 1802 - cost: 1802 - attrs: - distance_km: 360.25 - - - source: IAH - target: AUS - capacity: 100 - latency: 1133 - cost: 1133 - attrs: - distance_km: 226.53 - - - source: PHX - target: LAX - capacity: 100 - latency: 2982 - cost: 2982 - attrs: - distance_km: 596.35 - - - source: SAT - target: AUS - capacity: 100 - latency: 594 - cost: 594 - attrs: - distance_km: 118.69 - - - source: DFW - target: AUS - capacity: 100 - latency: 1539 - cost: 1539 - attrs: - distance_km: 307.63 - - - source: SJC - target: SFO - capacity: 100 - latency: 339 - cost: 339 - attrs: - distance_km: 67.79 - - - source: SJC - target: LAX - capacity: 100 - latency: 2468 - cost: 2468 - attrs: - distance_km: 493.55 - - - source: CLT - target: DCA - capacity: 100 - latency: 2654 - cost: 2654 - attrs: - distance_km: 530.64 - + # West -> Middle - source: SEA - target: SFO - capacity: 100 - latency: 5460 - cost: 5460 - attrs: - distance_km: 1091.95 - - - source: DEN - target: PHX - capacity: 100 - latency: 4761 - cost: 4761 - attrs: - distance_km: 952.16 - - - source: DEN - target: SEA - capacity: 100 + target: DEN + capacity: 200 latency: 6846 cost: 6846 attrs: @@ -184,98 +34,98 @@ network: attrs: distance_km: 1550.77 + - source: SEA + target: DFW + capacity: 200 + latency: 9600 + cost: 9600 + attrs: + distance_km: 1920 + + - source: SFO + target: DFW + capacity: 200 + latency: 10000 + cost: 10000 + attrs: + distance_km: 2000 + + # Middle <-> Middle (two parallel links to represent redundancy) - source: DEN - target: ORD - capacity: 300 + target: DFW + capacity: 400 latency: 7102 cost: 7102 attrs: distance_km: 1420.28 - - source: PHX + - source: DEN target: DFW - capacity: 300 - latency: 6900 - cost: 6900 + capacity: 400 + latency: 7102 + cost: 7102 attrs: - distance_km: 1380 + distance_km: 1420.28 - - source: CMH + # Middle -> East + - source: DEN target: JFK - capacity: 100 - latency: 3788 - cost: 3788 - attrs: - distance_km: 757.58 - - - source: IND - target: CLT - capacity: 100 - latency: 3419 - cost: 3419 - attrs: - distance_km: 683.66 - - - source: IAH - target: JAX - capacity: 100 - latency: 6746 - cost: 6746 - attrs: - distance_km: 1349.04 - - - source: SAT - target: AUS - capacity: 100 - latency: 594 - cost: 594 - attrs: - distance_km: 118.69 - - - source: SAT - target: IAH - capacity: 100 - latency: 1524 - cost: 1524 + capacity: 200 + latency: 7500 + cost: 7500 attrs: - distance_km: 304.66 + distance_km: 1500 - - source: JAX - target: CLT - capacity: 100 - latency: 2671 - cost: 2671 + - source: DFW + target: DCA + capacity: 200 + latency: 8000 + cost: 8000 attrs: - distance_km: 534.2 + distance_km: 1600 - - source: SAN - target: PHX - capacity: 100 - latency: 2435 - cost: 2435 + - source: DFW + target: JFK + capacity: 200 + latency: 9500 + cost: 9500 attrs: - distance_km: 486.91 + distance_km: 1900 - - source: DFW - target: IND + # East <-> East + - source: JFK + target: DCA capacity: 100 - latency: 6889 - cost: 6889 + latency: 1714 + cost: 1714 attrs: - distance_km: 1377.67 + distance_km: 342.69 failure_policy: - failure_probabilities: - node: 0.001 - link: 0.002 + name: "anySingleLink" + description: "Evaluate traffic routing under any single link failure." + rules: + - conditions: + - attr: "type" + operator: "==" + value: "link" + logic: "and" + rule_type: "choice" + count: 1 traffic_demands: - - source: JFK - target: LAX + - source: SEA + target: JFK + demand: 50 + - source: SFO + target: DCA + demand: 50 + - source: SEA + target: DCA + demand: 50 + - source: SFO + target: JFK demand: 50 - - source: SAN - target: SEA - demand: 30 workflow: - step_type: BuildGraph diff --git a/tests/scenarios/test_scenario_1.py b/tests/scenarios/test_scenario_1.py index 3ce52be..3a4ed2f 100644 --- a/tests/scenarios/test_scenario_1.py +++ b/tests/scenarios/test_scenario_1.py @@ -3,13 +3,17 @@ from pathlib import Path from ngraph.scenario import Scenario +from ngraph.failure_policy import FailurePolicy def test_scenario_1_build_graph() -> None: """ Integration test that verifies we can parse scenario_1.yaml, run the BuildGraph step, and produce a valid NetworkX MultiDiGraph. - Also checks traffic demands and failure policy. + Checks: + - The expected number of nodes and links are correctly parsed. + - The traffic demands are loaded. + - The multi-rule failure policy matches "anySingleLink". """ # 1) Load the YAML file @@ -26,28 +30,54 @@ def test_scenario_1_build_graph() -> None: graph = scenario.results.get("build_graph", "graph") assert isinstance( graph, nx.MultiDiGraph - ), "Expected a MultiDiGraph in scenario.results." + ), "Expected a MultiDiGraph in scenario.results under key ('build_graph', 'graph')." # 5) Check the total number of nodes matches what's listed in scenario_1.yaml - assert len(graph.nodes) == 19, f"Expected 19 nodes, found {len(graph.nodes)}" + # For a 6-node scenario, we expect 6 nodes in the final Nx graph. + expected_nodes = 6 + actual_nodes = len(graph.nodes) + assert ( + actual_nodes == expected_nodes + ), f"Expected {expected_nodes} nodes, found {actual_nodes}" - # 6) Each physical link becomes 2 directed edges in the MultiDiGraph. - # The YAML has 28 total link lines (including one duplicate SAT->AUS entry). - # So expected edges = 2 * 28 = 56. - expected_links = 28 + # 6) Each physical link from the YAML becomes 2 directed edges in MultiDiGraph. + # If the YAML has 10 link definitions, we expect 2 * 10 = 20 directed edges. + expected_links = 10 expected_nx_edges = expected_links * 2 actual_edges = len(graph.edges) assert ( actual_edges == expected_nx_edges ), f"Expected {expected_nx_edges} directed edges, found {actual_edges}" - # 7) Verify the traffic demands - assert len(scenario.traffic_demands) == 2, "Expected 2 traffic demands." - demand_map = {(td.source, td.target): td.demand for td in scenario.traffic_demands} - # scenario_1.yaml has demands: (JFK->LAX=50), (SAN->SEA=30) - assert demand_map[("JFK", "LAX")] == 50 - assert demand_map[("SAN", "SEA")] == 30 + # 7) Verify the traffic demands. In scenario_1.yaml, let's assume we have 4 demands. + # Adjust this to match your actual scenario_1.yaml. + expected_demands = 4 + assert ( + len(scenario.traffic_demands) == expected_demands + ), f"Expected {expected_demands} traffic demands." + + # 8) Check the new multi-rule failure policy for "any single link". + # This should have exactly 1 rule that picks exactly 1 link from all links. + policy: FailurePolicy = scenario.failure_policy + assert len(policy.rules) == 1, "Should only have 1 rule for 'anySingleLink'." - # 8) Check the failure policy from YAML - assert scenario.failure_policy.failure_probabilities["node"] == 0.001 - assert scenario.failure_policy.failure_probabilities["link"] == 0.002 + rule = policy.rules[0] + # - conditions: [ {attr: 'type', operator: '==', value: 'link'} ] + # - logic: 'and' + # - rule_type: 'choice' + # - count: 1 + assert len(rule.conditions) == 1, "Expected exactly 1 condition for matching links." + cond = rule.conditions[0] + assert cond.attr == "type" + assert cond.operator == "==" + assert cond.value == "link" + + assert rule.logic == "and" + assert rule.rule_type == "choice" + assert rule.count == 1 + + assert policy.attrs.get("name") == "anySingleLink" + assert ( + policy.attrs.get("description") + == "Evaluate traffic routing under any single link failure." + ) diff --git a/tests/test_failure_policy.py b/tests/test_failure_policy.py index 645e180..768d90a 100644 --- a/tests/test_failure_policy.py +++ b/tests/test_failure_policy.py @@ -1,82 +1,318 @@ import pytest from unittest.mock import patch -from ngraph.failure_policy import FailurePolicy +from ngraph.failure_policy import ( + FailurePolicy, + FailureRule, + FailureCondition, + _evaluate_condition, +) -def test_default_attributes(): +def test_empty_policy_no_failures(): """ - Ensure default constructor creates an empty failure_probabilities dict, - and sets distribution to 'uniform'. + Verify that if no rules are present, no entities fail. """ - policy = FailurePolicy() - assert policy.failure_probabilities == {} - assert policy.distribution == "uniform" + policy = FailurePolicy(rules=[]) + # Suppose we have 2 nodes, 1 link + nodes = { + "N1": {"type": "node", "capacity": 100}, + "N2": {"type": "node", "capacity": 200}, + } + links = { + "N1-N2-abc123": {"type": "link", "capacity": 50}, + } -@patch("ngraph.failure_policy.random") -def test_test_failure_returns_true(mock_random): + failed = policy.apply_failures(nodes, links) + assert failed == [], "No rules => no entities fail." + + +def test_single_rule_all_matched(): """ - For a specific tag with nonzero probability, verify test_failure() returns True - when random() is less than that probability. + If we have a rule that matches all entities and selects 'all', + then everything fails. """ - policy = FailurePolicy(failure_probabilities={"node1": 0.7}) + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="!=", value="")], + logic="and", + rule_type="all", + ) + policy = FailurePolicy(rules=[rule]) - # Mock random to return 0.5 which is < 0.7 - mock_random.return_value = 0.5 - assert ( - policy.test_failure("node1") is True - ), "Should return True when random() < failure probability." + nodes = {"N1": {"type": "node"}, "N2": {"type": "node"}} + links = { + "L1": {"type": "link"}, + "L2": {"type": "link"}, + } + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"N1", "N2", "L1", "L2"} -@patch("ngraph.failure_policy.random") -def test_test_failure_returns_false(mock_random): + +def test_single_rule_choice(): """ - For a specific tag with nonzero probability, verify test_failure() returns False - when random() is not less than that probability. + Test rule_type='choice': it picks exactly 'count' entities from the matched set. """ - policy = FailurePolicy(failure_probabilities={"node1": 0.3}) + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="and", + rule_type="choice", + count=2, + ) + policy = FailurePolicy(rules=[rule]) + + nodes = { + "SEA": {"type": "node", "capacity": 100}, + "SFO": {"type": "node", "capacity": 200}, + "DEN": {"type": "node", "capacity": 300}, + } + links = { + "SEA-SFO-xxx": {"type": "link", "capacity": 400}, + } - # Mock random to return 0.4 which is > 0.3 - mock_random.return_value = 0.4 - assert ( - policy.test_failure("node1") is False - ), "Should return False when random() >= failure probability." + with patch("ngraph.failure_policy.sample", return_value=["SEA", "DEN"]): + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"SEA", "DEN"} @patch("ngraph.failure_policy.random") -def test_test_failure_zero_probability(mock_random): +def test_single_rule_random(mock_random): """ - A probability of zero means it should always return False, even if random() is also zero. + For rule_type='random', each matched entity is selected if random() < probability. + We'll mock out random() to test. """ - policy = FailurePolicy(failure_probabilities={"node1": 0.0}) + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="link")], + logic="and", + rule_type="random", + probability=0.5, + ) + policy = FailurePolicy(rules=[rule]) - mock_random.return_value = 0.0 - assert ( - policy.test_failure("node1") is False - ), "Should always return False with probability = 0.0" + nodes = { + "SEA": {"type": "node"}, + "SFO": {"type": "node"}, + } + links = { + "L1": {"type": "link", "capacity": 100}, + "L2": {"type": "link", "capacity": 100}, + "L3": {"type": "link", "capacity": 100}, + } + mock_random.side_effect = [0.4, 0.6, 0.3] + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"L1", "L3"}, "Should fail those where random() < 0.5" -@patch("ngraph.failure_policy.random") -def test_test_failure_no_entry_for_tag(mock_random): + +def test_operator_conditions(): + """ + Check that <, != conditions evaluate correctly in 'and' logic. + (We also have coverage for ==, > in other tests.) + """ + conditions = [ + FailureCondition(attr="capacity", operator="<", value=300), + FailureCondition(attr="region", operator="!=", value="east"), + ] + rule = FailureRule(conditions=conditions, logic="and", rule_type="all") + policy = FailurePolicy(rules=[rule]) + + nodes = { + "N1": {"type": "node", "capacity": 100, "region": "west"}, # matches + "N2": {"type": "node", "capacity": 100, "region": "east"}, # fails != + "N3": {"type": "node", "capacity": 300, "region": "west"}, # fails < + } + links = { + "L1": {"type": "link", "capacity": 200, "region": "east"}, # fails != + } + + failed = policy.apply_failures(nodes, links) + assert failed == ["N1"] + + +def test_logic_or(): + """ + Check 'or' logic: an entity is matched if it satisfies at least one condition (>150 or region=east). + """ + conditions = [ + FailureCondition(attr="capacity", operator=">", value=150), + FailureCondition(attr="region", operator="==", value="east"), + ] + rule = FailureRule(conditions=conditions, logic="or", rule_type="all") + policy = FailurePolicy(rules=[rule]) + + nodes = { + "N1": {"type": "node", "capacity": 100, "region": "west"}, # fails both + "N2": { + "type": "node", + "capacity": 200, + "region": "west", + }, # passes capacity>150 + "N3": {"type": "node", "capacity": 100, "region": "east"}, # passes region=east + "N4": {"type": "node", "capacity": 200, "region": "east"}, # passes both + } + links = {} + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"N2", "N3", "N4"} + + +def test_logic_any(): + """ + 'any' logic means all entities are selected, ignoring conditions. + """ + rule = FailureRule( + conditions=[FailureCondition(attr="capacity", operator="==", value=-999)], + logic="any", + rule_type="all", + ) + policy = FailurePolicy(rules=[rule]) + + nodes = {"N1": {"type": "node"}, "N2": {"type": "node"}} + links = {"L1": {"type": "link"}, "L2": {"type": "link"}} + + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"N1", "N2", "L1", "L2"} + + +def test_multiple_rules_union(): + """ + If multiple rules exist, the final set of failed entities is the union + of each rule's selection. + """ + rule1 = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="and", + rule_type="all", + ) + rule2 = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="link")], + logic="and", + rule_type="choice", + count=1, + ) + policy = FailurePolicy(rules=[rule1, rule2]) + + nodes = {"N1": {"type": "node"}, "N2": {"type": "node"}} + links = {"L1": {"type": "link"}, "L2": {"type": "link"}} + + with patch("ngraph.failure_policy.sample", return_value=["L1"]): + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"N1", "N2", "L1"} + + +def test_unsupported_logic(): + """ + Ensure that if a rule specifies an unsupported logic string, + _evaluate_conditions() raises ValueError. + """ + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="UNSUPPORTED", + rule_type="all", + ) + policy = FailurePolicy(rules=[rule]) + + nodes = {"A": {"type": "node"}} + links = {} + with pytest.raises(ValueError, match="Unsupported logic: UNSUPPORTED"): + policy.apply_failures(nodes, links) + + +def test_unsupported_rule_type(): + """ + Ensure that if a rule has an unknown rule_type, + _select_entities() raises ValueError. + """ + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="and", + rule_type="UNKNOWN", + ) + policy = FailurePolicy(rules=[rule]) + + nodes = {"A": {"type": "node"}} + links = {} + with pytest.raises(ValueError, match="Unsupported rule_type: UNKNOWN"): + policy.apply_failures(nodes, links) + + +def test_unsupported_operator(): + """ + If a condition has an unknown operator, _evaluate_condition() raises ValueError. """ - If no entry for a given tag is found, probability defaults to 0.0 => always False. + cond = FailureCondition(attr="capacity", operator="??", value=100) + with pytest.raises(ValueError, match="Unsupported operator: "): + _evaluate_condition({"capacity": 100}, cond) + + +def test_no_conditions_with_non_any_logic(): + """ + If logic is not 'any' but conditions is empty, + we expect _evaluate_conditions() to return False. + """ + rule = FailureRule(conditions=[], logic="and", rule_type="all") + policy = FailurePolicy(rules=[rule]) + + nodes = {"N1": {"type": "node"}} + links = {} + failed = policy.apply_failures(nodes, links) + assert failed == [], "No conditions => no match => no failures." + + +def test_choice_larger_count_than_matched(): """ - policy = FailurePolicy() + If rule.count > number of matched entities, we pick all matched. + """ + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="and", + rule_type="choice", + count=10, + ) + policy = FailurePolicy(rules=[rule]) - mock_random.return_value = 0.0 - assert ( - policy.test_failure("unknown_tag") is False - ), "Unknown tag should default to 0.0 probability => always False." + nodes = {"A": {"type": "node"}, "B": {"type": "node"}} + links = {"L1": {"type": "link"}} + failed = policy.apply_failures(nodes, links) + assert set(failed) == {"A", "B"} -def test_test_failure_non_uniform_distribution(): +def test_choice_zero_count(): """ - Verify that any distribution other than 'uniform' raises a ValueError. + If rule.count=0, we select none from the matched entities. """ - policy = FailurePolicy(distribution="non_uniform") + rule = FailureRule( + conditions=[FailureCondition(attr="type", operator="==", value="node")], + logic="and", + rule_type="choice", + count=0, + ) + policy = FailurePolicy(rules=[rule]) + + nodes = {"A": {"type": "node"}, "B": {"type": "node"}, "C": {"type": "node"}} + links = {} + failed = policy.apply_failures(nodes, links) + assert failed == [], "count=0 => no entities chosen." + + +def test_operator_condition_le_ge(): + """ + Verify that the '<=' and '>=' operators in _evaluate_condition are correctly handled. + """ + cond_le = FailureCondition(attr="capacity", operator="<=", value=100) + cond_ge = FailureCondition(attr="capacity", operator=">=", value=100) + + # Entity with capacity=100 => passes both <=100 and >=100 + e1 = {"capacity": 100} + assert _evaluate_condition(e1, cond_le) is True + assert _evaluate_condition(e1, cond_ge) is True - with pytest.raises(ValueError) as exc_info: - policy.test_failure("node1") + # capacity=90 => pass <=100, fail >=100 + e2 = {"capacity": 90} + assert _evaluate_condition(e2, cond_le) is True + assert _evaluate_condition(e2, cond_ge) is False - assert "Unsupported distribution" in str(exc_info.value) + # capacity=110 => fail <=100, pass >=100 + e3 = {"capacity": 110} + assert _evaluate_condition(e3, cond_le) is False + assert _evaluate_condition(e3, cond_ge) is True diff --git a/tests/test_network.py b/tests/test_network.py index 1bdcf0f..7532a7b 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -1,132 +1,160 @@ import pytest -from ngraph.network import ( - Network, - Node, - Link, - new_base64_uuid -) +from ngraph.network import Network, Node, Link, new_base64_uuid def test_new_base64_uuid_length_and_uniqueness(): - # Generate two Base64-encoded UUIDs + """ + Generate two Base64-encoded UUIDs and confirm: + - they are strings with no '=' padding + - they are 22 chars long + - they differ from each other + """ uuid1 = new_base64_uuid() uuid2 = new_base64_uuid() - - # They should be strings without any padding characters + assert isinstance(uuid1, str) assert isinstance(uuid2, str) - assert '=' not in uuid1 - assert '=' not in uuid2 - - # They are typically 22 characters long (Base64 without padding) + assert "=" not in uuid1 + assert "=" not in uuid2 + + # 22 characters for a UUID in unpadded Base64 assert len(uuid1) == 22 assert len(uuid2) == 22 - - # The two generated UUIDs should be unique + + # They should be unique assert uuid1 != uuid2 + def test_node_creation_default_attrs(): - # Create a Node with default attributes + """ + A new Node with no attrs should have an empty dict for attrs. + """ node = Node("A") assert node.name == "A" assert node.attrs == {} + def test_node_creation_custom_attrs(): - # Create a Node with custom attributes + """ + A new Node can be created with custom attributes that are stored as-is. + """ custom_attrs = {"key": "value", "number": 42} node = Node("B", attrs=custom_attrs) assert node.name == "B" assert node.attrs == custom_attrs + def test_link_defaults_and_id_generation(): - # Create a Link; __post_init__ should auto-generate the id. + """ + A Link without custom parameters should default capacity/latency/cost to 1.0, + have an empty attrs dict, and generate a unique ID like 'A-B-'. + """ link = Link("A", "B") - - # Check default parameters are set correctly. + assert link.capacity == 1.0 assert link.latency == 1.0 assert link.cost == 1.0 assert link.attrs == {} - - # Verify the link ID is correctly formatted and starts with "A-B-" + + # ID should start with 'A-B-' and have a random suffix assert link.id.startswith("A-B-") - # Ensure there is a random UUID part appended after the prefix assert len(link.id) > len("A-B-") + def test_link_custom_values(): - # Create a Link with custom values + """ + A Link can be created with custom capacity/latency/cost/attrs, + and the ID is generated automatically. + """ custom_attrs = {"color": "red"} link = Link("X", "Y", capacity=2.0, latency=3.0, cost=4.0, attrs=custom_attrs) - + assert link.source == "X" assert link.target == "Y" assert link.capacity == 2.0 assert link.latency == 3.0 assert link.cost == 4.0 assert link.attrs == custom_attrs - # Check that the ID has the proper format assert link.id.startswith("X-Y-") + def test_link_id_uniqueness(): - # Two links between the same nodes should have different IDs. + """ + Even if two Links have the same source and target, the auto-generated IDs + should differ because of the random UUID portion. + """ link1 = Link("A", "B") link2 = Link("A", "B") assert link1.id != link2.id + def test_network_add_node_and_link(): - # Create a network and add two nodes + """ + Adding nodes and links to a Network should store them in dictionaries + keyed by node name and link ID, respectively. + """ network = Network() node_a = Node("A") node_b = Node("B") + network.add_node(node_a) network.add_node(node_b) - - # The nodes should be present in the network + assert "A" in network.nodes assert "B" in network.nodes - - # Create a link between the nodes and add it to the network + link = Link("A", "B") network.add_link(link) - - # Check that the link is stored in the network using its auto-generated id. + + # Link is stored under link.id assert link.id in network.links - # Verify that the stored link is the same object assert network.links[link.id] is link + def test_network_add_link_missing_source(): - # Create a network with only the target node + """ + Attempting to add a Link whose source node is not in the Network should raise an error. + """ network = Network() node_b = Node("B") network.add_node(node_b) - - # Try to add a link whose source node does not exist. - link = Link("A", "B") + + link = Link("A", "B") # 'A' doesn't exist + with pytest.raises(ValueError, match="Source node 'A' not found in network."): network.add_link(link) + def test_network_add_link_missing_target(): - # Create a network with only the source node + """ + Attempting to add a Link whose target node is not in the Network should raise an error. + """ network = Network() node_a = Node("A") network.add_node(node_a) - - # Try to add a link whose target node does not exist. - link = Link("A", "B") + + link = Link("A", "B") # 'B' doesn't exist with pytest.raises(ValueError, match="Target node 'B' not found in network."): network.add_link(link) + def test_network_attrs(): - # Test that extra network metadata can be stored in attrs. + """ + The Network's 'attrs' dictionary can store arbitrary metadata about the network. + """ network = Network(attrs={"network_type": "test"}) assert network.attrs["network_type"] == "test" -def test_duplicate_node_overwrite(): - # When adding nodes with the same name, the latter should overwrite the former. + +def test_add_duplicate_node_raises_valueerror(): + """ + With the new behavior, adding a second Node with the same name should raise ValueError + rather than overwriting the existing node. + """ network = Network() node1 = Node("A", attrs={"data": 1}) node2 = Node("A", attrs={"data": 2}) - + network.add_node(node1) - network.add_node(node2) # This should overwrite node1 - assert network.nodes["A"].attrs["data"] == 2 + with pytest.raises(ValueError, match="Node 'A' already exists in the network."): + network.add_node(node2) diff --git a/tests/test_scenario.py b/tests/test_scenario.py index 7f5fbdd..621b2dd 100644 --- a/tests/test_scenario.py +++ b/tests/test_scenario.py @@ -1,12 +1,11 @@ import pytest import yaml - from typing import TYPE_CHECKING from dataclasses import dataclass from ngraph.scenario import Scenario from ngraph.network import Network -from ngraph.failure_policy import FailurePolicy +from ngraph.failure_policy import FailurePolicy, FailureRule, FailureCondition from ngraph.traffic_demand import TrafficDemand from ngraph.results import Results from ngraph.workflow.base import ( @@ -34,9 +33,10 @@ class DoSmth(WorkflowStep): def run(self, scenario: Scenario) -> None: """ Perform a dummy operation for testing. - You might store something in scenario.results here if desired. + Store something in scenario.results using the step name as a key. """ - pass + # We can use self.name as the "step_name" + scenario.results.put(self.name, "ran", True) @register_workflow_step("DoSmthElse") @@ -49,17 +49,17 @@ class DoSmthElse(WorkflowStep): factor: float = 1.0 def run(self, scenario: Scenario) -> None: - """ - Perform another dummy operation for testing. - """ - pass + scenario.results.put(self.name, "ran", True) @pytest.fixture def valid_scenario_yaml() -> str: """ - Returns a valid YAML string for constructing a Scenario with a small - realistic network of three nodes and two links, plus two traffic demands. + Returns a YAML string for constructing a Scenario with: + - A small network of three nodes and two links + - A multi-rule failure policy + - Two traffic demands + - Two workflow steps """ return """ network: @@ -85,9 +85,22 @@ def valid_scenario_yaml() -> str: cost: 4 attrs: {} failure_policy: - failure_probabilities: - node: 0.01 - link: 0.02 + name: "multi_rule_example" + description: "Testing multi-rule approach." + rules: + - conditions: + - attr: "type" + operator: "==" + value: "node" + logic: "and" + rule_type: "choice" + count: 1 + - conditions: + - attr: "type" + operator: "==" + value: "link" + logic: "and" + rule_type: "all" traffic_demands: - source: NodeA target: NodeB @@ -121,9 +134,7 @@ def missing_step_type_yaml() -> str: target: NodeB capacity: 1 failure_policy: - failure_probabilities: - node: 0.01 - link: 0.02 + rules: [] traffic_demands: - source: NodeA target: NodeB @@ -150,9 +161,7 @@ def unrecognized_step_type_yaml() -> str: target: NodeB capacity: 1 failure_policy: - failure_probabilities: - node: 0.01 - link: 0.02 + rules: [] traffic_demands: - source: NodeA target: NodeB @@ -181,9 +190,7 @@ def extra_param_yaml() -> str: capacity: 1 traffic_demands: [] failure_policy: - failure_probabilities: - node: 0.01 - link: 0.02 + rules: [] workflow: - step_type: DoSmth name: StepWithExtra @@ -197,7 +204,7 @@ def test_scenario_from_yaml_valid(valid_scenario_yaml: str) -> None: Tests that a Scenario can be correctly constructed from a valid YAML string. Ensures that: - Network has correct nodes and links - - FailurePolicy is set + - FailurePolicy is set with multiple rules - TrafficDemands are parsed - Workflow steps are instantiated - Results object is present @@ -235,8 +242,33 @@ def test_scenario_from_yaml_valid(valid_scenario_yaml: str) -> None: # Check failure policy assert isinstance(scenario.failure_policy, FailurePolicy) - assert scenario.failure_policy.failure_probabilities["node"] == 0.01 - assert scenario.failure_policy.failure_probabilities["link"] == 0.02 + assert len(scenario.failure_policy.rules) == 2, "Expected 2 rules in the policy." + # Check that the leftover fields in failure_policy (e.g. "name", "description") + # went into policy.attrs + assert scenario.failure_policy.attrs.get("name") == "multi_rule_example" + assert ( + scenario.failure_policy.attrs.get("description") + == "Testing multi-rule approach." + ) + + # Rule1 => "choice", count=1, conditions => type == "node" + rule1 = scenario.failure_policy.rules[0] + assert rule1.rule_type == "choice" + assert rule1.count == 1 + assert len(rule1.conditions) == 1 + cond1 = rule1.conditions[0] + assert cond1.attr == "type" + assert cond1.operator == "==" + assert cond1.value == "node" + + # Rule2 => "all", conditions => type == "link" + rule2 = scenario.failure_policy.rules[1] + assert rule2.rule_type == "all" + assert len(rule2.conditions) == 1 + cond2 = rule2.conditions[0] + assert cond2.attr == "type" + assert cond2.operator == "==" + assert cond2.value == "link" # Check traffic demands assert len(scenario.traffic_demands) == 2 @@ -269,7 +301,12 @@ def test_scenario_from_yaml_valid(valid_scenario_yaml: str) -> None: # Verify the step types come from the registry assert step1.__class__ == WORKFLOW_STEP_REGISTRY["DoSmth"] + assert step1.name == "Step1" + assert step1.some_param == 42 + assert step2.__class__ == WORKFLOW_STEP_REGISTRY["DoSmthElse"] + assert step2.name == "Step2" + assert step2.factor == 2.0 # Check the scenario results store assert isinstance(scenario.results, Results) @@ -278,16 +315,15 @@ def test_scenario_from_yaml_valid(valid_scenario_yaml: str) -> None: def test_scenario_run(valid_scenario_yaml: str) -> None: """ Tests that calling scenario.run() executes each workflow step in order - without errors. This verifies the new .run() method introduced in the Scenario class. + without errors. Steps may store data in scenario.results. """ scenario = Scenario.from_yaml(valid_scenario_yaml) - - # Just ensure it runs without raising exceptions scenario.run() - # For a thorough test, one might check scenario.results or other side effects - # inside the steps themselves. Here, we just verify the workflow runs successfully. - assert True + # The first step's name is "Step1" in the YAML: + assert scenario.results.get("Step1", "ran", default=False) is True + # The second step's name is "Step2" in the YAML: + assert scenario.results.get("Step2", "ran", default=False) is True def test_scenario_from_yaml_missing_step_type(missing_step_type_yaml: str) -> None: From 7cf4cdfa3a30cd13ab7f5ecfefbb1c1c2c77d37b Mon Sep 17 00:00:00 2001 From: Andrey Golovanov Date: Sun, 9 Feb 2025 15:13:37 +0000 Subject: [PATCH 2/2] minor corrections to comments --- ngraph/scenario.py | 2 +- tests/scenarios/test_scenario_1.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/ngraph/scenario.py b/ngraph/scenario.py index 6542e87..83b8116 100644 --- a/ngraph/scenario.py +++ b/ngraph/scenario.py @@ -120,7 +120,7 @@ def from_yaml(cls, yaml_str: str) -> Scenario: network_data = data.get("network", {}) network = cls._build_network(network_data) - # 2) Build the (new) multi-rule failure policy + # 2) Build the multi-rule failure policy fp_data = data.get("failure_policy", {}) failure_policy = cls._build_failure_policy(fp_data) diff --git a/tests/scenarios/test_scenario_1.py b/tests/scenarios/test_scenario_1.py index 3a4ed2f..3eff2cc 100644 --- a/tests/scenarios/test_scenario_1.py +++ b/tests/scenarios/test_scenario_1.py @@ -49,14 +49,13 @@ def test_scenario_1_build_graph() -> None: actual_edges == expected_nx_edges ), f"Expected {expected_nx_edges} directed edges, found {actual_edges}" - # 7) Verify the traffic demands. In scenario_1.yaml, let's assume we have 4 demands. - # Adjust this to match your actual scenario_1.yaml. + # 7) Verify the traffic demands. expected_demands = 4 assert ( len(scenario.traffic_demands) == expected_demands ), f"Expected {expected_demands} traffic demands." - # 8) Check the new multi-rule failure policy for "any single link". + # 8) Check the multi-rule failure policy for "any single link". # This should have exactly 1 rule that picks exactly 1 link from all links. policy: FailurePolicy = scenario.failure_policy assert len(policy.rules) == 1, "Should only have 1 rule for 'anySingleLink'."