Skip to content

Commit 87f3d76

Browse files
committed
birefnet: switch to direct version of conv_2d_deform on Vulkan
* times old -> new / new with coopmat2 * birefnet: 268ms -> 315ms / 243ms * birefnet-lite: 109ms -> 119 / 87ms * deform conv2d is a bit slower without coopmat2 support, but also requires much less vram, so still worth it
1 parent f82d2c6 commit 87f3d76

File tree

5 files changed

+21
-26
lines changed

5 files changed

+21
-26
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ as other frameworks for inference speed, but with:
192192
| Model | | | _vision.cpp_ | PyTorch | ONNX Runtime |
193193
| :---- | :--- | :--- | -----------: | -------: | -----------: |
194194
| Full | cpu | f32 | 16333 ms | 18800 ms | |
195-
| Full | gpu | f16 | 268 ms | 140 ms | |
195+
| Full | gpu | f16 | 243 ms | 140 ms | |
196196
| Lite | cpu | f32 | 4505 ms | 10900 ms | 6978 ms |
197-
| Lite | gpu | f16 | 109 ms | 59 ms | |
197+
| Lite | gpu | f16 | 86 ms | 59 ms | |
198198

199199
#### MI-GAN, 512x512
200200

@@ -205,7 +205,7 @@ as other frameworks for inference speed, but with:
205205

206206
#### Setup
207207

208-
* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench sam cpu`
208+
* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench -m sam -b cpu`
209209
* PyTorch: v2.7.1+cu128, eager eval, GPU via CUDA, average n iterations after warm-up
210210

211211
## Dependencies (integrated)

src/visp/nn.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,6 @@ tensor conv_2d_deform(
154154

155155
if (m.flags & model_build_flag::cwhn) {
156156
x = permute_whcn_to_cwhn(m, x);
157-
} else if (!(m.flags & model_build_flag::f16_conv_transpose)) {
158-
// Vulkan WHCN implementation doesn't do the final permute atm
159-
// only worth fixing if WHCN ends up faster AND we dont implement
160-
// a direct version of conv_2d_deform
161-
auto [w, h, c, n] = nelements(x);
162-
x = ggml_reshape_4d(m, x, c, w, h, n);
163-
x = ggml_cont(m, ggml_permute(m, x, 2, 0, 1, 3));
164157
}
165158
return x;
166159
}

tests/test_birefnet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -761,13 +761,15 @@ def test_encode():
761761
@pytest.mark.parametrize("backend", ["cpu", "vulkan"])
762762
def test_conv_2d_deform(scenario: str, memory_layout: str, backend: str):
763763
torch.manual_seed(42)
764+
if memory_layout == "nhwc" and backend == "vulkan":
765+
pytest.skip("conv_2d_deform with nhwc layout is not supported on Vulkan")
764766

765767
w, h, c_in, c_out, k = {
766768
"small": (4, 4, 5, 2, 3),
767769
"large": (49, 38, 81, 17, 7),
768770
}[scenario]
769-
x = input_tensor(1, c_in, h, w) - 0.5
770-
weight = input_tensor(c_out, c_in, k, k)
771+
x = torch.rand(1, c_in, h, w) - 0.5
772+
weight = torch.rand(c_out, c_in, k, k) - 0.5
771773
offset = 1.0 - input_tensor(1, 2 * k * k, h, w)
772774
mask = torch.rand(1, k * k, h, w)
773775
expected = torchvision.ops.deform_conv2d(x, offset, weight, mask=mask, padding=(k // 2, k // 2))
@@ -785,7 +787,7 @@ def test_conv_2d_deform(scenario: str, memory_layout: str, backend: str):
785787
if memory_layout == "nhwc":
786788
result = to_nchw(result)
787789

788-
assert torch.allclose(result, expected, atol=1e-2 if backend == "vulkan" else 1e-5)
790+
assert torch.allclose(result, expected, atol=0.1 if backend == "vulkan" else 0.001)
789791

790792

791793
class DeformableConv2d(nn.Module):
@@ -916,8 +918,6 @@ def test_global_avg_pool(backend: str):
916918

917919
state = fuse_all_conv_2d_batch_norm(state, "", "1", "2")
918920
state = convert_to_nhwc(state, key="1.weight")
919-
for k, v in state.items():
920-
print(f"{k}: {v.shape}")
921921
x = to_nhwc(x)
922922
result = workbench.invoke_test("biref_global_avg_pool", x, state, nhwc_layout, backend=backend)
923923
result = to_nchw(result)

tests/workbench.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ DEF(biref_relative_position_index)(model_ref m, span<tensor> input, param_dict c
240240
DEF(biref_window_attention)(model_ref m, span<tensor> input, param_dict const& p) {
241241
int window_size = 3;
242242
tensor mask = m.find("mask");
243-
auto rel_pos_index = birefnet::create_relative_position_index(m.weights_context, window_size);
244-
ggml_backend_alloc_ctx_tensors(m.weights_context, workbench_backend());
243+
auto rel_pos_index = birefnet::create_relative_position_index(m, window_size);
244+
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
245245
transfer_to_backend(rel_pos_index);
246246
return {birefnet::window_attention(m, input[0], mask, 2, window_size)};
247247
}
@@ -254,8 +254,8 @@ DEF(biref_swin_block)(model_ref m, span<tensor> input, param_dict const& p) {
254254
block.h = 6;
255255
block.shift = 0;
256256
tensor mask = m.find("mask");
257-
auto rel_pos_index = birefnet::create_relative_position_index(m.weights_context, 3);
258-
ggml_backend_alloc_ctx_tensors(m.weights_context, workbench_backend());
257+
auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
258+
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
259259
transfer_to_backend(rel_pos_index);
260260
return {birefnet::swin_block(m, input[0], mask, block)};
261261
}
@@ -276,9 +276,11 @@ DEF(biref_swin_layer)(model_ref m, span<tensor> input, param_dict const& p) {
276276
layer.n_heads = 2;
277277
layer.n_features = 8;
278278
layer.downsample = true;
279-
auto rel_pos_index = birefnet::create_relative_position_index(m.weights_context, 3);
280-
ggml_backend_alloc_ctx_tensors(m.weights_context, workbench_backend());
279+
auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
280+
auto attn_mask = birefnet::create_attention_mask(m, 6, 6, 3);
281+
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
281282
transfer_to_backend(rel_pos_index);
283+
transfer_to_backend(attn_mask);
282284
auto result = birefnet::swin_layer(m, input[0], 6, 6, layer, 3);
283285
ASSERT(result.w_down == 3 && result.h_down == 3);
284286
return {result.x_down};
@@ -294,11 +296,11 @@ DEF(biref_swin_transformer)(model_ref m, span<tensor> input, param_dict const& p
294296
swin_layer_t{2, 4, 8 * 4, true},
295297
swin_layer_t{2, 2, 8 * 8, false},
296298
}};
297-
auto rel_pos_index = birefnet::create_relative_position_index(m.weights_context, 3);
299+
auto rel_pos_index = birefnet::create_relative_position_index(m, 3);
298300
auto attn_masks = std::array{
299-
birefnet::create_attention_mask(m.weights_context, 8, 8, 3), birefnet::create_attention_mask(m.weights_context, 4, 4, 3),
300-
birefnet::create_attention_mask(m.weights_context, 2, 2, 3), birefnet::create_attention_mask(m.weights_context, 1, 1, 3)};
301-
ggml_backend_alloc_ctx_tensors(m.weights_context, workbench_backend());
301+
birefnet::create_attention_mask(m, 8, 8, 3), birefnet::create_attention_mask(m, 4, 4, 3),
302+
birefnet::create_attention_mask(m, 2, 2, 3), birefnet::create_attention_mask(m, 1, 1, 3)};
303+
ggml_backend_alloc_ctx_tensors(m, workbench_backend());
302304
transfer_to_backend(rel_pos_index);
303305
for (auto&& attn_mask : attn_masks) {
304306
transfer_to_backend(attn_mask);

0 commit comments

Comments
 (0)