diff --git a/include/visp/ml.h b/include/visp/ml.h index 4b9af39..ed108a4 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -42,6 +42,7 @@ struct backend_device { VISP_API ggml_type preferred_float_type() const; VISP_API tensor_data_layout preferred_layout() const; VISP_API size_t total_memory() const; + VISP_API size_t max_alloc() const; operator ggml_backend_t() const { return handle.get(); } }; diff --git a/include/visp/vision.h b/include/visp/vision.h index 8fa4211..4daeaab 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -150,9 +150,11 @@ struct birefnet_params { using birefnet_buffers = std::array; -VISP_API birefnet_params birefnet_detect_params(model_file const&, i32x2 dynamic_extent = {}); +VISP_API birefnet_params birefnet_detect_params( + model_file const&, i32x2 dynamic_extent = {}, size_t max_alloc = SIZE_MAX); VISP_API birefnet_buffers birefnet_precompute(model_ref, birefnet_params const&); -VISP_API i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const&); +VISP_API i32x2 birefnet_image_extent( + i32x2 input_extent, birefnet_params const&, size_t max_alloc = SIZE_MAX); VISP_API image_data birefnet_process_input(image_view, birefnet_params const&); VISP_API image_data birefnet_process_output( diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index 25f2ee9..3e37434 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -399,7 +399,7 @@ void run_birefnet(cli_args const& args) { require_inputs(args.inputs, 1, ""); image_data image = image_load(args.inputs[0]); - birefnet_params params = birefnet_detect_params(file, image.extent); + birefnet_params params = birefnet_detect_params(file, image.extent, backend.max_alloc()); image_data input_data = birefnet_process_input(image, params); i32x2 extent = params.image_extent; diff --git a/src/visp/arch/birefnet.cpp b/src/visp/arch/birefnet.cpp index 43bee9b..b294b87 100644 --- a/src/visp/arch/birefnet.cpp +++ b/src/visp/arch/birefnet.cpp @@ -19,7 +19,7 @@ tensor mlp(model_ref m, tensor x) { // Ensures that the tensor's data is not overwritten during computation. tensor make_constant(tensor x, tensor_name name) { ggml_set_name(x, name.c_str()); - ggml_set_input(x); // allocate at the beginning of the graph buffer + ggml_set_input(x); // allocate at the beginning of the graph buffer ggml_set_output(x); // don't reuse memory for computations return x; } @@ -611,10 +611,18 @@ swin_params swin_detect_params(model_file const& f) { } } -i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p) { +i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p, size_t max_alloc) { i32x2 extent{p.image_size, p.image_size}; if (p.image_size == -1) { ASSERT(input_extent[0] > 0 && input_extent[1] > 0); + // largest layer in BiRefNet-dynamic is input for 240-channel conv-2d at full resolution + size_t req_alloc = size_t(input_extent[0]) * input_extent[1] * 240ULL * sizeof(float); + if (req_alloc > max_alloc) { + float scale = std::sqrt(float(max_alloc) / float(req_alloc)); + input_extent = { + std::max(1, int(input_extent[0] * scale) - p.image_multiple), + std::max(1, int(input_extent[1] * scale) - p.image_multiple)}; + } extent = { next_multiple(input_extent[0], p.image_multiple), next_multiple(input_extent[1], p.image_multiple)}; @@ -622,14 +630,16 @@ i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p) { return extent; } -birefnet_params birefnet_detect_params(model_file const& f, i32x2 dynamic_extent) { +birefnet_params birefnet_detect_params( + model_file const& f, i32x2 dynamic_extent, size_t max_alloc) { + if (std::string_view arch = f.arch(); arch != "birefnet") { throw except("Architecture expected to be 'birefnet', but was '{}' ({})", arch, f.path); } birefnet_params p; p.image_size = f.get_int("birefnet.image_size"); p.image_multiple = f.get_int("birefnet.image_multiple"); - p.image_extent = birefnet_image_extent(dynamic_extent, p); + p.image_extent = birefnet_image_extent(dynamic_extent, p, max_alloc); p.encoder = swin_detect_params(f); return p; } diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index 1ac564c..1c8e55d 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -108,6 +108,11 @@ size_t backend_device::total_memory() const { return total; } +size_t backend_device::max_alloc() const { + const size_t vulkan_max = 4 * 1024 * 1024 * 1024ULL; // TODO: query from backend + return type() == backend_type::cpu ? SIZE_MAX : vulkan_max; +} + void backend_set_n_threads(backend_device& b, int n_threads) { if (b.type() != backend_type::cpu) { return; diff --git a/src/visp/vision.cpp b/src/visp/vision.cpp index 7c9c424..bd8216e 100644 --- a/src/visp/vision.cpp +++ b/src/visp/vision.cpp @@ -90,7 +90,7 @@ birefnet_model birefnet_load_model(char const* filepath, backend_device const& d } image_data birefnet_compute(birefnet_model& model, image_view image) { - i32x2 res = birefnet_image_extent(image.extent, model.params); + i32x2 res = birefnet_image_extent(image.extent, model.params, model.backend->max_alloc()); if (!model.graph || res != model.params.image_extent) { model.params.image_extent = res; model.graph = compute_graph_init(6 * 1024);