@@ -409,9 +409,23 @@ PYBIND11_MODULE(_netgraph_core, m) {
409409 .def (" place_on_dag" , [](FlowState& fs, std::int32_t src, std::int32_t dst, const PredDAG& dag, double requested_flow, FlowPlacement placement){
410410 py::gil_scoped_release rel; auto placed = fs.place_on_dag (src, dst, dag, requested_flow, placement); py::gil_scoped_acquire acq; return placed;
411411 }, py::arg (" src" ), py::arg (" dst" ), py::arg (" dag" ), py::arg (" requested_flow" ) = std::numeric_limits<double >::infinity (), py::arg (" flow_placement" ) = FlowPlacement::Proportional)
412- .def (" place_max_flow" , [](FlowState& fs, std::int32_t src, std::int32_t dst, FlowPlacement placement, bool shortest_path, bool require_capacity){
413- py::gil_scoped_release rel; auto total = fs.place_max_flow (src, dst, placement, shortest_path, require_capacity); py::gil_scoped_acquire acq; return total;
414- }, py::arg (" src" ), py::arg (" dst" ), py::arg (" flow_placement" ) = FlowPlacement::Proportional, py::arg (" shortest_path" ) = false , py::arg (" require_capacity" ) = true )
412+ .def (" place_max_flow" , [](py::object self_obj, std::int32_t src, std::int32_t dst,
413+ FlowPlacement placement, bool shortest_path, bool require_capacity,
414+ py::object node_mask, py::object edge_mask){
415+ FlowState& fs = py::cast<FlowState&>(self_obj);
416+ // Get graph reference to validate mask lengths
417+ const StrictMultiDiGraph& g = py::cast<const StrictMultiDiGraph&>(self_obj.attr (" _graph_ref" ));
418+ auto node_bs = to_bool_span_from_numpy (node_mask, static_cast <std::size_t >(g.num_nodes ()), " node_mask" );
419+ auto edge_bs = to_bool_span_from_numpy (edge_mask, static_cast <std::size_t >(g.num_edges ()), " edge_mask" );
420+ py::gil_scoped_release rel;
421+ auto total = fs.place_max_flow (src, dst, placement, shortest_path, require_capacity,
422+ node_bs.view , edge_bs.view );
423+ py::gil_scoped_acquire acq;
424+ return total;
425+ }, py::arg (" src" ), py::arg (" dst" ),
426+ py::arg (" flow_placement" ) = FlowPlacement::Proportional,
427+ py::arg (" shortest_path" ) = false , py::arg (" require_capacity" ) = true ,
428+ py::kw_only (), py::arg (" node_mask" ) = py::none (), py::arg (" edge_mask" ) = py::none ())
415429 .def (" compute_min_cut" , [](py::object self_obj, std::int32_t src, py::object node_mask, py::object edge_mask){
416430 const FlowState& fs = py::cast<const FlowState&>(self_obj);
417431 // Get graph reference to validate mask lengths
@@ -546,63 +560,29 @@ PYBIND11_MODULE(_netgraph_core, m) {
546560 .def_readwrite (" diminishing_returns_epsilon_frac" , &FlowPolicyConfig::diminishing_returns_epsilon_frac);
547561
548562 py::class_<FlowPolicy>(m, " FlowPolicy" , py::dynamic_attr ())
549- .def (" __init__" , [](py::object self, py::object algs_obj, py::object graph_obj, const FlowPolicyConfig& cfg){
563+ .def (" __init__" , [](py::object self, py::object algs_obj, py::object graph_obj, FlowPolicyConfig cfg,
564+ py::object node_mask, py::object edge_mask){
550565 std::shared_ptr<Algorithms> algs = py::cast<std::shared_ptr<Algorithms>>(algs_obj);
551566 const PyGraph& pg = py::cast<const PyGraph&>(graph_obj);
567+
568+ // Convert masks to spans (FlowPolicy will copy the data)
569+ auto node_bs = to_bool_span_from_numpy (node_mask, static_cast <std::size_t >(pg.num_nodes ), " node_mask" );
570+ auto edge_bs = to_bool_span_from_numpy (edge_mask, static_cast <std::size_t >(pg.num_edges ), " edge_mask" );
571+
572+ // Update config with mask spans (will be copied by FlowPolicy constructor)
573+ cfg.node_mask = node_bs.view ;
574+ cfg.edge_mask = edge_bs.view ;
575+
552576 ExecutionContext ctx (algs, pg.handle );
553577 FlowPolicy* fp = self.cast <FlowPolicy*>();
554578 new (fp) FlowPolicy (ctx, cfg);
555579 self.attr (" _algorithms_ref" ) = algs_obj;
556580 self.attr (" _graph_ref" ) = graph_obj;
557581 },
558582 py::arg (" algorithms" ), py::arg (" graph" ), py::arg (" config" ),
559- py::keep_alive<1 , 2 >() // self keeps algorithms alive
560- )
561- .def (" __init__" , [](py::object self, py::object algs_obj, py::object graph_obj,
562- PathAlg path_alg,
563- FlowPlacement flow_placement,
564- EdgeSelection selection,
565- bool require_capacity,
566- int min_flow_count,
567- std::optional<int > max_flow_count,
568- std::optional<Cost> max_path_cost,
569- std::optional<double > max_path_cost_factor,
570- bool shortest_path,
571- bool reoptimize_flows_on_each_placement,
572- int max_no_progress_iterations,
573- int max_total_iterations,
574- bool diminishing_returns_enabled,
575- int diminishing_returns_window,
576- double diminishing_returns_epsilon_frac){
577- std::shared_ptr<Algorithms> algs = py::cast<std::shared_ptr<Algorithms>>(algs_obj);
578- const PyGraph& pg = py::cast<const PyGraph&>(graph_obj);
579- ExecutionContext ctx (algs, pg.handle );
580- FlowPolicy* fp = self.cast <FlowPolicy*>();
581- new (fp) FlowPolicy (ctx, path_alg, flow_placement, selection, require_capacity,
582- min_flow_count, max_flow_count, max_path_cost, max_path_cost_factor,
583- shortest_path, reoptimize_flows_on_each_placement,
584- max_no_progress_iterations, max_total_iterations,
585- diminishing_returns_enabled, diminishing_returns_window,
586- diminishing_returns_epsilon_frac);
587- self.attr (" _algorithms_ref" ) = algs_obj;
588- self.attr (" _graph_ref" ) = graph_obj;
589- },
590- py::arg (" algorithms" ), py::arg (" graph" ),
591- py::arg (" path_alg" ) = PathAlg::SPF,
592- py::arg (" flow_placement" ) = FlowPlacement::Proportional,
593- py::arg (" selection" ) = EdgeSelection{},
594- py::arg (" require_capacity" ) = true ,
595- py::arg (" min_flow_count" ) = 1 ,
596- py::arg (" max_flow_count" ) = py::none (),
597- py::arg (" max_path_cost" ) = py::none (),
598- py::arg (" max_path_cost_factor" ) = py::none (),
599- py::arg (" shortest_path" ) = false ,
600- py::arg (" reoptimize_flows_on_each_placement" ) = false ,
601- py::arg (" max_no_progress_iterations" ) = 100 ,
602- py::arg (" max_total_iterations" ) = 10000 ,
603- py::arg (" diminishing_returns_enabled" ) = true ,
604- py::arg (" diminishing_returns_window" ) = 8 ,
605- py::arg (" diminishing_returns_epsilon_frac" ) = 1e-3 ,
583+ py::kw_only (),
584+ py::arg (" node_mask" ) = py::none (),
585+ py::arg (" edge_mask" ) = py::none (),
606586 py::keep_alive<1 , 2 >() // self keeps algorithms alive
607587 )
608588 .def (" flow_count" , &FlowPolicy::flow_count)
0 commit comments