Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/visp/ml.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct backend_device {
VISP_API ggml_type preferred_float_type() const;
VISP_API tensor_data_layout preferred_layout() const;
VISP_API size_t total_memory() const;
VISP_API size_t max_alloc() const;

operator ggml_backend_t() const { return handle.get(); }
};
Expand Down
6 changes: 4 additions & 2 deletions include/visp/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,11 @@ struct birefnet_params {

using birefnet_buffers = std::array<tensor_data, swin_params::n_layers + 2>;

VISP_API birefnet_params birefnet_detect_params(model_file const&, i32x2 dynamic_extent = {});
VISP_API birefnet_params birefnet_detect_params(
model_file const&, i32x2 dynamic_extent = {}, size_t max_alloc = SIZE_MAX);
VISP_API birefnet_buffers birefnet_precompute(model_ref, birefnet_params const&);
VISP_API i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const&);
VISP_API i32x2 birefnet_image_extent(
i32x2 input_extent, birefnet_params const&, size_t max_alloc = SIZE_MAX);

VISP_API image_data birefnet_process_input(image_view, birefnet_params const&);
VISP_API image_data birefnet_process_output(
Expand Down
2 changes: 1 addition & 1 deletion src/cli/cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ void run_birefnet(cli_args const& args) {

require_inputs(args.inputs, 1, "<image>");
image_data image = image_load(args.inputs[0]);
birefnet_params params = birefnet_detect_params(file, image.extent);
birefnet_params params = birefnet_detect_params(file, image.extent, backend.max_alloc());
image_data input_data = birefnet_process_input(image, params);

i32x2 extent = params.image_extent;
Expand Down
18 changes: 14 additions & 4 deletions src/visp/arch/birefnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ tensor mlp(model_ref m, tensor x) {
// Ensures that the tensor's data is not overwritten during computation.
tensor make_constant(tensor x, tensor_name name) {
ggml_set_name(x, name.c_str());
ggml_set_input(x); // allocate at the beginning of the graph buffer
ggml_set_input(x); // allocate at the beginning of the graph buffer
ggml_set_output(x); // don't reuse memory for computations
return x;
}
Expand Down Expand Up @@ -611,25 +611,35 @@ swin_params swin_detect_params(model_file const& f) {
}
}

i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p) {
i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p, size_t max_alloc) {
i32x2 extent{p.image_size, p.image_size};
if (p.image_size == -1) {
ASSERT(input_extent[0] > 0 && input_extent[1] > 0);
// largest layer in BiRefNet-dynamic is input for 240-channel conv-2d at full resolution
size_t req_alloc = size_t(input_extent[0]) * input_extent[1] * 240ULL * sizeof(float);
if (req_alloc > max_alloc) {
float scale = std::sqrt(float(max_alloc) / float(req_alloc));
input_extent = {
std::max(1, int(input_extent[0] * scale) - p.image_multiple),
std::max(1, int(input_extent[1] * scale) - p.image_multiple)};
}
extent = {
next_multiple(input_extent[0], p.image_multiple),
next_multiple(input_extent[1], p.image_multiple)};
}
return extent;
}

birefnet_params birefnet_detect_params(model_file const& f, i32x2 dynamic_extent) {
birefnet_params birefnet_detect_params(
model_file const& f, i32x2 dynamic_extent, size_t max_alloc) {

if (std::string_view arch = f.arch(); arch != "birefnet") {
throw except("Architecture expected to be 'birefnet', but was '{}' ({})", arch, f.path);
}
birefnet_params p;
p.image_size = f.get_int("birefnet.image_size");
p.image_multiple = f.get_int("birefnet.image_multiple");
p.image_extent = birefnet_image_extent(dynamic_extent, p);
p.image_extent = birefnet_image_extent(dynamic_extent, p, max_alloc);
p.encoder = swin_detect_params(f);
return p;
}
Expand Down
5 changes: 5 additions & 0 deletions src/visp/ml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ size_t backend_device::total_memory() const {
return total;
}

size_t backend_device::max_alloc() const {
const size_t vulkan_max = 4 * 1024 * 1024 * 1024ULL; // TODO: query from backend
return type() == backend_type::cpu ? SIZE_MAX : vulkan_max;
}

void backend_set_n_threads(backend_device& b, int n_threads) {
if (b.type() != backend_type::cpu) {
return;
Expand Down
2 changes: 1 addition & 1 deletion src/visp/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ birefnet_model birefnet_load_model(char const* filepath, backend_device const& d
}

image_data birefnet_compute(birefnet_model& model, image_view image) {
i32x2 res = birefnet_image_extent(image.extent, model.params);
i32x2 res = birefnet_image_extent(image.extent, model.params, model.backend->max_alloc());
if (!model.graph || res != model.params.image_extent) {
model.params.image_extent = res;
model.graph = compute_graph_init(6 * 1024);
Expand Down
Loading