diff --git a/ngraph/failure_policy.py b/ngraph/failure_policy.py new file mode 100644 index 0000000..8cc6efd --- /dev/null +++ b/ngraph/failure_policy.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field +from random import random + + +@dataclass(slots=True) +class FailurePolicy: + """ + Mapping from element tag to failure probability. + """ + + failure_probabilities: dict[str, float] = field(default_factory=dict) + distribution: str = "uniform" + + def test_failure(self, tag: str) -> bool: + if self.distribution == "uniform": + return random() < self.failure_probabilities.get(tag, 0) + else: + raise ValueError(f"Unsupported distribution: {self.distribution}") diff --git a/ngraph/network.py b/ngraph/network.py index d5981fd..39ba847 100644 --- a/ngraph/network.py +++ b/ngraph/network.py @@ -1,207 +1,102 @@ from __future__ import annotations import uuid -from typing import Any, Dict, List, Optional, NamedTuple, Hashable -import concurrent.futures +import base64 +from dataclasses import dataclass, field +from typing import Any, Dict -from ngraph.lib.graph import MultiDiGraph -from ngraph.lib.common import init_flow_graph -from ngraph.lib.max_flow import calc_max_flow +def new_base64_uuid() -> str: + """ + Generate a Base64-encoded UUID without padding (~22 characters). + """ + return base64.urlsafe_b64encode(uuid.uuid4().bytes).decode("ascii").rstrip("=") -class LinkID(NamedTuple): - src_node: Hashable - dst_node: Hashable - unique_id: Hashable +@dataclass(slots=True) +class Node: + """ + Represents a node in the network. + Each node is uniquely identified by its name, which is used as the key + in the Network's node dictionary. -class Node: - def __init__(self, node_id: str, node_type: str = "simple", **attributes: Dict): - self.node_id: str = node_id - self.node_type: str = node_type - self.attributes: Dict[str, Any] = { - "node_id": node_id, - "node_type": node_type, - "plane_ids": [], - "total_link_capacity": 0, - "non_transit": False, - "transit_only": False, # no local sinks/sources - "lat": 0, - "lon": 0, - } - self.update_attributes(**attributes) - self.sub_nodes: Dict[str, "Node"] = {} # Used if node_type is 'composite' - self.sub_links: Dict[str, "Link"] = {} # Used if node_type is 'composite' - - def add_sub_node(self, sub_node_id: str, **attributes: Any): - # Logic to add a sub-node to a composite node - ... - - def add_sub_link( - self, sub_link_id: str, sub_node1: str, sub_node2: str, **attributes: Any - ): - # Logic to add a sub-link to a composite node - ... - - def update_attributes(self, **attributes: Any): - """ - Update the attributes of the node. - """ - self.attributes.update(attributes) + :param name: The unique name of the node. + :param attrs: Optional extra metadata for the node. + """ + name: str + attrs: Dict[str, Any] = field(default_factory=dict) + +@dataclass(slots=True) class Link: - def __init__( - self, - node1: str, - node2: str, - link_id: Optional[LinkID] = None, - **attributes: Dict, - ): - self.link_id: str = ( - LinkID(node1, node2, str(uuid.uuid4())) if link_id is None else link_id - ) - self.node1: str = node1 - self.node2: str = node2 - self.attributes: Dict[str, Any] = { - "link_id": self.link_id, - "node1": node1, - "node2": node2, - "plane_ids": [], - "capacity": 0, - "metric": 0, - "distance": 0, - } - self.update_attributes(**attributes) - - def update_attributes(self, **attributes: Any): + """ + Represents a link connecting two nodes in the network. + + The 'source' and 'target' fields reference node names. A unique link ID + is auto-generated from the source, target, and a random Base64-encoded UUID, + allowing multiple distinct links between the same nodes. + + :param source: Unique name of the source node. + :param target: Unique name of the target node. + :param capacity: Link capacity (default 1.0). + :param latency: Link latency (default 1.0). + :param cost: Link cost (default 1.0). + :param attrs: Optional extra metadata for the link. + :param id: Auto-generated unique link identifier. + """ + + source: str + target: str + capacity: float = 1.0 + latency: float = 1.0 + cost: float = 1.0 + attrs: Dict[str, Any] = field(default_factory=dict) + id: str = field(init=False) + + def __post_init__(self) -> None: """ - Update the attributes of the link. + Auto-generate a unique link ID by combining the source, target, + and a random Base64-encoded UUID. """ - self.attributes.update(attributes) + self.id = f"{self.source}-{self.target}-{new_base64_uuid()}" +@dataclass(slots=True) class Network: - def __init__(self): - self.planes: Dict[str, MultiDiGraph] = {} # Key is plane_id - self.nodes: Dict[str, Node] = {} # Key is unique node_id - self.links: Dict[str, Link] = {} # Key is unique link_id + """ + A container for network nodes and links. - @staticmethod - def generate_edge_id(from_node: str, to_node: str, link_id: LinkID) -> str: - """ - Generate a unique edge ID for a link between two nodes. - """ - return LinkID(from_node, to_node, link_id[2]) - - def add_plane(self, plane_id: str): - self.planes[plane_id] = init_flow_graph(MultiDiGraph()) - - def add_node( - self, - node_id: str, - plane_ids: Optional[List[str]] = None, - node_type: str = "simple", - **attributes: Any, - ) -> str: - new_node = Node(node_id, node_type, **attributes) - self.nodes[new_node.node_id] = new_node - - if plane_ids is None: - plane_ids = self.planes.keys() - - for plane_id in plane_ids: - self.planes[plane_id].add_node(new_node.node_id, **attributes) - new_node.attributes["plane_ids"].append(plane_id) - return new_node.node_id - - def add_link( - self, - node1: str, - node2: str, - plane_ids: Optional[List[str]] = None, - **attributes: Any, - ) -> str: - new_link = Link(node1, node2, **attributes) - self.links[new_link.link_id] = new_link - - if plane_ids is None: - plane_ids = self.planes.keys() - - for plane_id in plane_ids: - self.planes[plane_id].add_edge( - node1, - node2, - edge_id=self.generate_edge_id(node1, node2, new_link.link_id), - capacity=new_link.attributes["capacity"] / len(plane_ids), - metric=new_link.attributes["metric"], - ) - self.planes[plane_id].add_edge( - node2, - node1, - edge_id=self.generate_edge_id(node2, node1, new_link.link_id), - capacity=new_link.attributes["capacity"] / len(plane_ids), - metric=new_link.attributes["metric"], - ) - new_link.attributes["plane_ids"].append(plane_id) - - # Update the total link capacity of the nodes - self.nodes[node1].attributes["total_link_capacity"] += new_link.attributes[ - "capacity" - ] - self.nodes[node2].attributes["total_link_capacity"] += new_link.attributes[ - "capacity" - ] - return new_link.link_id - - @staticmethod - def plane_max_flow(plane_id, plane_graph, src_node, dst_nodes) -> Optional[float]: + Nodes are stored in a dictionary keyed by their unique names. + Links are stored in a dictionary keyed by their auto-generated IDs. + The 'attrs' dict allows extra network metadata. + + :param nodes: Mapping from node name to Node. + :param links: Mapping from link id to Link. + :param attrs: Optional extra metadata for the network. + """ + + nodes: Dict[str, Node] = field(default_factory=dict) + links: Dict[str, Link] = field(default_factory=dict) + attrs: Dict[str, Any] = field(default_factory=dict) + + def add_node(self, node: Node) -> None: """ - Calculate the maximum flow between src and dst for a single plane. - There can be multiple dst nodes, they all are attached to the same virtual sink node. + Add a node to the network, keyed by its name. + + :param node: The Node to add. """ - if src_node in plane_graph: - for dst_node in dst_nodes: - if dst_node in plane_graph: - # add a pseudo node to the graph to act as the sink for the max flow calculation - plane_graph.add_edge( - dst_node, - "sink", - edge_id=-1, - capacity=2**32, - metric=0, - flow=0, - flows={}, - ) - if "sink" in plane_graph: - return calc_max_flow(plane_graph, src_node, "sink") - - def calc_max_flow( - self, src_nodes: List[str], dst_nodes: List[str] - ) -> Dict[str, Dict[str, float]]: + self.nodes[node.name] = node + + def add_link(self, link: Link) -> None: """ - Calculate the maximum flow between each of the src nodes and all of the dst nodes. - All the dst nodes are attached to the same virtual sink node. - Runs the calculation in parallel for all planes and src nodes. + Add a link to the network. Both source and target nodes must exist. + + :param link: The Link to add. + :raises ValueError: If the source or target node is not present. """ - with concurrent.futures.ProcessPoolExecutor() as executor: - future_to_plane_source = {} - for plane_id in self.planes: - for src_node in src_nodes: - future_to_plane_source[ - executor.submit( - self.plane_max_flow, - plane_id, - self.planes[plane_id], - src_node, - dst_nodes, - ) - ] = (plane_id, src_node, dst_nodes) - - results = {} - for future in concurrent.futures.as_completed(future_to_plane_source): - plane_id, src_node, dst_nodes = future_to_plane_source[future] - results.setdefault(src_node, {}) - results[src_node].setdefault(tuple(dst_nodes), {}) - results[src_node][tuple(dst_nodes)][plane_id] = future.result() - return results + if link.source not in self.nodes: + raise ValueError(f"Source node '{link.source}' not found in network.") + if link.target not in self.nodes: + raise ValueError(f"Target node '{link.target}' not found in network.") + self.links[link.id] = link diff --git a/ngraph/results.py b/ngraph/results.py new file mode 100644 index 0000000..174d815 --- /dev/null +++ b/ngraph/results.py @@ -0,0 +1,56 @@ +from dataclasses import dataclass, field +from typing import Any, Dict + + +@dataclass(slots=True) +class Results: + """ + A container for storing arbitrary key-value data that arises during workflow steps. + The data is organized by step name, then by key. + + Example usage: + results.put("Step1", "total_capacity", 123.45) + cap = results.get("Step1", "total_capacity") # returns 123.45 + all_caps = results.get_all("total_capacity") # might return {"Step1": 123.45, "Step2": 98.76} + """ + + # Internally, store per-step data in a nested dict: + # _store[step_name][key] = value + _store: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + def put(self, step_name: str, key: str, value: Any) -> None: + """ + Store a value under (step_name, key). + If the step_name sub-dict does not exist, it is created. + + :param step_name: The workflow step that produced the result. + :param key: A short label describing the data (e.g. "total_capacity"). + :param value: The actual data to store (can be any Python object). + """ + if step_name not in self._store: + self._store[step_name] = {} + self._store[step_name][key] = value + + def get(self, step_name: str, key: str, default: Any = None) -> Any: + """ + Retrieve the value from (step_name, key). If the key is missing, return `default`. + + :param step_name: The workflow step name. + :param key: The key under which the data was stored. + :param default: Value to return if the (step_name, key) is not present. + :return: The data, or `default` if not found. + """ + return self._store.get(step_name, {}).get(key, default) + + def get_all(self, key: str) -> Dict[str, Any]: + """ + Retrieve a dictionary of {step_name: value} for all step_names that contain the specified key. + + :param key: The key to look up in each step. + :return: A dict mapping step_name -> value for all steps that have stored something under 'key'. + """ + result = {} + for step_name, data in self._store.items(): + if key in data: + result[step_name] = data[key] + return result diff --git a/ngraph/scenario.py b/ngraph/scenario.py new file mode 100644 index 0000000..6d61fda --- /dev/null +++ b/ngraph/scenario.py @@ -0,0 +1,167 @@ +from __future__ import annotations +import yaml +from dataclasses import dataclass, field +from typing import Any, Dict, List + +from ngraph.network import Network, Node, Link +from ngraph.failure_policy import FailurePolicy +from ngraph.traffic_demand import TrafficDemand +from ngraph.results import Results +from ngraph.workflow.base import WorkflowStep, WORKFLOW_STEP_REGISTRY + + +@dataclass(slots=True) +class Scenario: + """ + Represents a complete scenario, including the network, failure policy, + traffic demands, workflow steps, and a results store. + + Usage: + scenario = Scenario.from_yaml(yaml_str) + scenario.run() + # Access scenario.results for workflow outputs + + Example YAML structure: + + network: + nodes: + JFK: + coords: [40.64, -73.78] + LAX: + coords: [33.94, -118.41] + links: + - source: JFK + target: LAX + capacity: 100 + latency: 50 + cost: 50 + attrs: { distance_km: 4000 } + + failure_policy: + failure_probabilities: + node: 0.001 + link: 0.002 + + traffic_demands: + - source: JFK + target: LAX + demand: 50 + + workflow: + - step_type: BuildGraph + name: build_graph + + :param network: The network model. + :param failure_policy: The policy for element failures. + :param traffic_demands: A list of traffic demands. + :param workflow: A list of WorkflowStep objects to be executed in order. + :param results: A Results object to store step outputs, summary, etc. + """ + + network: Network + failure_policy: FailurePolicy + traffic_demands: List[TrafficDemand] + workflow: List[WorkflowStep] + results: Results = field(default_factory=Results) + + def run(self) -> None: + """ + Execute the scenario's workflow steps in the given order. + Each WorkflowStep has access to this Scenario object and + can store output in scenario.results. + """ + for step in self.workflow: + step.run(self) + + @classmethod + def from_yaml(cls, yaml_str: str) -> Scenario: + """ + Construct a Scenario from a YAML string. + + This looks for top-level sections: + 'network', 'failure_policy', 'traffic_demands', and 'workflow'. + + See the class docstring for a short example of the expected structure. + """ + data = yaml.safe_load(yaml_str) + if not isinstance(data, dict): + raise ValueError("The provided YAML must map to a dictionary at top-level.") + + # 1) Build the network + network_data = data.get("network", {}) + network = cls._build_network(network_data) + + # 2) Build the failure policy + fp_data = data.get("failure_policy", {}) + failure_policy = FailurePolicy( + failure_probabilities=fp_data.get("failure_probabilities", {}) + ) + + # 3) Build traffic demands + traffic_demands_data = data.get("traffic_demands", []) + traffic_demands = [TrafficDemand(**td) for td in traffic_demands_data] + + # 4) Build workflow steps using the registry + workflow_data = data.get("workflow", []) + workflow_steps = cls._build_workflow_steps(workflow_data) + + return cls( + network=network, + failure_policy=failure_policy, + traffic_demands=traffic_demands, + workflow=workflow_steps, + ) + + @staticmethod + def _build_network(network_data: Dict[str, Any]) -> Network: + """ + Construct a Network object from a dictionary containing 'nodes' and 'links'. + """ + net = Network() + + # Add nodes + nodes = network_data.get("nodes", {}) + for node_name, node_attrs in nodes.items(): + net.add_node(Node(name=node_name, attrs=node_attrs or {})) + + # Add links + links = network_data.get("links", []) + for link_info in links: + link = Link( + source=link_info["source"], + target=link_info["target"], + capacity=link_info.get("capacity", 1.0), + latency=link_info.get("latency", 1.0), + cost=link_info.get("cost", 1.0), + attrs=link_info.get("attrs", {}), + ) + net.add_link(link) + + return net + + @staticmethod + def _build_workflow_steps( + workflow_data: List[Dict[str, Any]] + ) -> List[WorkflowStep]: + """ + Instantiate workflow steps listed in 'workflow_data' using WORKFLOW_STEP_REGISTRY. + """ + steps: List[WorkflowStep] = [] + + for step_info in workflow_data: + step_type = step_info.get("step_type") + if not step_type: + raise ValueError( + "Each workflow entry must have a 'step_type' field " + "indicating which WorkflowStep subclass to use." + ) + + step_cls = WORKFLOW_STEP_REGISTRY.get(step_type) + if not step_cls: + raise ValueError(f"Unrecognized 'step_type': {step_type}") + + # Remove 'step_type' so it doesn't clash with the step_class __init__ + step_args = {k: v for k, v in step_info.items() if k != "step_type"} + steps.append(step_cls(**step_args)) + + return steps diff --git a/ngraph/traffic_demand.py b/ngraph/traffic_demand.py new file mode 100644 index 0000000..dfb85b7 --- /dev/null +++ b/ngraph/traffic_demand.py @@ -0,0 +1,27 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Any, Dict + + +@dataclass(slots=True) +class TrafficDemand: + """ + Represents a single traffic demand in a network. + + Attributes: + source (str): The name of the source node. + target (str): The name of the target node. + priority (int): The priority of this traffic demand. Lower values indicate higher priority (default=0). + demand (float): The total demand volume (default=0.0). + demand_placed (float): The placed portion of the demand (default=0.0). + demand_unplaced (float): The unplaced portion of the demand (default=0.0). + attrs (dict[str, Any]): A dictionary for any additional attributes (default={}). + """ + + source: str + target: str + priority: int = 0 + demand: float = 0.0 + demand_placed: float = 0.0 + demand_unplaced: float = 0.0 + attrs: Dict[str, Any] = field(default_factory=dict) diff --git a/ngraph/workflow/__init__.py b/ngraph/workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ngraph/workflow/base.py b/ngraph/workflow/base.py new file mode 100644 index 0000000..0cc5a3f --- /dev/null +++ b/ngraph/workflow/base.py @@ -0,0 +1,38 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from abc import ABC, abstractmethod +from typing import Dict, Type, TYPE_CHECKING + +if TYPE_CHECKING: + # Only imported for type-checking; not at runtime, so no circular import occurs. + from ngraph.scenario import Scenario + +WORKFLOW_STEP_REGISTRY: Dict[str, Type["WorkflowStep"]] = {} + + +def register_workflow_step(step_type: str): + """ + A decorator that registers a WorkflowStep subclass under `step_type`. + """ + + def decorator(cls: Type["WorkflowStep"]): + WORKFLOW_STEP_REGISTRY[step_type] = cls + return cls + + return decorator + + +@dataclass +class WorkflowStep(ABC): + """ + Base class for all workflow steps. + """ + + name: str = "" + + @abstractmethod + def run(self, scenario: Scenario) -> None: + """ + Execute the workflow step logic. + """ + pass diff --git a/ngraph/workflow/build_graph.py b/ngraph/workflow/build_graph.py new file mode 100644 index 0000000..30cacf4 --- /dev/null +++ b/ngraph/workflow/build_graph.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import networkx as nx + +from ngraph.workflow.base import WorkflowStep, register_workflow_step + +if TYPE_CHECKING: + from ngraph.scenario import Scenario + + +@register_workflow_step("BuildGraph") +@dataclass +class BuildGraph(WorkflowStep): + """ + A workflow step that uses Scenario.network to build a NetworkX MultiDiGraph. + + Since links in Network are conceptually bidirectional but we need unique identifiers + for each direction, we add two directed edges per link: + - forward edge: key = link.id + - reverse edge: key = link.id + "_rev" + + The constructed graph is stored in scenario.results under (self.name, "graph"). + """ + + def run(self, scenario: Scenario) -> None: + # Create a MultiDiGraph to hold bidirectional edges + graph = nx.MultiDiGraph() + + # 1) Add nodes + for node_name, node in scenario.network.nodes.items(): + graph.add_node(node_name, **node.attrs) + + # 2) For each physical Link, add forward and reverse edges with unique keys + for link_id, link in scenario.network.links.items(): + # Forward edge uses link.id + graph.add_edge( + link.source, + link.target, + key=link.id, + capacity=link.capacity, + cost=link.cost, + latency=link.latency, + **link.attrs, + ) + # Reverse edge uses link.id + "_rev" + reverse_id = f"{link.id}_rev" + graph.add_edge( + link.target, + link.source, + key=reverse_id, + capacity=link.capacity, + cost=link.cost, + latency=link.latency, + **link.attrs, + ) + + # 3) Store the resulting graph + scenario.results.put(self.name, "graph", graph) diff --git a/notebooks/lib_examples.ipynb b/notebooks/lib_examples.ipynb index f519dba..93446ce 100644 --- a/notebooks/lib_examples.ipynb +++ b/notebooks/lib_examples.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -103,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -309,7 +309,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "ngraph-venv", "language": "python", "name": "python3" }, diff --git a/requirements.txt b/requirements.txt index f249d17..03952c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ geopy -networkx \ No newline at end of file +networkx +pyyaml \ No newline at end of file diff --git a/tests/test_io.py b/tests/lib/test_io.py similarity index 100% rename from tests/test_io.py rename to tests/lib/test_io.py diff --git a/tests/sample_data/sample_networks.py b/tests/sample_data/sample_networks.py deleted file mode 100644 index 25c4c70..0000000 --- a/tests/sample_data/sample_networks.py +++ /dev/null @@ -1,89 +0,0 @@ -import math - -import geopy.distance -import pytest - -from ngraph.network import Network - - -def calculate_latency_km(km_distance): - """Calculate latency in nanoseconds for a given distance in km.""" - speed_of_light_km_per_ns = 0.2 # Speed of light in fiber in km/ns - return math.ceil(km_distance / speed_of_light_km_per_ns) - - -# Coordinates of the airports (latitude, longitude) -airport_coords = { - "JFK": (40.641766, -73.780968), - "LAX": (33.941589, -118.40853), - "ORD": (41.974163, -87.907321), - "IAH": (29.99022, -95.336783), - "PHX": (33.437269, -112.007788), - "PHL": (39.874395, -75.242423), - "SAT": (29.424122, -98.493629), - "SAN": (32.733801, -117.193304), - "DFW": (32.899809, -97.040335), - "SJC": (37.363947, -121.928938), - "AUS": (30.197475, -97.666305), - "JAX": (30.332184, -81.655651), - "CMH": (39.961176, -82.998794), - "IND": (39.768403, -86.158068), - "CLT": (35.227087, -80.843127), - "SFO": (37.774929, -122.419416), - "SEA": (47.606209, -122.332071), - "DEN": (39.739236, -104.990251), - "DCA": (38.907192, -77.036871), -} - -connections = [ - ("JFK", "PHL", 100), - ("JFK", "DCA", 100), - ("LAX", "SFO", 100), - ("LAX", "SAN", 100), - ("ORD", "IND", 100), - ("ORD", "CMH", 100), - ("IAH", "DFW", 100), - ("IAH", "AUS", 100), - ("PHX", "LAX", 100), - ("SAT", "AUS", 100), - ("DFW", "AUS", 100), - ("SJC", "SFO", 100), - ("SJC", "LAX", 100), - ("CLT", "DCA", 100), - ("SEA", "SFO", 100), - ("DEN", "PHX", 100), - ("DEN", "SEA", 100), - ("SFO", "DEN", 200), - ("DEN", "ORD", 300), - ("PHX", "DFW", 300), - ("CMH", "JFK", 100), - ("IND", "CLT", 100), - ("IAH", "JAX", 100), - ("SAT", "AUS", 100), - ("SAT", "IAH", 100), - ("JAX", "CLT", 100), - ("SAN", "PHX", 100), - ("DFW", "IND", 100), -] - - -connections_with_capacity_and_metric = [] -for src, dst, cap in connections: - src_coord = airport_coords[src] - dst_coord = airport_coords[dst] - distance_km = geopy.distance.distance(src_coord, dst_coord).km - latency_ns = calculate_latency_km(distance_km) - connections_with_capacity_and_metric.append((src, dst, cap, latency_ns)) - - -@pytest.fixture -def network1(): - network = Network() - for plane_id in ["Plane1", "Plane2"]: - network.add_plane(plane_id) - - for src, dst, cap, metric in connections_with_capacity_and_metric: - network.add_node(src) - network.add_node(dst) - network.add_link(src, dst, capacity=cap, metric=metric) - return network diff --git a/tests/scenarios/__init__.py b/tests/scenarios/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/scenarios/scenario_1.yaml b/tests/scenarios/scenario_1.yaml new file mode 100644 index 0000000..f3c3f50 --- /dev/null +++ b/tests/scenarios/scenario_1.yaml @@ -0,0 +1,282 @@ +network: + nodes: + JFK: + coords: [40.641766, -73.780968] + LAX: + coords: [33.941589, -118.40853] + ORD: + coords: [41.974163, -87.907321] + IAH: + coords: [29.99022, -95.336783] + PHX: + coords: [33.437269, -112.007788] + PHL: + coords: [39.874395, -75.242423] + SAT: + coords: [29.424122, -98.493629] + SAN: + coords: [32.733801, -117.193304] + DFW: + coords: [32.899809, -97.040335] + SJC: + coords: [37.363947, -121.928938] + AUS: + coords: [30.197475, -97.666305] + JAX: + coords: [30.332184, -81.655651] + CMH: + coords: [39.961176, -82.998794] + IND: + coords: [39.768403, -86.158068] + CLT: + coords: [35.227087, -80.843127] + SFO: + coords: [37.774929, -122.419416] + SEA: + coords: [47.606209, -122.332071] + DEN: + coords: [39.739236, -104.990251] + DCA: + coords: [38.907192, -77.036871] + + links: + - source: JFK + target: PHL + capacity: 100 + latency: 756 + cost: 756 + attrs: + distance_km: 151.19 + + - source: JFK + target: DCA + capacity: 100 + latency: 1714 + cost: 1714 + attrs: + distance_km: 342.69 + + - source: LAX + target: SFO + capacity: 100 + latency: 2720 + cost: 2720 + attrs: + distance_km: 543.95 + + - source: LAX + target: SAN + capacity: 100 + latency: 893 + cost: 893 + attrs: + distance_km: 178.55 + + - source: ORD + target: IND + capacity: 100 + latency: 1395 + cost: 1395 + attrs: + distance_km: 278.84 + + - source: ORD + target: CMH + capacity: 100 + latency: 1386 + cost: 1386 + attrs: + distance_km: 277.17 + + - source: IAH + target: DFW + capacity: 100 + latency: 1802 + cost: 1802 + attrs: + distance_km: 360.25 + + - source: IAH + target: AUS + capacity: 100 + latency: 1133 + cost: 1133 + attrs: + distance_km: 226.53 + + - source: PHX + target: LAX + capacity: 100 + latency: 2982 + cost: 2982 + attrs: + distance_km: 596.35 + + - source: SAT + target: AUS + capacity: 100 + latency: 594 + cost: 594 + attrs: + distance_km: 118.69 + + - source: DFW + target: AUS + capacity: 100 + latency: 1539 + cost: 1539 + attrs: + distance_km: 307.63 + + - source: SJC + target: SFO + capacity: 100 + latency: 339 + cost: 339 + attrs: + distance_km: 67.79 + + - source: SJC + target: LAX + capacity: 100 + latency: 2468 + cost: 2468 + attrs: + distance_km: 493.55 + + - source: CLT + target: DCA + capacity: 100 + latency: 2654 + cost: 2654 + attrs: + distance_km: 530.64 + + - source: SEA + target: SFO + capacity: 100 + latency: 5460 + cost: 5460 + attrs: + distance_km: 1091.95 + + - source: DEN + target: PHX + capacity: 100 + latency: 4761 + cost: 4761 + attrs: + distance_km: 952.16 + + - source: DEN + target: SEA + capacity: 100 + latency: 6846 + cost: 6846 + attrs: + distance_km: 1369.13 + + - source: SFO + target: DEN + capacity: 200 + latency: 7754 + cost: 7754 + attrs: + distance_km: 1550.77 + + - source: DEN + target: ORD + capacity: 300 + latency: 7102 + cost: 7102 + attrs: + distance_km: 1420.28 + + - source: PHX + target: DFW + capacity: 300 + latency: 6900 + cost: 6900 + attrs: + distance_km: 1380 + + - source: CMH + target: JFK + capacity: 100 + latency: 3788 + cost: 3788 + attrs: + distance_km: 757.58 + + - source: IND + target: CLT + capacity: 100 + latency: 3419 + cost: 3419 + attrs: + distance_km: 683.66 + + - source: IAH + target: JAX + capacity: 100 + latency: 6746 + cost: 6746 + attrs: + distance_km: 1349.04 + + - source: SAT + target: AUS + capacity: 100 + latency: 594 + cost: 594 + attrs: + distance_km: 118.69 + + - source: SAT + target: IAH + capacity: 100 + latency: 1524 + cost: 1524 + attrs: + distance_km: 304.66 + + - source: JAX + target: CLT + capacity: 100 + latency: 2671 + cost: 2671 + attrs: + distance_km: 534.2 + + - source: SAN + target: PHX + capacity: 100 + latency: 2435 + cost: 2435 + attrs: + distance_km: 486.91 + + - source: DFW + target: IND + capacity: 100 + latency: 6889 + cost: 6889 + attrs: + distance_km: 1377.67 + +failure_policy: + failure_probabilities: + node: 0.001 + link: 0.002 + +traffic_demands: + - source: JFK + target: LAX + demand: 50 + - source: SAN + target: SEA + demand: 30 + +workflow: + - step_type: BuildGraph + name: build_graph diff --git a/tests/scenarios/test_scenario_1.py b/tests/scenarios/test_scenario_1.py new file mode 100644 index 0000000..3ce52be --- /dev/null +++ b/tests/scenarios/test_scenario_1.py @@ -0,0 +1,53 @@ +import pytest +import networkx as nx +from pathlib import Path + +from ngraph.scenario import Scenario + + +def test_scenario_1_build_graph() -> None: + """ + Integration test that verifies we can parse scenario_1.yaml, + run the BuildGraph step, and produce a valid NetworkX MultiDiGraph. + Also checks traffic demands and failure policy. + """ + + # 1) Load the YAML file + scenario_path = Path(__file__).parent / "scenario_1.yaml" + yaml_text = scenario_path.read_text() + + # 2) Parse into a Scenario object + scenario = Scenario.from_yaml(yaml_text) + + # 3) Run the scenario's workflow (in this YAML, there's only "BuildGraph") + scenario.run() + + # 4) Retrieve the graph built by BuildGraph + graph = scenario.results.get("build_graph", "graph") + assert isinstance( + graph, nx.MultiDiGraph + ), "Expected a MultiDiGraph in scenario.results." + + # 5) Check the total number of nodes matches what's listed in scenario_1.yaml + assert len(graph.nodes) == 19, f"Expected 19 nodes, found {len(graph.nodes)}" + + # 6) Each physical link becomes 2 directed edges in the MultiDiGraph. + # The YAML has 28 total link lines (including one duplicate SAT->AUS entry). + # So expected edges = 2 * 28 = 56. + expected_links = 28 + expected_nx_edges = expected_links * 2 + actual_edges = len(graph.edges) + assert ( + actual_edges == expected_nx_edges + ), f"Expected {expected_nx_edges} directed edges, found {actual_edges}" + + # 7) Verify the traffic demands + assert len(scenario.traffic_demands) == 2, "Expected 2 traffic demands." + demand_map = {(td.source, td.target): td.demand for td in scenario.traffic_demands} + # scenario_1.yaml has demands: (JFK->LAX=50), (SAN->SEA=30) + assert demand_map[("JFK", "LAX")] == 50 + assert demand_map[("SAN", "SEA")] == 30 + + # 8) Check the failure policy from YAML + assert scenario.failure_policy.failure_probabilities["node"] == 0.001 + assert scenario.failure_policy.failure_probabilities["link"] == 0.002 diff --git a/tests/test_failure_policy.py b/tests/test_failure_policy.py new file mode 100644 index 0000000..645e180 --- /dev/null +++ b/tests/test_failure_policy.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import patch + +from ngraph.failure_policy import FailurePolicy + + +def test_default_attributes(): + """ + Ensure default constructor creates an empty failure_probabilities dict, + and sets distribution to 'uniform'. + """ + policy = FailurePolicy() + assert policy.failure_probabilities == {} + assert policy.distribution == "uniform" + + +@patch("ngraph.failure_policy.random") +def test_test_failure_returns_true(mock_random): + """ + For a specific tag with nonzero probability, verify test_failure() returns True + when random() is less than that probability. + """ + policy = FailurePolicy(failure_probabilities={"node1": 0.7}) + + # Mock random to return 0.5 which is < 0.7 + mock_random.return_value = 0.5 + assert ( + policy.test_failure("node1") is True + ), "Should return True when random() < failure probability." + + +@patch("ngraph.failure_policy.random") +def test_test_failure_returns_false(mock_random): + """ + For a specific tag with nonzero probability, verify test_failure() returns False + when random() is not less than that probability. + """ + policy = FailurePolicy(failure_probabilities={"node1": 0.3}) + + # Mock random to return 0.4 which is > 0.3 + mock_random.return_value = 0.4 + assert ( + policy.test_failure("node1") is False + ), "Should return False when random() >= failure probability." + + +@patch("ngraph.failure_policy.random") +def test_test_failure_zero_probability(mock_random): + """ + A probability of zero means it should always return False, even if random() is also zero. + """ + policy = FailurePolicy(failure_probabilities={"node1": 0.0}) + + mock_random.return_value = 0.0 + assert ( + policy.test_failure("node1") is False + ), "Should always return False with probability = 0.0" + + +@patch("ngraph.failure_policy.random") +def test_test_failure_no_entry_for_tag(mock_random): + """ + If no entry for a given tag is found, probability defaults to 0.0 => always False. + """ + policy = FailurePolicy() + + mock_random.return_value = 0.0 + assert ( + policy.test_failure("unknown_tag") is False + ), "Unknown tag should default to 0.0 probability => always False." + + +def test_test_failure_non_uniform_distribution(): + """ + Verify that any distribution other than 'uniform' raises a ValueError. + """ + policy = FailurePolicy(distribution="non_uniform") + + with pytest.raises(ValueError) as exc_info: + policy.test_failure("node1") + + assert "Unsupported distribution" in str(exc_info.value) diff --git a/tests/test_link.py b/tests/test_link.py deleted file mode 100644 index 0f31a05..0000000 --- a/tests/test_link.py +++ /dev/null @@ -1,35 +0,0 @@ -import pytest -from ngraph.network import Link - - -class TestLink: - def test_link_creation(self): - link = Link("Node1", "Node2", attribute1="value1") - assert link.node1 == "Node1" - assert link.node2 == "Node2" - assert link.attributes["attribute1"] == "value1" - assert "link_id" in link.attributes - assert "capacity" in link.attributes - assert "metric" in link.attributes - - def test_link_creation_with_explicit_link_id(self): - link_id = "custom_link_id" - link = Link("Node1", "Node2", link_id=link_id) - assert link.link_id == link_id - - def test_link_creation_with_custom_capacity_and_metric(self): - link = Link("Node1", "Node2", capacity=100, metric=10) - assert link.attributes["capacity"] == 100 - assert link.attributes["metric"] == 10 - - def test_update_link_attributes(self): - link = Link("Node1", "Node2", attribute1="value1") - link.update_attributes(attribute1="new_value", metric=20) - assert link.attributes["attribute1"] == "new_value" - assert link.attributes["metric"] == 20 - - def test_update_link_attributes_with_new_attribute(self): - link = Link("Node1", "Node2") - link.update_attributes(new_attribute="new_value") - assert "new_attribute" in link.attributes - assert link.attributes["new_attribute"] == "new_value" diff --git a/tests/test_network.py b/tests/test_network.py index 1039194..1bdcf0f 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -1,113 +1,132 @@ import pytest +from ngraph.network import ( + Network, + Node, + Link, + new_base64_uuid +) -from ngraph.lib.graph import MultiDiGraph -from ngraph.network import Network, Node, Link, LinkID -from .sample_data.sample_networks import * +def test_new_base64_uuid_length_and_uniqueness(): + # Generate two Base64-encoded UUIDs + uuid1 = new_base64_uuid() + uuid2 = new_base64_uuid() + + # They should be strings without any padding characters + assert isinstance(uuid1, str) + assert isinstance(uuid2, str) + assert '=' not in uuid1 + assert '=' not in uuid2 + + # They are typically 22 characters long (Base64 without padding) + assert len(uuid1) == 22 + assert len(uuid2) == 22 + + # The two generated UUIDs should be unique + assert uuid1 != uuid2 -class TestNetwork: - def test_add_plane(self): - network = Network() - network.add_plane("Plane1") - assert "Plane1" in network.planes - assert isinstance(network.planes["Plane1"], MultiDiGraph) +def test_node_creation_default_attrs(): + # Create a Node with default attributes + node = Node("A") + assert node.name == "A" + assert node.attrs == {} - def test_generate_edge_id(self): - network = Network() - link = Link("Node1", "Node2", capacity=100) - edge_id = network.generate_edge_id("Node1", "Node2", link.link_id) - assert edge_id == LinkID("Node1", "Node2", link.link_id[2]) +def test_node_creation_custom_attrs(): + # Create a Node with custom attributes + custom_attrs = {"key": "value", "number": 42} + node = Node("B", attrs=custom_attrs) + assert node.name == "B" + assert node.attrs == custom_attrs - def test_add_node(self): - network = Network() - network.add_plane("Plane1") - node_id = network.add_node("Node1", plane_ids=["Plane1"]) - assert node_id in network.nodes - assert network.nodes[node_id].node_id == "Node1" - assert "Plane1" in network.nodes[node_id].attributes["plane_ids"] +def test_link_defaults_and_id_generation(): + # Create a Link; __post_init__ should auto-generate the id. + link = Link("A", "B") + + # Check default parameters are set correctly. + assert link.capacity == 1.0 + assert link.latency == 1.0 + assert link.cost == 1.0 + assert link.attrs == {} + + # Verify the link ID is correctly formatted and starts with "A-B-" + assert link.id.startswith("A-B-") + # Ensure there is a random UUID part appended after the prefix + assert len(link.id) > len("A-B-") - def test_add_link(self): - network = Network() - network.add_plane("Plane1") - network.add_node("Node1", plane_ids=["Plane1"]) - network.add_node("Node2", plane_ids=["Plane1"]) - link_id = network.add_link("Node1", "Node2", plane_ids=["Plane1"], capacity=100) - assert link_id in network.links - assert network.links[link_id].node1 == "Node1" - assert network.links[link_id].node2 == "Node2" - assert network.links[link_id].attributes["capacity"] == 100 - assert network.nodes["Node1"].attributes["total_link_capacity"] == 100 - assert network.nodes["Node2"].attributes["total_link_capacity"] == 100 +def test_link_custom_values(): + # Create a Link with custom values + custom_attrs = {"color": "red"} + link = Link("X", "Y", capacity=2.0, latency=3.0, cost=4.0, attrs=custom_attrs) + + assert link.source == "X" + assert link.target == "Y" + assert link.capacity == 2.0 + assert link.latency == 3.0 + assert link.cost == 4.0 + assert link.attrs == custom_attrs + # Check that the ID has the proper format + assert link.id.startswith("X-Y-") - def test_add_multiple_planes(self): - network = Network() - network.add_plane("Plane1") - network.add_plane("Plane2") - network.add_node("Node1", plane_ids=["Plane1", "Plane2"]) - assert "Plane1" in network.nodes["Node1"].attributes["plane_ids"] - assert "Plane2" in network.nodes["Node1"].attributes["plane_ids"] +def test_link_id_uniqueness(): + # Two links between the same nodes should have different IDs. + link1 = Link("A", "B") + link2 = Link("A", "B") + assert link1.id != link2.id - def test_add_node_to_all_planes_by_default(self): - network = Network() - network.add_plane("Plane1") - network.add_plane("Plane2") - network.add_node("Node1") - assert "Plane1" in network.nodes["Node1"].attributes["plane_ids"] - assert "Plane2" in network.nodes["Node1"].attributes["plane_ids"] +def test_network_add_node_and_link(): + # Create a network and add two nodes + network = Network() + node_a = Node("A") + node_b = Node("B") + network.add_node(node_a) + network.add_node(node_b) + + # The nodes should be present in the network + assert "A" in network.nodes + assert "B" in network.nodes + + # Create a link between the nodes and add it to the network + link = Link("A", "B") + network.add_link(link) + + # Check that the link is stored in the network using its auto-generated id. + assert link.id in network.links + # Verify that the stored link is the same object + assert network.links[link.id] is link - def test_add_link_to_all_planes_by_default(self): - network = Network() - network.add_plane("Plane1") - network.add_plane("Plane2") - network.add_node("Node1") - network.add_node("Node2") - link_id = network.add_link("Node1", "Node2") - assert "Plane1" in network.links[link_id].attributes["plane_ids"] - assert "Plane2" in network.links[link_id].attributes["plane_ids"] +def test_network_add_link_missing_source(): + # Create a network with only the target node + network = Network() + node_b = Node("B") + network.add_node(node_b) + + # Try to add a link whose source node does not exist. + link = Link("A", "B") + with pytest.raises(ValueError, match="Source node 'A' not found in network."): + network.add_link(link) - def test_update_total_link_capacity(self): - network = Network() - network.add_plane("Plane1") - network.add_node("Node1", plane_ids=["Plane1"]) - network.add_node("Node2", plane_ids=["Plane1"]) - network.add_link("Node1", "Node2", plane_ids=["Plane1"], capacity=100) - network.add_link("Node2", "Node1", plane_ids=["Plane1"], capacity=200) - assert network.nodes["Node1"].attributes["total_link_capacity"] == 300 - assert network.nodes["Node2"].attributes["total_link_capacity"] == 300 +def test_network_add_link_missing_target(): + # Create a network with only the source node + network = Network() + node_a = Node("A") + network.add_node(node_a) + + # Try to add a link whose target node does not exist. + link = Link("A", "B") + with pytest.raises(ValueError, match="Target node 'B' not found in network."): + network.add_link(link) - def test_plane_max_flow(self): - network = Network() - network.add_plane("Plane1") - network.add_node("Node1", plane_ids=["Plane1"]) - network.add_node("Node2", plane_ids=["Plane1"]) - network.add_link( - "Node1", "Node2", plane_ids=["Plane1"], capacity=100, metric=10 - ) - max_flow = network.plane_max_flow( - "Plane1", network.planes["Plane1"], "Node1", ["Node2"] - ) - assert max_flow == 100 +def test_network_attrs(): + # Test that extra network metadata can be stored in attrs. + network = Network(attrs={"network_type": "test"}) + assert network.attrs["network_type"] == "test" - def test_network1_max_flow_1(self, network1): - max_flow = network1.calc_max_flow(["LAX"], ["SFO"]) - assert max_flow == {"LAX": {("SFO",): {"Plane1": 200.0, "Plane2": 200.0}}} - assert "sink" not in network1.planes["Plane1"] - assert "sink" not in network1.planes["Plane2"] - - def test_network1_max_flow_2(self, network1): - max_flow = network1.calc_max_flow(["SEA", "LAX"], ["SFO"]) - assert max_flow == { - "LAX": {("SFO",): {"Plane1": 200.0, "Plane2": 200.0}}, - "SEA": {("SFO",): {"Plane1": 100.0, "Plane2": 100.0}}, - } - assert "sink" not in network1.planes["Plane1"] - assert "sink" not in network1.planes["Plane2"] - - def test_network1_max_flow_3(self, network1): - max_flow = network1.calc_max_flow(["SFO", "LAX"], ["JFK", "SEA"]) - assert max_flow == { - "LAX": {("JFK", "SEA"): {"Plane1": 200.0, "Plane2": 200.0}}, - "SFO": {("JFK", "SEA"): {"Plane1": 200.0, "Plane2": 200.0}}, - } - assert "sink" not in network1.planes["Plane1"] - assert "sink" not in network1.planes["Plane2"] +def test_duplicate_node_overwrite(): + # When adding nodes with the same name, the latter should overwrite the former. + network = Network() + node1 = Node("A", attrs={"data": 1}) + node2 = Node("A", attrs={"data": 2}) + + network.add_node(node1) + network.add_node(node2) # This should overwrite node1 + assert network.nodes["A"].attrs["data"] == 2 diff --git a/tests/test_node.py b/tests/test_node.py deleted file mode 100644 index 9f38152..0000000 --- a/tests/test_node.py +++ /dev/null @@ -1,17 +0,0 @@ -import pytest -from ngraph.network import Node - - -class TestNode: - def test_node_creation(self): - node = Node("Node1", node_type="simple", attribute1="value1", capacity=100) - assert node.node_id == "Node1" - assert node.node_type == "simple" - assert node.attributes["attribute1"] == "value1" - assert node.attributes["capacity"] == 100 - assert node.attributes["total_link_capacity"] == 0 # default value - - def test_update_node_attributes(self): - node = Node("Node1", attribute1="value1") - node.update_attributes(attribute1="new_value") - assert node.attributes["attribute1"] == "new_value" diff --git a/tests/test_result.py b/tests/test_result.py new file mode 100644 index 0000000..cef5f9e --- /dev/null +++ b/tests/test_result.py @@ -0,0 +1,69 @@ +import pytest +from ngraph.results import Results + + +def test_put_and_get(): + """ + Test that putting a value in the store and then getting it works as expected. + """ + results = Results() + results.put("Step1", "total_capacity", 123.45) + assert results.get("Step1", "total_capacity") == 123.45 + + +def test_get_with_default_missing_key(): + """ + Test retrieving a non-existent key with a default value. + """ + results = Results() + default_value = "not found" + assert results.get("StepX", "unknown_key", default_value) == default_value + + +def test_get_with_default_missing_step(): + """ + Test retrieving from a non-existent step with a default value. + """ + results = Results() + results.put("Step1", "some_key", 42) + default_value = "missing step" + assert results.get("Step2", "some_key", default_value) == default_value + + +def test_get_all_single_key_multiple_steps(): + """ + Test retrieving all values for a single key across multiple steps. + """ + results = Results() + results.put("Step1", "duration", 5.5) + results.put("Step2", "duration", 3.2) + results.put("Step2", "other_key", "unused") + results.put("Step3", "different_key", 99) + + durations = results.get_all("duration") + assert durations == {"Step1": 5.5, "Step2": 3.2} + + # No 'duration' key in Step3, so it won't appear in durations + assert "Step3" not in durations + + +def test_overwriting_value(): + """ + Test that storing a new value under an existing step/key pair overwrites the old value. + """ + results = Results() + results.put("Step1", "metric", 10) + assert results.get("Step1", "metric") == 10 + + # Overwrite + results.put("Step1", "metric", 20) + assert results.get("Step1", "metric") == 20 + + +def test_empty_results(): + """ + Test that a newly instantiated Results object does not have any stored data. + """ + results = Results() + assert results.get("StepX", "keyX") is None + assert results.get_all("keyX") == {} diff --git a/tests/test_scenario.py b/tests/test_scenario.py new file mode 100644 index 0000000..7f5fbdd --- /dev/null +++ b/tests/test_scenario.py @@ -0,0 +1,325 @@ +import pytest +import yaml + +from typing import TYPE_CHECKING +from dataclasses import dataclass + +from ngraph.scenario import Scenario +from ngraph.network import Network +from ngraph.failure_policy import FailurePolicy +from ngraph.traffic_demand import TrafficDemand +from ngraph.results import Results +from ngraph.workflow.base import ( + WorkflowStep, + register_workflow_step, + WORKFLOW_STEP_REGISTRY, +) + +if TYPE_CHECKING: + from ngraph.scenario import Scenario + + +# ------------------------------------------------------------------- +# Dummy workflow steps for testing +# ------------------------------------------------------------------- +@register_workflow_step("DoSmth") +@dataclass +class DoSmth(WorkflowStep): + """ + Example step that has an extra field 'some_param'. + """ + + some_param: int = 0 + + def run(self, scenario: Scenario) -> None: + """ + Perform a dummy operation for testing. + You might store something in scenario.results here if desired. + """ + pass + + +@register_workflow_step("DoSmthElse") +@dataclass +class DoSmthElse(WorkflowStep): + """ + Example step that has an extra field 'factor'. + """ + + factor: float = 1.0 + + def run(self, scenario: Scenario) -> None: + """ + Perform another dummy operation for testing. + """ + pass + + +@pytest.fixture +def valid_scenario_yaml() -> str: + """ + Returns a valid YAML string for constructing a Scenario with a small + realistic network of three nodes and two links, plus two traffic demands. + """ + return """ +network: + nodes: + NodeA: + role: ingress + location: somewhere + NodeB: + role: transit + NodeC: + role: egress + links: + - source: NodeA + target: NodeB + capacity: 10 + latency: 2 + cost: 5 + attrs: {some_attr: some_value} + - source: NodeB + target: NodeC + capacity: 20 + latency: 3 + cost: 4 + attrs: {} +failure_policy: + failure_probabilities: + node: 0.01 + link: 0.02 +traffic_demands: + - source: NodeA + target: NodeB + demand: 15 + - source: NodeA + target: NodeC + demand: 5 +workflow: + - step_type: DoSmth + name: Step1 + some_param: 42 + - step_type: DoSmthElse + name: Step2 + factor: 2.0 +""" + + +@pytest.fixture +def missing_step_type_yaml() -> str: + """ + Returns a YAML string missing the 'step_type' in the workflow, + which should raise a ValueError. + """ + return """ +network: + nodes: + NodeA: {} + NodeB: {} + links: + - source: NodeA + target: NodeB + capacity: 1 +failure_policy: + failure_probabilities: + node: 0.01 + link: 0.02 +traffic_demands: + - source: NodeA + target: NodeB + demand: 10 +workflow: + - name: StepWithoutType + some_param: 123 +""" + + +@pytest.fixture +def unrecognized_step_type_yaml() -> str: + """ + Returns a YAML string with an unrecognized step_type in the workflow, + which should raise a ValueError. + """ + return """ +network: + nodes: + NodeA: {} + NodeB: {} + links: + - source: NodeA + target: NodeB + capacity: 1 +failure_policy: + failure_probabilities: + node: 0.01 + link: 0.02 +traffic_demands: + - source: NodeA + target: NodeB + demand: 10 +workflow: + - step_type: NonExistentStep + name: BadStep + some_param: 999 +""" + + +@pytest.fixture +def extra_param_yaml() -> str: + """ + Returns a YAML string that attempts to pass an unsupported 'extra_param' + to a known workflow step type, which should raise a TypeError. + """ + return """ +network: + nodes: + NodeA: {} + NodeB: {} + links: + - source: NodeA + target: NodeB + capacity: 1 +traffic_demands: [] +failure_policy: + failure_probabilities: + node: 0.01 + link: 0.02 +workflow: + - step_type: DoSmth + name: StepWithExtra + some_param: 42 + extra_param: 99 +""" + + +def test_scenario_from_yaml_valid(valid_scenario_yaml: str) -> None: + """ + Tests that a Scenario can be correctly constructed from a valid YAML string. + Ensures that: + - Network has correct nodes and links + - FailurePolicy is set + - TrafficDemands are parsed + - Workflow steps are instantiated + - Results object is present + """ + scenario = Scenario.from_yaml(valid_scenario_yaml) + + # Check network + assert isinstance(scenario.network, Network) + assert len(scenario.network.nodes) == 3 # We defined NodeA, NodeB, NodeC + assert len(scenario.network.links) == 2 # NodeA->NodeB, NodeB->NodeC + + node_names = [node.name for node in scenario.network.nodes.values()] + assert "NodeA" in node_names + assert "NodeB" in node_names + assert "NodeC" in node_names + + links = list(scenario.network.links.values()) + assert len(links) == 2 + + link_ab = next((lk for lk in links if lk.source == "NodeA"), None) + link_bc = next((lk for lk in links if lk.source == "NodeB"), None) + + assert link_ab is not None, "Link from NodeA to NodeB was not found." + assert link_ab.target == "NodeB" + assert link_ab.capacity == 10 + assert link_ab.latency == 2 + assert link_ab.cost == 5 + assert link_ab.attrs.get("some_attr") == "some_value" + + assert link_bc is not None, "Link from NodeB to NodeC was not found." + assert link_bc.target == "NodeC" + assert link_bc.capacity == 20 + assert link_bc.latency == 3 + assert link_bc.cost == 4 + + # Check failure policy + assert isinstance(scenario.failure_policy, FailurePolicy) + assert scenario.failure_policy.failure_probabilities["node"] == 0.01 + assert scenario.failure_policy.failure_probabilities["link"] == 0.02 + + # Check traffic demands + assert len(scenario.traffic_demands) == 2 + demand_ab = next( + ( + d + for d in scenario.traffic_demands + if d.source == "NodeA" and d.target == "NodeB" + ), + None, + ) + demand_ac = next( + ( + d + for d in scenario.traffic_demands + if d.source == "NodeA" and d.target == "NodeC" + ), + None, + ) + assert demand_ab is not None, "Demand from NodeA to NodeB not found." + assert demand_ab.demand == 15 + + assert demand_ac is not None, "Demand from NodeA to NodeC not found." + assert demand_ac.demand == 5 + + # Check workflow + assert len(scenario.workflow) == 2 + step1 = scenario.workflow[0] + step2 = scenario.workflow[1] + + # Verify the step types come from the registry + assert step1.__class__ == WORKFLOW_STEP_REGISTRY["DoSmth"] + assert step2.__class__ == WORKFLOW_STEP_REGISTRY["DoSmthElse"] + + # Check the scenario results store + assert isinstance(scenario.results, Results) + + +def test_scenario_run(valid_scenario_yaml: str) -> None: + """ + Tests that calling scenario.run() executes each workflow step in order + without errors. This verifies the new .run() method introduced in the Scenario class. + """ + scenario = Scenario.from_yaml(valid_scenario_yaml) + + # Just ensure it runs without raising exceptions + scenario.run() + + # For a thorough test, one might check scenario.results or other side effects + # inside the steps themselves. Here, we just verify the workflow runs successfully. + assert True + + +def test_scenario_from_yaml_missing_step_type(missing_step_type_yaml: str) -> None: + """ + Tests that Scenario.from_yaml raises a ValueError if a workflow step + is missing the 'step_type' field. + """ + with pytest.raises(ValueError) as excinfo: + _ = Scenario.from_yaml(missing_step_type_yaml) + assert "must have a 'step_type' field" in str(excinfo.value) + + +def test_scenario_from_yaml_unrecognized_step_type( + unrecognized_step_type_yaml: str, +) -> None: + """ + Tests that Scenario.from_yaml raises a ValueError if the step_type + is not found in the WORKFLOW_STEP_REGISTRY. + """ + with pytest.raises(ValueError) as excinfo: + _ = Scenario.from_yaml(unrecognized_step_type_yaml) + assert "Unrecognized 'step_type'" in str(excinfo.value) + + +def test_scenario_from_yaml_unsupported_param(extra_param_yaml: str) -> None: + """ + Tests that Scenario.from_yaml raises a TypeError if a workflow step + in the YAML has an unsupported parameter. + """ + with pytest.raises(TypeError) as excinfo: + _ = Scenario.from_yaml(extra_param_yaml) + + # Typically the error message is something like: + # "DoSmth.__init__() got an unexpected keyword argument 'extra_param'" + assert "extra_param" in str(excinfo.value) diff --git a/tests/test_traffic_demand.py b/tests/test_traffic_demand.py new file mode 100644 index 0000000..c83fb4e --- /dev/null +++ b/tests/test_traffic_demand.py @@ -0,0 +1,62 @@ +import pytest +from ngraph.traffic_demand import TrafficDemand + + +def test_traffic_demand_defaults(): + """ + Test creation of TrafficDemand with default values. + """ + demand = TrafficDemand(source="NodeA", target="NodeB") + assert demand.source == "NodeA" + assert demand.target == "NodeB" + assert demand.priority == 0 + assert demand.demand == 0.0 + assert demand.demand_placed == 0.0 + assert demand.demand_unplaced == 0.0 + assert demand.attrs == {} + + +def test_traffic_demand_custom_values(): + """ + Test creation of TrafficDemand with custom values. + """ + demand = TrafficDemand( + source="SourceNode", + target="TargetNode", + priority=5, + demand=42.5, + demand_placed=10.0, + demand_unplaced=32.5, + attrs={"description": "test"}, + ) + assert demand.source == "SourceNode" + assert demand.target == "TargetNode" + assert demand.priority == 5 + assert demand.demand == 42.5 + assert demand.demand_placed == 10.0 + assert demand.demand_unplaced == 32.5 + assert demand.attrs == {"description": "test"} + + +def test_traffic_demand_attrs_modification(): + """ + Test that the attrs dictionary can be modified after instantiation. + """ + demand = TrafficDemand(source="NodeA", target="NodeB") + demand.attrs["key"] = "value" + assert demand.attrs == {"key": "value"} + + +def test_traffic_demand_partial_kwargs(): + """ + Test initialization with only a subset of fields, ensuring defaults work. + """ + demand = TrafficDemand(source="NodeA", target="NodeC", demand=15.0) + assert demand.source == "NodeA" + assert demand.target == "NodeC" + assert demand.demand == 15.0 + # Check the defaults + assert demand.priority == 0 + assert demand.demand_placed == 0.0 + assert demand.demand_unplaced == 0.0 + assert demand.attrs == {} diff --git a/tests/workflow/__init__.py b/tests/workflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/workflow/test_base.py b/tests/workflow/test_base.py new file mode 100644 index 0000000..1fb4dd8 --- /dev/null +++ b/tests/workflow/test_base.py @@ -0,0 +1,53 @@ +import pytest +from unittest.mock import MagicMock + +from ngraph.workflow.base import ( + WorkflowStep, + register_workflow_step, + WORKFLOW_STEP_REGISTRY, +) + + +def test_workflow_step_is_abstract() -> None: + """ + Verify that WorkflowStep is an abstract class and cannot be instantiated directly. + """ + with pytest.raises(TypeError) as exc_info: + WorkflowStep() # type: ignore + assert "abstract class" in str(exc_info.value) + + +def test_register_workflow_step_decorator() -> None: + """ + Verify that using the @register_workflow_step decorator registers + the subclass in the WORKFLOW_STEP_REGISTRY with the correct key. + """ + + @register_workflow_step("TestStep") + class TestStep(WorkflowStep): + def run(self, scenario) -> None: + pass + + # Check if the class is registered correctly + assert "TestStep" in WORKFLOW_STEP_REGISTRY + assert WORKFLOW_STEP_REGISTRY["TestStep"] == TestStep + + +def test_workflow_step_subclass_run_method() -> None: + """ + Verify that a concrete subclass of WorkflowStep can implement and call the run() method. + """ + + class ConcreteStep(WorkflowStep): + def run(self, scenario) -> None: + scenario.called = True + + mock_scenario = MagicMock() + step_instance = ConcreteStep(name="test_step") + step_instance.run(mock_scenario) + + # Check if run() was actually invoked + # e.g., we set scenario.called = True in run() + # but here we can also rely on MagicMock calls or attributes if needed + assert hasattr(mock_scenario, "called") and mock_scenario.called is True + assert step_instance.name == "test_step" diff --git a/tests/workflow/test_build_graph.py b/tests/workflow/test_build_graph.py new file mode 100644 index 0000000..71d92e4 --- /dev/null +++ b/tests/workflow/test_build_graph.py @@ -0,0 +1,140 @@ +import pytest +import networkx as nx +from unittest.mock import MagicMock + +from ngraph.workflow.build_graph import BuildGraph + + +class MockNode: + """ + A simple mock Node to simulate scenario.network.nodes[node_name]. + """ + + def __init__(self, attrs=None): + self.attrs = attrs or {} + + +class MockLink: + """ + A simple mock Link to simulate scenario.network.links[link_id]. + """ + + def __init__(self, link_id, source, target, capacity, cost, latency, attrs=None): + self.id = link_id + self.source = source + self.target = target + self.capacity = capacity + self.cost = cost + self.latency = latency + self.attrs = attrs or {} + + +@pytest.fixture +def mock_scenario(): + """ + Provides a mock Scenario object for testing. + """ + scenario = MagicMock() + scenario.network = MagicMock() + + # Sample data: + scenario.network.nodes = { + "A": MockNode(attrs={"type": "router", "location": "rack1"}), + "B": MockNode(attrs={"type": "router", "location": "rack2"}), + } + scenario.network.links = { + "L1": MockLink( + link_id="L1", + source="A", + target="B", + capacity=100, + cost=5, + latency=10, + attrs={"fiber": True}, + ), + "L2": MockLink( + link_id="L2", + source="B", + target="A", + capacity=50, + cost=2, + latency=5, + attrs={"copper": True}, + ), + } + + # Mock results object with a MagicMocked put method + scenario.results = MagicMock() + scenario.results.put = MagicMock() + return scenario + + +def test_build_graph_stores_multidigraph_in_results(mock_scenario): + """ + Ensure BuildGraph creates a MultiDiGraph, adds all nodes/edges, + and stores it in scenario.results with the key (step_name, "graph"). + """ + step = BuildGraph(name="MyBuildStep") + + step.run(mock_scenario) + + # Check scenario.results.put was called exactly once + mock_scenario.results.put.assert_called_once() + + # Extract the arguments from the .put call + call_args = mock_scenario.results.put.call_args + # Should look like ("MyBuildStep", "graph", ) + assert call_args[0][0] == "MyBuildStep" + assert call_args[0][1] == "graph" + created_graph = call_args[0][2] + assert isinstance( + created_graph, nx.MultiDiGraph + ), "Resulting object must be a MultiDiGraph." + + # Verify the correct nodes were added + assert set(created_graph.nodes()) == { + "A", + "B", + }, "MultiDiGraph should contain the correct node set." + # Check node attributes + assert created_graph.nodes["A"]["type"] == "router" + assert created_graph.nodes["B"]["location"] == "rack2" + + # Verify edges + # We expect two edges for each link: forward ("L1") and reverse ("L1_rev"), etc. + # So we should have 4 edges in total (2 from L1, 2 from L2). + assert ( + created_graph.number_of_edges() == 4 + ), "Should have two edges (forward/reverse) for each link." + + # Check forward edge from link 'L1' + edge_data = created_graph.get_edge_data("A", "B", key="L1") + assert edge_data is not None, "Forward edge 'L1' should exist from A to B." + assert edge_data["capacity"] == 100 + assert edge_data["cost"] == 5 + assert edge_data["latency"] == 10 + assert "fiber" in edge_data + + # Check reverse edge from link 'L1' + rev_edge_data = created_graph.get_edge_data("B", "A", key="L1_rev") + assert rev_edge_data is not None, "Reverse edge 'L1_rev' should exist from B to A." + assert ( + rev_edge_data["capacity"] == 100 + ), "Reverse edge should share the same capacity." + + # Check forward edge from link 'L2' + edge_data_l2 = created_graph.get_edge_data("B", "A", key="L2") + assert edge_data_l2 is not None, "Forward edge 'L2' should exist from B to A." + assert edge_data_l2["capacity"] == 50 + assert edge_data_l2["cost"] == 2 + assert "copper" in edge_data_l2 + + # Check reverse edge from link 'L2' + rev_edge_data_l2 = created_graph.get_edge_data("A", "B", key="L2_rev") + assert ( + rev_edge_data_l2 is not None + ), "Reverse edge 'L2_rev' should exist from A to B." + assert ( + rev_edge_data_l2["capacity"] == 50 + ), "Reverse edge should share the same capacity." + assert rev_edge_data_l2["latency"] == 5