Skip to content

Commit c90edaa

Browse files
committed
clean up of masking implementation
1 parent c5f17a2 commit c90edaa

21 files changed

+1947
-160
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ add_executable(netgraph_core_tests
2222
tests/cpp/flow_graph_tests.cpp
2323
tests/cpp/max_flow_tests.cpp
2424
tests/cpp/k_shortest_paths_tests.cpp
25-
tests/cpp/negative_safety_tests.cpp
25+
tests/cpp/negative_safety_tests.cpp
26+
tests/cpp/masking_tests.cpp
2627
)
2728
target_link_libraries(netgraph_core_tests PRIVATE netgraph_core GTest::gtest_main)
2829
target_include_directories(netgraph_core_tests PRIVATE tests/cpp)

bindings/python/module.cpp

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -409,9 +409,23 @@ PYBIND11_MODULE(_netgraph_core, m) {
409409
.def("place_on_dag", [](FlowState& fs, std::int32_t src, std::int32_t dst, const PredDAG& dag, double requested_flow, FlowPlacement placement){
410410
py::gil_scoped_release rel; auto placed = fs.place_on_dag(src, dst, dag, requested_flow, placement); py::gil_scoped_acquire acq; return placed;
411411
}, py::arg("src"), py::arg("dst"), py::arg("dag"), py::arg("requested_flow") = std::numeric_limits<double>::infinity(), py::arg("flow_placement") = FlowPlacement::Proportional)
412-
.def("place_max_flow", [](FlowState& fs, std::int32_t src, std::int32_t dst, FlowPlacement placement, bool shortest_path, bool require_capacity){
413-
py::gil_scoped_release rel; auto total = fs.place_max_flow(src, dst, placement, shortest_path, require_capacity); py::gil_scoped_acquire acq; return total;
414-
}, py::arg("src"), py::arg("dst"), py::arg("flow_placement") = FlowPlacement::Proportional, py::arg("shortest_path") = false, py::arg("require_capacity") = true)
412+
.def("place_max_flow", [](py::object self_obj, std::int32_t src, std::int32_t dst,
413+
FlowPlacement placement, bool shortest_path, bool require_capacity,
414+
py::object node_mask, py::object edge_mask){
415+
FlowState& fs = py::cast<FlowState&>(self_obj);
416+
// Get graph reference to validate mask lengths
417+
const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(self_obj.attr("_graph_ref"));
418+
auto node_bs = to_bool_span_from_numpy(node_mask, static_cast<std::size_t>(g.num_nodes()), "node_mask");
419+
auto edge_bs = to_bool_span_from_numpy(edge_mask, static_cast<std::size_t>(g.num_edges()), "edge_mask");
420+
py::gil_scoped_release rel;
421+
auto total = fs.place_max_flow(src, dst, placement, shortest_path, require_capacity,
422+
node_bs.view, edge_bs.view);
423+
py::gil_scoped_acquire acq;
424+
return total;
425+
}, py::arg("src"), py::arg("dst"),
426+
py::arg("flow_placement") = FlowPlacement::Proportional,
427+
py::arg("shortest_path") = false, py::arg("require_capacity") = true,
428+
py::kw_only(), py::arg("node_mask") = py::none(), py::arg("edge_mask") = py::none())
415429
.def("compute_min_cut", [](py::object self_obj, std::int32_t src, py::object node_mask, py::object edge_mask){
416430
const FlowState& fs = py::cast<const FlowState&>(self_obj);
417431
// Get graph reference to validate mask lengths
@@ -546,63 +560,29 @@ PYBIND11_MODULE(_netgraph_core, m) {
546560
.def_readwrite("diminishing_returns_epsilon_frac", &FlowPolicyConfig::diminishing_returns_epsilon_frac);
547561

548562
py::class_<FlowPolicy>(m, "FlowPolicy", py::dynamic_attr())
549-
.def("__init__", [](py::object self, py::object algs_obj, py::object graph_obj, const FlowPolicyConfig& cfg){
563+
.def("__init__", [](py::object self, py::object algs_obj, py::object graph_obj, FlowPolicyConfig cfg,
564+
py::object node_mask, py::object edge_mask){
550565
std::shared_ptr<Algorithms> algs = py::cast<std::shared_ptr<Algorithms>>(algs_obj);
551566
const PyGraph& pg = py::cast<const PyGraph&>(graph_obj);
567+
568+
// Convert masks to spans (FlowPolicy will copy the data)
569+
auto node_bs = to_bool_span_from_numpy(node_mask, static_cast<std::size_t>(pg.num_nodes), "node_mask");
570+
auto edge_bs = to_bool_span_from_numpy(edge_mask, static_cast<std::size_t>(pg.num_edges), "edge_mask");
571+
572+
// Update config with mask spans (will be copied by FlowPolicy constructor)
573+
cfg.node_mask = node_bs.view;
574+
cfg.edge_mask = edge_bs.view;
575+
552576
ExecutionContext ctx(algs, pg.handle);
553577
FlowPolicy* fp = self.cast<FlowPolicy*>();
554578
new (fp) FlowPolicy(ctx, cfg);
555579
self.attr("_algorithms_ref") = algs_obj;
556580
self.attr("_graph_ref") = graph_obj;
557581
},
558582
py::arg("algorithms"), py::arg("graph"), py::arg("config"),
559-
py::keep_alive<1, 2>() // self keeps algorithms alive
560-
)
561-
.def("__init__", [](py::object self, py::object algs_obj, py::object graph_obj,
562-
PathAlg path_alg,
563-
FlowPlacement flow_placement,
564-
EdgeSelection selection,
565-
bool require_capacity,
566-
int min_flow_count,
567-
std::optional<int> max_flow_count,
568-
std::optional<Cost> max_path_cost,
569-
std::optional<double> max_path_cost_factor,
570-
bool shortest_path,
571-
bool reoptimize_flows_on_each_placement,
572-
int max_no_progress_iterations,
573-
int max_total_iterations,
574-
bool diminishing_returns_enabled,
575-
int diminishing_returns_window,
576-
double diminishing_returns_epsilon_frac){
577-
std::shared_ptr<Algorithms> algs = py::cast<std::shared_ptr<Algorithms>>(algs_obj);
578-
const PyGraph& pg = py::cast<const PyGraph&>(graph_obj);
579-
ExecutionContext ctx(algs, pg.handle);
580-
FlowPolicy* fp = self.cast<FlowPolicy*>();
581-
new (fp) FlowPolicy(ctx, path_alg, flow_placement, selection, require_capacity,
582-
min_flow_count, max_flow_count, max_path_cost, max_path_cost_factor,
583-
shortest_path, reoptimize_flows_on_each_placement,
584-
max_no_progress_iterations, max_total_iterations,
585-
diminishing_returns_enabled, diminishing_returns_window,
586-
diminishing_returns_epsilon_frac);
587-
self.attr("_algorithms_ref") = algs_obj;
588-
self.attr("_graph_ref") = graph_obj;
589-
},
590-
py::arg("algorithms"), py::arg("graph"),
591-
py::arg("path_alg") = PathAlg::SPF,
592-
py::arg("flow_placement") = FlowPlacement::Proportional,
593-
py::arg("selection") = EdgeSelection{},
594-
py::arg("require_capacity") = true,
595-
py::arg("min_flow_count") = 1,
596-
py::arg("max_flow_count") = py::none(),
597-
py::arg("max_path_cost") = py::none(),
598-
py::arg("max_path_cost_factor") = py::none(),
599-
py::arg("shortest_path") = false,
600-
py::arg("reoptimize_flows_on_each_placement") = false,
601-
py::arg("max_no_progress_iterations") = 100,
602-
py::arg("max_total_iterations") = 10000,
603-
py::arg("diminishing_returns_enabled") = true,
604-
py::arg("diminishing_returns_window") = 8,
605-
py::arg("diminishing_returns_epsilon_frac") = 1e-3,
583+
py::kw_only(),
584+
py::arg("node_mask") = py::none(),
585+
py::arg("edge_mask") = py::none(),
606586
py::keep_alive<1, 2>() // self keeps algorithms alive
607587
)
608588
.def("flow_count", &FlowPolicy::flow_count)

include/netgraph/core/flow_policy.hpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ struct FlowPolicyConfig {
4747
bool diminishing_returns_enabled { true };
4848
int diminishing_returns_window { 8 };
4949
double diminishing_returns_epsilon_frac { 1e-3 };
50+
std::span<const bool> node_mask {}; // Optional node mask for failure exclusions (True = include)
51+
std::span<const bool> edge_mask {}; // Optional edge mask for failure exclusions (True = include)
5052
};
5153

5254
// FlowPolicy manages flow creation, placement, reoptimization for a single demand
@@ -61,7 +63,31 @@ class FlowPolicy {
6163
max_path_cost_factor_(cfg.max_path_cost_factor), reoptimize_flows_on_each_placement_(cfg.reoptimize_flows_on_each_placement),
6264
max_no_progress_iterations_(cfg.max_no_progress_iterations), max_total_iterations_(cfg.max_total_iterations),
6365
diminishing_returns_enabled_(cfg.diminishing_returns_enabled), diminishing_returns_window_(cfg.diminishing_returns_window),
64-
diminishing_returns_epsilon_frac_(cfg.diminishing_returns_epsilon_frac) {}
66+
diminishing_returns_epsilon_frac_(cfg.diminishing_returns_epsilon_frac)
67+
{
68+
// Validate mask sizes early for clearer error messages
69+
const auto* g = ctx_.graph.graph.get();
70+
if (g) {
71+
if (!cfg.node_mask.empty() && cfg.node_mask.size() != static_cast<std::size_t>(g->num_nodes())) {
72+
throw std::invalid_argument("FlowPolicy: node_mask length mismatch");
73+
}
74+
if (!cfg.edge_mask.empty() && cfg.edge_mask.size() != static_cast<std::size_t>(g->num_edges())) {
75+
throw std::invalid_argument("FlowPolicy: edge_mask length mismatch");
76+
}
77+
}
78+
79+
// Copy mask data if provided
80+
if (!cfg.node_mask.empty()) {
81+
node_mask_storage_.reset(new bool[cfg.node_mask.size()]);
82+
std::copy(cfg.node_mask.begin(), cfg.node_mask.end(), node_mask_storage_.get());
83+
node_mask_ = std::span<const bool>(node_mask_storage_.get(), cfg.node_mask.size());
84+
}
85+
if (!cfg.edge_mask.empty()) {
86+
edge_mask_storage_.reset(new bool[cfg.edge_mask.size()]);
87+
std::copy(cfg.edge_mask.begin(), cfg.edge_mask.end(), edge_mask_storage_.get());
88+
edge_mask_ = std::span<const bool>(edge_mask_storage_.get(), cfg.edge_mask.size());
89+
}
90+
}
6591

6692
FlowPolicy(const ExecutionContext& ctx,
6793
PathAlg path_alg,
@@ -140,6 +166,12 @@ class FlowPolicy {
140166
int diminishing_returns_window_ { 8 };
141167
double diminishing_returns_epsilon_frac_ { 1e-3 };
142168

169+
// Mask storage and views for failure exclusions
170+
std::unique_ptr<bool[]> node_mask_storage_;
171+
std::unique_ptr<bool[]> edge_mask_storage_;
172+
std::span<const bool> node_mask_ {};
173+
std::span<const bool> edge_mask_ {};
174+
143175
// State
144176
std::unordered_map<FlowIndex, FlowRecord, FlowIndexHash> flows_;
145177
Cost best_path_cost_ { std::numeric_limits<Cost>::max() };

include/netgraph/core/flow_state.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ class FlowState {
7272
[[nodiscard]] Flow place_max_flow(NodeId src, NodeId dst,
7373
FlowPlacement placement,
7474
bool shortest_path = false,
75-
bool require_capacity = true);
75+
bool require_capacity = true,
76+
std::span<const bool> node_mask = {},
77+
std::span<const bool> edge_mask = {});
7678

7779
// Compute min-cut with respect to current residual state, starting reachability
7880
// from source s on the residual graph (forward arcs: residual>MIN; reverse arcs:

include/netgraph/core/shortest_paths.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ struct PredDAG {
4545
// For EB semantics this means re-running SPF on updated residuals
4646
// will remove saturated next-hops and therefore *change* the fixed
4747
// equal-split set (progressive behavior).
48-
// - node_mask: if provided, node_mask[v]==true means node v is allowed (false excludes it)
48+
// - node_mask: if provided, node_mask[v]==true means node v is allowed (false excludes it).
49+
// If the source node is masked (node_mask[src]==false), returns an empty predecessor DAG
50+
// with all distances at infinity, as no traversal can begin from an excluded source.
4951
// - edge_mask: if provided, edge_mask[e]==true means edge e is allowed (false excludes it)
5052
[[nodiscard]] std::pair<std::vector<Cost>, PredDAG>
5153
shortest_paths(const StrictMultiDiGraph& g, NodeId src,

python/netgraph_core/_docs.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,18 @@ def place_max_flow(
205205
flow_placement: FlowPlacement = FlowPlacement.PROPORTIONAL,
206206
shortest_path: bool = False,
207207
require_capacity: bool = True,
208+
*,
209+
node_mask: "Optional[np.ndarray]" = None,
210+
edge_mask: "Optional[np.ndarray]" = None,
208211
) -> float:
209212
"""Place maximum flow from src to dst.
210213
211214
Args:
212215
require_capacity: Whether to require edges to have capacity.
213216
- True (default): Routes adapt to residuals (SDN/TE, progressive fill).
214217
- False: Routes based on costs only (IP/IGP, fixed routing).
218+
node_mask: Optional 1-D bool array (True=allowed). Copied for thread safety.
219+
edge_mask: Optional 1-D bool array (True=allowed). Copied for thread safety.
215220
216221
For *IP-style ECMP* max-flow, use: require_capacity=False, shortest_path=True,
217222
flow_placement=EQUAL_BALANCED.
@@ -285,29 +290,26 @@ class FlowPolicy:
285290
When static_paths is empty the policy may refresh the DAG per round using
286291
residual-aware shortest paths. This progressively prunes saturated next-hops
287292
(traffic-engineering style) and differs from one-shot ECMP admission.
293+
294+
Args:
295+
algorithms: Algorithms instance (kept alive by FlowPolicy)
296+
graph: Graph handle (kept alive by FlowPolicy)
297+
config: FlowPolicyConfig with all policy parameters
298+
node_mask: Optional 1-D bool array (True=allowed). **Copied for thread safety.**
299+
edge_mask: Optional 1-D bool array (True=allowed). **Copied for thread safety.**
300+
301+
Raises:
302+
TypeError: If masks have wrong dtype, ndim, or length.
288303
"""
289304

290305
def __init__(
291306
self,
292307
algorithms: "Algorithms",
293308
graph: "Graph",
294-
config: Optional["FlowPolicyConfig"] = None,
295-
/,
309+
config: "FlowPolicyConfig",
296310
*,
297-
path_alg: PathAlg = PathAlg.SPF,
298-
flow_placement: FlowPlacement = FlowPlacement.PROPORTIONAL,
299-
selection: EdgeSelection = ..., # default provided by runtime binding
300-
min_flow_count: int = 1,
301-
max_flow_count: Optional[int] = None,
302-
max_path_cost: Optional[int] = None,
303-
max_path_cost_factor: Optional[float] = None,
304-
shortest_path: bool = False,
305-
reoptimize_flows_on_each_placement: bool = False,
306-
max_no_progress_iterations: int = 100,
307-
max_total_iterations: int = 10000,
308-
diminishing_returns_enabled: bool = True,
309-
diminishing_returns_window: int = 8,
310-
diminishing_returns_epsilon_frac: float = 1e-3,
311+
node_mask: "Optional[np.ndarray]" = None,
312+
edge_mask: "Optional[np.ndarray]" = None,
311313
) -> None: ...
312314

313315
def flow_count(self) -> int: ...
@@ -419,13 +421,21 @@ def spf(
419421
selection: Edge selection policy
420422
residual: Optional 1-D float64 array of residuals. **Copied for thread safety.**
421423
node_mask: Optional 1-D bool mask (length num_nodes). **Copied for thread safety.**
424+
True = node allowed, False = node excluded.
425+
**If source node is masked (False), returns empty DAG with all distances at infinity.**
422426
edge_mask: Optional 1-D bool mask (length num_edges). **Copied for thread safety.**
427+
True = edge allowed, False = edge excluded.
423428
multipath: Whether to track multiple equal-cost paths
424429
dtype: "float64" (inf for unreachable) or "int64" (max for unreachable)
425430
426431
Returns:
427432
(distances, predecessor_dag)
428433
434+
Note:
435+
When the source node is masked out (node_mask[src] == False), the algorithm
436+
immediately returns an empty predecessor DAG with all distances set to infinity,
437+
as no traversal can begin from an excluded source.
438+
429439
Raises:
430440
TypeError: If arrays have wrong dtype, ndim, or length.
431441
ValueError: If src/dst out of range.
@@ -518,6 +528,7 @@ def batch_max_flow(
518528
edge_masks: Optional[list["np.ndarray"]] = None,
519529
flow_placement: FlowPlacement = FlowPlacement.PROPORTIONAL,
520530
shortest_path: bool = False,
531+
require_capacity: bool = True,
521532
with_edge_flows: bool = False,
522533
with_reachable: bool = False,
523534
with_residuals: bool = False,
@@ -528,6 +539,7 @@ def batch_max_flow(
528539
pairs: int32 array of shape [B, 2] with (src, dst) pairs
529540
node_masks: Optional list of B bool masks. **Each copied for thread safety.**
530541
edge_masks: Optional list of B bool masks. **Each copied for thread safety.**
542+
require_capacity: If True, exclude saturated edges (SDN/TE). If False, route by cost only (IP/IGP).
531543
532544
Returns:
533545
List of FlowSummary objects, one per pair.

0 commit comments

Comments
 (0)