Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 117 additions & 77 deletions ngraph/blueprints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from __future__ import annotations

import copy
from dataclasses import dataclass
from typing import Any, Dict, List

Expand Down Expand Up @@ -151,33 +148,36 @@ def _expand_group(
if "use_blueprint" in group_def:
# Expand blueprint subgroups
blueprint_name: str = group_def["use_blueprint"]
bp = ctx.blueprints.get(blueprint_name)
if not bp:
raise ValueError(
f"Group '{group_name}' references unknown blueprint '{blueprint_name}'."
)

param_overrides: Dict[str, Any] = group_def.get("parameters", {})
coords = group_def.get("coords")

# For each subgroup in the blueprint, apply overrides and expand
for bp_sub_name, bp_sub_def in bp.groups.items():
merged_def = _apply_parameters(bp_sub_name, bp_sub_def, param_overrides)
if coords is not None and "coords" not in merged_def:
merged_def["coords"] = coords

_expand_group(
ctx,
parent_path=effective_path,
group_name=bp_sub_name,
group_def=merged_def,
blueprint_expansion=True,
)

# Expand blueprint adjacency
for adj_def in bp.adjacency:
_expand_blueprint_adjacency(ctx, adj_def, effective_path)

try:
bp = ctx.blueprints.get(blueprint_name)
if not bp:
raise ValueError(
f"Group '{group_name}' references unknown blueprint '{blueprint_name}'."
)

param_overrides: Dict[str, Any] = group_def.get("parameters", {})
coords = group_def.get("coords")

# For each subgroup in the blueprint, apply overrides and expand
for bp_sub_name, bp_sub_def in bp.groups.items():
merged_def = _apply_parameters(bp_sub_name, bp_sub_def, param_overrides)
if coords is not None and "coords" not in merged_def:
merged_def["coords"] = coords

_expand_group(
ctx,
parent_path=effective_path,
group_name=bp_sub_name,
group_def=merged_def,
blueprint_expansion=True,
)

# Expand blueprint adjacency
for adj_def in bp.adjacency:
_expand_blueprint_adjacency(ctx, adj_def, effective_path)

except Exception as e:
raise ValueError(f"Error expanding blueprint '{blueprint_name}': {e}")
else:
# It's a direct node group
node_count = group_def.get("node_count", 1)
Expand Down Expand Up @@ -217,11 +217,12 @@ def _expand_blueprint_adjacency(
target_rel = adj_def["target"]
pattern = adj_def.get("pattern", "mesh")
link_params = adj_def.get("link_params", {})
link_count = adj_def.get("link_count", 1)

src_path = _join_paths(parent_path, source_rel)
tgt_path = _join_paths(parent_path, target_rel)

_expand_adjacency_pattern(ctx, src_path, tgt_path, pattern, link_params)
_expand_adjacency_pattern(ctx, src_path, tgt_path, pattern, link_params, link_count)


def _expand_adjacency(
Expand All @@ -239,13 +240,16 @@ def _expand_adjacency(
source_path_raw = adj_def["source"]
target_path_raw = adj_def["target"]
pattern = adj_def.get("pattern", "mesh")
link_count = adj_def.get("link_count", 1)
link_params = adj_def.get("link_params", {})

# Convert to an absolute or relative path
source_path = _join_paths("", source_path_raw)
target_path = _join_paths("", target_path_raw)

_expand_adjacency_pattern(ctx, source_path, target_path, pattern, link_params)
_expand_adjacency_pattern(
ctx, source_path, target_path, pattern, link_params, link_count
)


def _expand_adjacency_pattern(
Expand All @@ -254,23 +258,25 @@ def _expand_adjacency_pattern(
target_path: str,
pattern: str,
link_params: Dict[str, Any],
link_count: int = 1,
) -> None:
"""
Generates Link objects for the chosen adjacency pattern among matched nodes.

Supported Patterns:
* "mesh": Connect every node from source side to every node on target side,
skipping self-loops, and deduplicating reversed pairs.
* "one_to_one": Pair each source node with exactly one target node (wrap-around).
* "ring": (Example pattern) For demonstration, connect nodes in a ring among
the union of source + target sets (ignores directionality).
skipping self-loops and deduplicating reversed pairs.
* "one_to_one": Pair each source node with exactly one target node (wrap-around),
requiring that the larger set size is an integer multiple
of the smaller set size.

Args:
ctx (DSLExpansionContext): The context containing the target network.
source_path (str): The path pattern identifying the source node group(s).
target_path (str): The path pattern identifying the target node group(s).
pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one", "ring").
pattern (str): The type of adjacency pattern (e.g., "mesh", "one_to_one").
link_params (Dict[str, Any]): Additional link parameters (capacity, cost, attrs).
link_count (int): Number of parallel links to create for each adjacency.
"""
source_node_groups = ctx.network.select_node_groups_by_path(source_path)
target_node_groups = ctx.network.select_node_groups_by_path(target_path)
Expand All @@ -292,21 +298,21 @@ def _expand_adjacency_pattern(
pair = tuple(sorted((sn.name, tn.name)))
if pair not in dedup_pairs:
dedup_pairs.add(pair)
_create_link(ctx.network, sn.name, tn.name, link_params)
_create_link(ctx.network, sn.name, tn.name, link_params, link_count)

elif pattern == "one_to_one":
s_count = len(source_nodes)
t_count = len(target_nodes)
bigger, smaller = max(s_count, t_count), min(s_count, t_count)
bigger_count = max(s_count, t_count)
smaller_count = min(s_count, t_count)

# Basic check for wrap-around scenario
if bigger % smaller != 0:
if bigger_count % smaller_count != 0:
raise ValueError(
f"one_to_one pattern requires either equal node counts "
f"or a valid wrap-around. Got {s_count} vs {t_count}."
f"one_to_one pattern requires sizes with a multiple factor. "
f"Got source={s_count}, target={t_count}."
)

for i in range(bigger):
for i in range(bigger_count):
if s_count >= t_count:
sn = source_nodes[i].name
tn = target_nodes[i % t_count].name
Expand All @@ -320,36 +326,46 @@ def _expand_adjacency_pattern(
pair = tuple(sorted((sn, tn)))
if pair not in dedup_pairs:
dedup_pairs.add(pair)
_create_link(ctx.network, sn, tn, link_params)
_create_link(ctx.network, sn, tn, link_params, link_count)
else:
raise ValueError(f"Unknown adjacency pattern: {pattern}")


def _create_link(
net: Network, source: str, target: str, link_params: Dict[str, Any]
net: Network,
source: str,
target: str,
link_params: Dict[str, Any],
link_count: int = 1,
) -> None:
"""
Creates and adds a Link to the network, applying capacity/cost/attrs from link_params.
Creates and adds one or more Links to the network, applying capacity, cost,
and attributes from link_params. Uses deep copies of the attributes to avoid
accidental shared mutations.

Args:
net (Network): The network to which the new link is added.
net (Network): The network to which the new link(s) is/are added.
source (str): Source node name for the link.
target (str): Target node name for the link.
link_params (Dict[str, Any]): A dict possibly containing 'capacity', 'cost',
and 'attrs' keys.
link_count (int): Number of parallel links to create between source and target.
"""
capacity = link_params.get("capacity", 1.0)
cost = link_params.get("cost", 1.0)
attrs = copy.deepcopy(link_params.get("attrs", {}))
import copy

link = Link(
source=source,
target=target,
capacity=capacity,
cost=cost,
attrs=attrs,
)
net.add_link(link)
for _ in range(link_count):
capacity = link_params.get("capacity", 1.0)
cost = link_params.get("cost", 1.0)
attrs = copy.deepcopy(link_params.get("attrs", {}))

link = Link(
source=source,
target=target,
capacity=capacity,
cost=cost,
attrs=attrs,
)
net.add_link(link)


def _process_direct_nodes(net: Network, network_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -403,14 +419,8 @@ def _process_direct_links(net: Network, network_data: Dict[str, Any]) -> None:
if source == target:
raise ValueError(f"Link cannot have the same source and target: {source}")
link_params = link_info.get("link_params", {})
link = Link(
source=source,
target=target,
capacity=link_params.get("capacity", 1.0),
cost=link_params.get("cost", 1.0),
attrs=link_params.get("attrs", {}),
)
net.add_link(link)
link_count = link_info.get("link_count", 1)
_create_link(net, source, target, link_params, link_count)


def _process_link_overrides(network: Network, network_data: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -475,7 +485,7 @@ def _update_links(
any_direction: bool = True,
) -> None:
"""
Update all Link objects between nodes matching 'source' and 'target' paths
Updates all Link objects between nodes matching 'source' and 'target' paths
with new parameters.

If any_direction=True, both (source->target) and (target->source) links
Expand Down Expand Up @@ -537,8 +547,11 @@ def _apply_parameters(
Applies user-provided parameter overrides to a blueprint subgroup.

Example:
If 'spine.node_count'=6 is in params_overrides,
we set 'node_count'=6 for the 'spine' subgroup.
If 'spine.node_count' = 6 is in params_overrides,
it sets 'node_count'=6 for the 'spine' subgroup.

If 'spine.node_attrs.hw_type' = 'Dell',
it sets subgroup_def['node_attrs']['hw_type'] = 'Dell'.

Args:
subgroup_name (str): Name of the subgroup in the blueprint (e.g. 'spine').
Expand All @@ -547,23 +560,50 @@ def _apply_parameters(
{'spine.node_count': 6, 'spine.node_attrs.hw_type': 'Dell'}.

Returns:
Dict[str, Any]: A copy of subgroup_def with parameter overrides applied.
Dict[str, Any]: A copy of subgroup_def with parameter overrides applied,
including nested dictionary fields if specified by dotted paths (e.g. node_attrs.foo).
"""
out = dict(subgroup_def)
import copy

out = copy.deepcopy(subgroup_def)

for key, val in params_overrides.items():
parts = key.split(".")
if parts[0] == subgroup_name and len(parts) > 1:
field_name = ".".join(parts[1:])
out[field_name] = val
# We have a dotted path that might refer to nested dictionaries.
subpath = parts[1:]
_apply_nested_path(out, subpath, val)

return out


def _apply_nested_path(
node_def: Dict[str, Any], path_parts: List[str], value: Any
) -> None:
"""
Recursively applies a path like ["node_attrs", "role"] to set node_def["node_attrs"]["role"] = value.
Creates intermediate dicts as needed.
"""
if not path_parts:
return
key = path_parts[0]
if len(path_parts) == 1:
node_def[key] = value
return

# Ensure that node_def[key] is a dict
if key not in node_def or not isinstance(node_def[key], dict):
node_def[key] = {}
_apply_nested_path(node_def[key], path_parts[1:], value)


def _join_paths(parent_path: str, rel_path: str) -> str:
"""
Joins two path segments according to NetGraph's DSL conventions:
- If rel_path starts with '/', strip the leading slash and treat it
as appended to parent_path if parent_path is not empty.
- Otherwise, simply append rel_path to parent_path if parent_path is non-empty.

- If rel_path starts with '/', strip the leading slash and treat it as
appended to parent_path if parent_path is not empty.
- Otherwise, simply append rel_path to parent_path if parent_path is non-empty.

Args:
parent_path (str): The existing path prefix.
Expand Down
Loading