@@ -240,8 +240,8 @@ DEF(biref_relative_position_index)(model_ref m, span<tensor> input, param_dict c
240240DEF (biref_window_attention)(model_ref m, span<tensor> input, param_dict const & p) {
241241 int window_size = 3 ;
242242 tensor mask = m.find (" mask" );
243- auto rel_pos_index = birefnet::create_relative_position_index (m. weights_context , window_size);
244- ggml_backend_alloc_ctx_tensors (m. weights_context , workbench_backend ());
243+ auto rel_pos_index = birefnet::create_relative_position_index (m, window_size);
244+ ggml_backend_alloc_ctx_tensors (m, workbench_backend ());
245245 transfer_to_backend (rel_pos_index);
246246 return {birefnet::window_attention (m, input[0 ], mask, 2 , window_size)};
247247}
@@ -254,8 +254,8 @@ DEF(biref_swin_block)(model_ref m, span<tensor> input, param_dict const& p) {
254254 block.h = 6 ;
255255 block.shift = 0 ;
256256 tensor mask = m.find (" mask" );
257- auto rel_pos_index = birefnet::create_relative_position_index (m. weights_context , 3 );
258- ggml_backend_alloc_ctx_tensors (m. weights_context , workbench_backend ());
257+ auto rel_pos_index = birefnet::create_relative_position_index (m, 3 );
258+ ggml_backend_alloc_ctx_tensors (m, workbench_backend ());
259259 transfer_to_backend (rel_pos_index);
260260 return {birefnet::swin_block (m, input[0 ], mask, block)};
261261}
@@ -276,9 +276,11 @@ DEF(biref_swin_layer)(model_ref m, span<tensor> input, param_dict const& p) {
276276 layer.n_heads = 2 ;
277277 layer.n_features = 8 ;
278278 layer.downsample = true ;
279- auto rel_pos_index = birefnet::create_relative_position_index (m.weights_context , 3 );
280- ggml_backend_alloc_ctx_tensors (m.weights_context , workbench_backend ());
279+ auto rel_pos_index = birefnet::create_relative_position_index (m, 3 );
280+ auto attn_mask = birefnet::create_attention_mask (m, 6 , 6 , 3 );
281+ ggml_backend_alloc_ctx_tensors (m, workbench_backend ());
281282 transfer_to_backend (rel_pos_index);
283+ transfer_to_backend (attn_mask);
282284 auto result = birefnet::swin_layer (m, input[0 ], 6 , 6 , layer, 3 );
283285 ASSERT (result.w_down == 3 && result.h_down == 3 );
284286 return {result.x_down };
@@ -294,11 +296,11 @@ DEF(biref_swin_transformer)(model_ref m, span<tensor> input, param_dict const& p
294296 swin_layer_t {2 , 4 , 8 * 4 , true },
295297 swin_layer_t {2 , 2 , 8 * 8 , false },
296298 }};
297- auto rel_pos_index = birefnet::create_relative_position_index (m. weights_context , 3 );
299+ auto rel_pos_index = birefnet::create_relative_position_index (m, 3 );
298300 auto attn_masks = std::array{
299- birefnet::create_attention_mask (m. weights_context , 8 , 8 , 3 ), birefnet::create_attention_mask (m. weights_context , 4 , 4 , 3 ),
300- birefnet::create_attention_mask (m. weights_context , 2 , 2 , 3 ), birefnet::create_attention_mask (m. weights_context , 1 , 1 , 3 )};
301- ggml_backend_alloc_ctx_tensors (m. weights_context , workbench_backend ());
301+ birefnet::create_attention_mask (m, 8 , 8 , 3 ), birefnet::create_attention_mask (m, 4 , 4 , 3 ),
302+ birefnet::create_attention_mask (m, 2 , 2 , 3 ), birefnet::create_attention_mask (m, 1 , 1 , 3 )};
303+ ggml_backend_alloc_ctx_tensors (m, workbench_backend ());
302304 transfer_to_backend (rel_pos_index);
303305 for (auto && attn_mask : attn_masks) {
304306 transfer_to_backend (attn_mask);
0 commit comments