Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/visp/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct image_data {
VISP_API image_data image_alloc(i32x2 extent, image_format format);

// Set all pixels to zero.
void image_clear(image_span const&);
VISP_API void image_clear(image_span const&);

// Load image from file (PNG, JPEG, etc.)
VISP_API image_data image_load(char const* filepath);
Expand Down
12 changes: 10 additions & 2 deletions src/visp/arch/birefnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ tensor mlp(model_ref m, tensor x) {
return named(m, x);
}

// Ensures that the tensor's data is not overwritten during computation.
tensor make_constant(tensor x, tensor_name name) {
ggml_set_name(x, name.c_str());
ggml_set_input(x); // allocate at the beginning of the graph buffer
ggml_set_output(x); // don't reuse memory for computations
return x;
}

void compute_relative_position_index(span<int32_t> dst, int window_size) {
int n = window_size;
int n2 = n * n;
Expand All @@ -34,7 +42,7 @@ tensor_data create_relative_position_index(ggml_context* ctx, int window_size) {
auto result = tensor_alloc(ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n * n * n * n));
auto name = format<tensor_name>("window_attention_{}.rel_pos_index", n);
compute_relative_position_index(result.as_i32(), n);
ggml_set_name(result.x, name.c_str());
make_constant(result.x, name);
return result;
}

Expand Down Expand Up @@ -226,7 +234,7 @@ tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int w
auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n * n, n * n, nw_x * nw_y));
auto name = format<tensor_name>("swin_layer_{}x{}.attn_mask", w, h);
compute_attention_mask(result.as_f32(), w, h, window_size);
ggml_set_name(result.x, name.c_str());
make_constant(result.x, name);
return result;
}

Expand Down
2 changes: 1 addition & 1 deletion src/visp/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ birefnet_model birefnet_load_model(char const* filepath, backend_device const& d

image_data birefnet_compute(birefnet_model& model, image_view image) {
i32x2 res = birefnet_image_extent(image.extent, model.params);
if (!model.input || res != model.params.image_extent) {
if (!model.graph || res != model.params.image_extent) {
model.params.image_extent = res;
model.graph = compute_graph_init(6 * 1024);

Expand Down
Loading