From 766b6db9f26edc35d28b6243a4b04df8d8e9b8b2 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 1 Oct 2025 23:34:21 +0200 Subject: [PATCH 01/24] ml: bicubic interpolation tests --- tests/test_primitives.py | 28 ++++++++++++++++++++++------ tests/workbench.cpp | 11 +++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 08c6414..28df77d 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -16,9 +16,7 @@ def test_linear(): assert torch.allclose(result, expected) -@pytest.mark.parametrize( - "scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"] -) +@pytest.mark.parametrize("scenario", ["stride_1_pad_0", "stride_2_pad_1", "dilation_2_pad_2"]) @pytest.mark.parametrize("memory_layout", ["nchw", "nhwc"]) @pytest.mark.parametrize("batch", ["single", "batch"]) @pytest.mark.parametrize("backend", ["cpu", "vulkan"]) @@ -128,9 +126,7 @@ def test_window_partition(backend: str): nW = pW // win # window partition expected = ( - expected.view(B, nH, win, nW, win, C) - .transpose(2, 3) - .reshape(B * nH * nW, win * win, C) + expected.view(B, nH, win, nW, win, C).transpose(2, 3).reshape(B * nH * nW, win * win, C) ) result = workbench.invoke_test("sam_window_partition", x, {}, backend=backend) @@ -150,3 +146,23 @@ def test_roll(shift: tuple[int, int, int, int], backend: str): result = workbench.invoke_test("roll", x, {}, params, backend) assert torch.allclose(result, expected) + + +@pytest.mark.parametrize("mode", ["bilinear", "bicubic"]) +@pytest.mark.parametrize("align_corners", [True, False]) +@pytest.mark.parametrize("size", ["small", "large"]) +@pytest.mark.parametrize("scale", [0.6, 2.0]) +def test_interpolate(mode: str, align_corners: bool, size: str, scale: float): + b, c, h, w = { + "small": (1, 3, 2, 3), + "large": (4, 19, 20, 30), + }[size] + target = (round(h * scale), round(w * scale)) + x = torch.arange(b * c * h * w).reshape(b, c, h, w).float() + expected = torch.nn.functional.interpolate( + x, size=target, mode=mode, align_corners=align_corners + ) + + params = dict(mode=mode, h=target[0], w=target[1], align_corners=1 if align_corners else 0) + result = workbench.invoke_test("interpolate", x, {}, params) + assert torch.allclose(result, expected) diff --git a/tests/workbench.cpp b/tests/workbench.cpp index f31e83d..1c6181b 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -116,6 +116,17 @@ DEF(linear)(model_ref m, span input, param_dict const& p) { return {linear(m, input[0])}; } +DEF(interpolate)(model_ref m, span input, param_dict const& p) { + int w = p.get("w", 8); + int h = p.get("h", 8); + uint32_t mode = p.get("mode", "bilinear") == "bilinear"sv ? GGML_SCALE_MODE_BILINEAR + : GGML_SCALE_MODE_BICUBIC; + if (p.get("align_corners", 0)) { + mode |= GGML_SCALE_FLAG_ALIGN_CORNERS; + } + return {ggml_interpolate(m, input[0], w, h, input[0]->ne[2], input[0]->ne[3], mode)}; +} + // // Mobile SAM From 74a5d9ae1c1e1dbaea66e9d746c58e6a0ab3667a Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 2 Oct 2025 20:29:14 +0200 Subject: [PATCH 02/24] dino: implement dino-v2 intermediate layers --- src/util/math.h | 4 + src/visp/arch/birefnet.cpp | 12 -- src/visp/arch/birefnet.h | 2 - src/visp/arch/dino.h | 158 +++++++++++++++++++ src/visp/nn.cpp | 15 ++ src/visp/nn.h | 3 + tests/test_depth_anything.py | 289 +++++++++++++++++++++++++++++++++++ tests/workbench.cpp | 61 ++++++-- 8 files changed, 514 insertions(+), 30 deletions(-) create mode 100644 src/visp/arch/dino.h create mode 100644 tests/test_depth_anything.py diff --git a/src/util/math.h b/src/util/math.h index 835229d..40bbbcf 100644 --- a/src/util/math.h +++ b/src/util/math.h @@ -59,5 +59,9 @@ constexpr i32x2 div_ceil(i32x2 a, i32x2 b) { return {div_ceil(a[0], b[0]), div_c constexpr i32x2 div_ceil(i32x2 a, int32_t b) { return div_ceil(a, i32x2{b, b}); } constexpr i32x2 min(i32x2 a, i32x2 b) { return {std::min(a[0], b[0]), std::min(a[1], b[1])}; } +// i64x2 operations +constexpr i64x2 operator*(i64x2 a, int64_t b) { return {a[0] * b, a[1] * b}; } +constexpr i64x2 operator/(i64x2 a, int64_t b) { return {a[0] / b, a[1] / b}; } + // clang-format on } // namespace visp \ No newline at end of file diff --git a/src/visp/arch/birefnet.cpp b/src/visp/arch/birefnet.cpp index b294b87..0bcd3e2 100644 --- a/src/visp/arch/birefnet.cpp +++ b/src/visp/arch/birefnet.cpp @@ -261,18 +261,6 @@ swin_layer_result swin_layer( return {x, w, h, x, w, h}; } -tensor patch_embed(model_ref m, tensor x, int patch_size) { - ASSERT(x->ne[1] % patch_size == 0 && x->ne[2] % patch_size == 0); - - m.flags |= model_build_flag::cwhn; - x = conv_2d(m["proj"], x, patch_size); - auto [c, ww, wh, b] = nelements(x); - x = ggml_reshape_3d(m, x, c, ww * wh, b); - x = layer_norm(m["norm"], x); - x = ggml_reshape_4d(m, x, c, ww, wh, b); - return named(m, x); -} - swin_result swin_transformer(model_ref m, tensor x, swin_params const& p) { x = patch_embed(m["patch_embed"], x, 4); diff --git a/src/visp/arch/birefnet.h b/src/visp/arch/birefnet.h index 7f109ad..2ad1b3f 100644 --- a/src/visp/arch/birefnet.h +++ b/src/visp/arch/birefnet.h @@ -6,7 +6,6 @@ #include namespace visp { - namespace birefnet { // SWIN Transformer @@ -37,7 +36,6 @@ tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int w tensor mlp(model_ref m, tensor x); tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h); -tensor patch_embed(model_ref m, tensor x, int patch_size = 4); tensor window_partition(model_ref m, tensor x, int window); tensor window_reverse(model_ref m, tensor x, int w, int h, int window); tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window); diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h new file mode 100644 index 0000000..8c56908 --- /dev/null +++ b/src/visp/arch/dino.h @@ -0,0 +1,158 @@ +#pragma once + +#include "util/math.h" +#include "visp/image.h" +#include "visp/ml.h" +#include "visp/nn.h" + +#include + +#pragma optimize("", off) + +namespace visp { +namespace dino { + +inline tensor interpolate_pos_encoding( + model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { + + tensor pos_embed = ggml_cast(m, m.weights("pos_embed"), GGML_TYPE_F32); + int64_t n_patch = x->ne[1] - 1; + int64_t n = pos_embed->ne[1] - 1; + if (n_patch == n && w == h) { + return pos_embed; + } + + tensor class_embed = slice(m, pos_embed, {}, {0}, {}, {}); + tensor patch_embed = slice(m, pos_embed, {}, {1, n + 1}, {}, {}); + int64_t dim = x->ne[0]; + i64x2 target = i64x2{w, h} / patch_size; + int64_t sqrt_n = int64_t(std::sqrt(float(n)) + 0.01f); + + patch_embed = ggml_reshape_4d(m, patch_embed, dim, sqrt_n, sqrt_n, 1); + patch_embed = ggml_cont(m, permute_cwhn_to_whcn(m, patch_embed)); + patch_embed = interpolate(m, patch_embed, target, GGML_SCALE_MODE_BICUBIC); + patch_embed = ggml_cont(m, permute_whcn_to_cwhn(m, patch_embed)); + patch_embed = ggml_reshape_3d(m, patch_embed, dim, target[0] * target[1], 1); + return concat(m, {class_embed, patch_embed}, 1); +} + +inline tensor prepare_tokens(model_ref m, tensor x, int patch_size) { + auto [c, w, h, n] = nelements(x); + x = patch_embed(m["patch_embed"], x, patch_size); + x = ggml_reshape_3d(m, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); + + tensor cls_token = m.weights("cls_token"); + if (cls_token->ne[2] != n) { + cls_token = ggml_repeat_4d(m, cls_token, cls_token->ne[0], 1, n, 1); + } + x = concat(m, {cls_token, x}, 1); + + tensor pos_enc = interpolate_pos_encoding(m, x, w, h, patch_size); + x = ggml_add_inplace(m, x, pos_enc); + return x; +} + +inline tensor layer_scale(model_ref m, tensor x) { + return ggml_mul(m, x, m.weights("gamma")); +} + +inline tensor mlp(model_ref m, tensor x) { + x = linear(m["fc1"], x); + x = ggml_gelu(m, x); + x = linear(m["fc2"], x); + return x; +} + +inline tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { + auto [c, n, b, _] = nelements(x); + float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); + + tensor qkv = linear(m["qkv"], x); + qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); + qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); + + auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { + tensor = slice(m, tensor, {}, {}, {}, index); + tensor = ggml_reshape_4d(m, tensor, c / n_heads, n_heads, n, b); + if (transpose) { + tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); + } else { + tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); + } + return tensor; + }; + tensor q = split(qkv, 0); + tensor k = split(qkv, 1); + tensor v = split(qkv, 2, !flash_attn); + + if (flash_attn) { + int64_t c_pad = GGML_PAD(c, 4) - c; + int64_t n_pad = GGML_PAD(n, 32) - n; + q = ggml_pad(m, q, c_pad, n_pad, 0, 0); + k = ggml_pad(m, k, c_pad, n_pad, 0, 0); + v = ggml_pad(m, v, c_pad, n_pad, 0, 0); + + ggml_type dtype = m.weights("qkv.weight")->type; + k = ggml_cast(m, k, dtype); + v = ggml_cast(m, v, dtype); + + x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); + x = slice(m, x, {}, {}, {0, n}, {}); + } else { + q = ggml_scale_inplace(m, q, scale); + + tensor attn = ggml_mul_mat(m, k, q); + attn = ggml_soft_max(m, attn); + + x = ggml_mul_mat(m, v, attn); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_3d(m, x, c, n, b); + } + + x = linear(m["proj"], x); + return named(m, x); +} + +struct dino_params { + int patch_size = 16; + int embed_dim = 384; + int n_blocks = 12; + int n_heads = 6; + int mlp_ratio = 4; + bool flash_attention = false; +}; + +inline tensor block(model_ref m, tensor x, dino_params const& p) { + tensor attn = x; + attn = layer_norm(m["norm1"], attn); + attn = attention(m["attn"], attn, p.n_heads, p.flash_attention); + attn = layer_scale(m["ls1"], attn); + x = ggml_add_inplace(m, x, attn); + + tensor ffn = x; + ffn = layer_norm(m["norm2"], ffn); + ffn = mlp(m["mlp"], ffn); + ffn = layer_scale(m["ls2"], ffn); + x = ggml_add_inplace(m, x, ffn); + + return named(m, x); +} + +inline std::vector get_intermediate_layers( + model_ref m, tensor x, int n, dino_params const& p) { + + x = prepare_tokens(m, x, p.patch_size); + + std::vector outputs; + model_ref blocks = m["blocks"]; + for (int i = 0; i < p.n_blocks; ++i) { + x = block(blocks[i], x, p); + if (i >= p.n_blocks - n) { + outputs.push_back(x); + } + } + return outputs; +} + +} // namespace dino +} // namespace visp diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp index 7b6065b..e6092bb 100644 --- a/src/visp/nn.cpp +++ b/src/visp/nn.cpp @@ -174,4 +174,19 @@ tensor batch_norm_2d(model_ref m, tensor x) { return named(m, x); } +tensor patch_embed(model_ref m, tensor x, int patch_size) { + ASSERT(x->ne[1] % patch_size == 0 && x->ne[2] % patch_size == 0); + + m.flags |= model_build_flag::cwhn; + x = conv_2d(m["proj"], x, patch_size); + + if (m.find("norm.weight")) { + auto [c, w, h, b] = nelements(x); + x = ggml_reshape_3d(m, x, c, w * h, b); + x = layer_norm(m["norm"], x); + x = ggml_reshape_4d(m, x, c, w, h, b); + } + return named(m, x); +} + } // namespace visp \ No newline at end of file diff --git a/src/visp/nn.h b/src/visp/nn.h index eb8c106..9b7e762 100644 --- a/src/visp/nn.h +++ b/src/visp/nn.h @@ -38,4 +38,7 @@ tensor conv_2d_deform( tensor conv_transpose_2d(model_ref m, tensor x, int stride); tensor batch_norm_2d(model_ref, tensor x); +// 2D image to patch embedding using convolution and optional norm. CWHN input and output. +tensor patch_embed(model_ref, tensor x, int patch_size); + } // namespace visp diff --git a/tests/test_depth_anything.py b/tests/test_depth_anything.py new file mode 100644 index 0000000..e933c43 --- /dev/null +++ b/tests/test_depth_anything.py @@ -0,0 +1,289 @@ +import math +import torch +import torch.nn as nn +from torch import Tensor + +from tests import workbench +from tests.workbench import convert_to_nhwc, generate_state, input_tensor, to_nchw, to_nhwc + + +class PatchEmbed(nn.Module): + def __init__( + self, + img_size=(224, 224), + patch_size=(16, 16), + in_chans: int = 3, + embed_dim: int = 768, + ): + super().__init__() + self.embed_dim = embed_dim + self.patch_size = patch_size + self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x: Tensor, flatten=False) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + assert H % patch_H == 0, ( + f"Input image height {H} is not a multiple of patch height {patch_H}" + ) + assert W % patch_W == 0, ( + f"Input image width {W} is not a multiple of patch width: {patch_W}" + ) + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + # x = self.norm(x) + if not flatten: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + +def test_patch_embed(): + patch_embed = PatchEmbed(img_size=(16, 16), patch_size=(4, 4), in_chans=3, embed_dim=8) + state = generate_state(patch_embed.state_dict()) + patch_embed.load_state_dict(state) + patch_embed.eval() + + x = input_tensor(1, 3, 8, 12) + expected = patch_embed(x) + + x = to_nhwc(x) + state = convert_to_nhwc(state, key="proj") + result = workbench.invoke_test("biref_patch_embed", x, state) + + assert torch.allclose(result, expected) + + +def interpolate_pos_encoding(pos_embed: Tensor, x: Tensor, w: int, h: int, patch_size: int): + # This is 0.1 in official code, which would cause a small difference because ggml + # does not support passing a scale_factor to interpolate + interpolate_offset = 0.0 + interpolate_antialias = False + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = pos_embed.shape[1] - 1 + if npatch == N and w == h: + return pos_embed + pos_embed = pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // patch_size + h0 = h // patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + interpolate_offset, h0 + interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + +def test_interpolate_pos_encoding(): + img_size = 12 + patch_size = 4 + num_patches = (img_size // patch_size) ** 2 + embed_dim = 8 + pos_embed = torch.randn(1, num_patches + 1, embed_dim) + + x = input_tensor(1, num_patches, embed_dim) + expected = interpolate_pos_encoding(pos_embed, x, img_size, img_size, patch_size) + + state = {"pos_embed": pos_embed} + params = {"img_size": img_size, "patch_size": patch_size} + result = workbench.invoke_test("dino_interpolate_pos_encoding", x, state, params) + + assert torch.allclose(result, expected) + + +class PrepareTokensModule(nn.Module): + def __init__(self, img_size, patch_size, embed_dim: int): + super().__init__() + self.patch_size = patch_size + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=(patch_size, patch_size), embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + + def prepare_tokens_with_masks(self, x: Tensor, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x, flatten=True) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + interpolate_pos_encoding(self.pos_embed, x, w, h, self.patch_size) + return x + + +def test_prepare_tokens_with_masks(): + img_size = 12 + patch_size = 4 + embed_dim = 6 + module = PrepareTokensModule((img_size, img_size), patch_size, embed_dim) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + x = input_tensor(1, 3, img_size, img_size) + expected = module.prepare_tokens_with_masks(x) + + x = to_nhwc(x) + state = convert_to_nhwc(state, key="patch_embed.proj") + result = workbench.invoke_test("dino_prepare_tokens", x, state) + + assert torch.allclose(result, expected) + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int | None = None, + out_features: int | None = None, + act_layer=nn.GELU, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + x = self.fc2(x) + # x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + # x = self.proj_drop(x) + return x + + +def test_attention(): + dim = 6 + num_heads = 3 + module = Attention(dim=dim, num_heads=num_heads, qkv_bias=True, proj_bias=True) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + x = input_tensor(1, 12, dim) + expected = module(x) + result = workbench.invoke_test( + "dino_attention", x, state, dict(n_heads=num_heads, flash_attn=0) + ) + + assert torch.allclose(result, expected) + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values=1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + init_values=None, + ) -> None: + super().__init__() + self.norm1 = nn.LayerNorm(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = nn.Identity() + + self.norm2 = nn.LayerNorm(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=nn.GELU, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def test_block(): + dim = 6 + num_heads = 3 + module = Block(dim=dim, num_heads=num_heads, init_values=1.0) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + x = input_tensor(1, 12, dim) + expected = module(x) + result = workbench.invoke_test("dino_block", x, state, dict(n_heads=num_heads)) + + workbench.print_results(result, expected) + assert torch.allclose(result, expected, atol=1e-2) # precision drop due to GELU in MLP diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 1c6181b..83d34ee 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -1,5 +1,7 @@ #include "util/string.h" #include "visp/arch/birefnet.h" +#include "visp/arch/depth-anything.h" +#include "visp/arch/dino.h" #include "visp/arch/esrgan.h" #include "visp/arch/migan.h" #include "visp/arch/mobile-sam.h" @@ -240,7 +242,7 @@ DEF(sam_predict_masks)(model_ref m, span input, param_dict const& p) { // BiRefNet DEF(biref_patch_embed)(model_ref m, span input, param_dict const& p) { - return {birefnet::patch_embed(m, input[0])}; + return {patch_embed(m, input[0], 4)}; } DEF(biref_relative_position_index)(model_ref m, span input, param_dict const& p) { @@ -413,6 +415,33 @@ DEF(esrgan_rrdbnet)(model_ref m, span input, param_dict const& p) { return {esrgan_generate(m, input[0], params)}; } +// +// DINO + +DEF(dino_interpolate_pos_encoding)(model_ref m, span input, param_dict const& p) { + int s = p.get("img_size", 64); + int patch_size = p.get("patch_size", 16); + return {dino::interpolate_pos_encoding(m, input[0], s, s, patch_size)}; +} + +DEF(dino_prepare_tokens)(model_ref m, span input, param_dict const& p) { + return {dino::prepare_tokens(m, input[0], 4)}; +} + +DEF(dino_attention)(model_ref m, span input, param_dict const& p) { + return {dino::attention(m, input[0], p.get("n_heads", 8), p.get("flash_attn", 0) != 0)}; +} + +DEF(dino_block)(model_ref m, span input, param_dict const& p) { + dino::dino_params params{}; + params.n_heads = p.get("n_heads", 8); + params.flash_attention = p.get("flash_attn", 0) != 0; + return {dino::block(m, input[0], params)}; +} + +// +// Depth Anything + // // Workbench implementation // @@ -430,19 +459,19 @@ param_dict build_dict(span raw_params) { param.name = raw.name; switch (param_type(raw.type)) { - case param_type::int32: - param.type = param_type::int32; - param.value.i = std::stoi(raw.value); - break; - case param_type::float32: - param.type = param_type::float32; - param.value.f = std::stof(raw.value); - break; - case param_type::string: - param.type = param_type::string; - param.value.s = raw.value; - break; - default: throw except("Unknown parameter type"); + case param_type::int32: + param.type = param_type::int32; + param.value.i = std::stoi(raw.value); + break; + case param_type::float32: + param.type = param_type::float32; + param.value.f = std::stof(raw.value); + break; + case param_type::string: + param.type = param_type::string; + param.value.s = raw.value; + break; + default: throw except("Unknown parameter type"); } dict.params.push_back(param); } @@ -490,7 +519,6 @@ struct raw_tensor { size_t size_bytes() const { return size() * ggml_type_size(type()); } }; - struct test_case { char const* name; test_function func; @@ -605,7 +633,8 @@ extern "C" { #ifdef _MSC_VER __declspec(dllexport) #endif -int32_t visp_workbench( +int32_t +visp_workbench( char const* testcase, visp::raw_tensor const* inputs, int32_t n_inputs, From eec1f4e649ee4037c3bae5acdbfc0153e27298bb Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 7 Oct 2025 17:28:31 +0200 Subject: [PATCH 03/24] dino: implement depth-anything v2 head --- src/util/math.h | 1 + src/visp/arch/depth-anything.h | 132 ++++++++++++++ src/visp/arch/dino.h | 40 +++-- tests/test_depth_anything.py | 307 ++++++++++++++++++++++++++++++++- 4 files changed, 465 insertions(+), 15 deletions(-) create mode 100644 src/visp/arch/depth-anything.h diff --git a/src/util/math.h b/src/util/math.h index 40bbbcf..ed4dd24 100644 --- a/src/util/math.h +++ b/src/util/math.h @@ -57,6 +57,7 @@ constexpr i32x2 operator/(i32x2 a, int32_t b) { return {a[0] / b, a[1] / b}; } constexpr i32x2 div_ceil(i32x2 a, i32x2 b) { return {div_ceil(a[0], b[0]), div_ceil(a[1], b[1])}; } constexpr i32x2 div_ceil(i32x2 a, int32_t b) { return div_ceil(a, i32x2{b, b}); } +constexpr i32x2 next_multiple(i32x2 x, int32_t mult) { return div_ceil(x, mult) * mult; } constexpr i32x2 min(i32x2 a, i32x2 b) { return {std::min(a[0], b[0]), std::min(a[1], b[1])}; } // i64x2 operations diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h new file mode 100644 index 0000000..9fd0ee5 --- /dev/null +++ b/src/visp/arch/depth-anything.h @@ -0,0 +1,132 @@ +#include "visp/arch/dino.h" +#include "visp/ml.h" +#include "visp/nn.h" + +namespace visp { + +struct depthany_params { + int image_size = 518; + int image_multiple = 14; + std::array feature_layers = {2, 5, 8, 11}; + dino_params dino; +}; + +namespace dpt { + +i32x2 compute_inference_extent(i32x2 extent, depthany_params const& p) { + int min_side = std::min(extent[0], extent[1]); + int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple)); + i32x2 target = extent * tgt_side / min_side; + return next_multiple(target, p.image_multiple); +} + +tensor residual_conv(model_ref m, tensor x) { + tensor out = x; + out = ggml_relu(m, out); + out = conv_2d(m["conv1"], out, 1, 1); + out = ggml_relu(m, out); + out = conv_2d(m["conv2"], out, 1, 1); + x = ggml_add_inplace(m, x, out); + return named(m, x); +} + +tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = nullptr) { + tensor x = x0; + if (x1) { + tensor res = residual_conv(m["resConfUnit1"], x1); + x = ggml_add_inplace(m, x, res); + } + x = residual_conv(m["resConfUnit2"], x); + + int64_t w = size ? size[0] : x->ne[0] * 2; + int64_t h = size ? size[1] : x->ne[1] * 2; + x = interpolate(m, x, {w, h}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS); + + x = conv_2d(m["out_conv"], x); + return named(m, x); +} + +tensor head(model_ref m, span features, int patch_w, int patch_h) { + ASSERT(features.size() == 4); + std::array layer; + for (int i = 0; i < 4; ++i) { + tensor x = features[i]; + x = slice(m, x, {}, {}, {}, 0); + x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]); + + model_ref proj = m["projects"][i]; + proj.flags |= model_build_flag::cwhn; + x = conv_2d(proj, x); // 1x1 conv, keep CWHN layout and directly use mul_mat + + x = cwhn_to_contiguous_2d(m, x); + switch (i) { + case 0: x = conv_transpose_2d(m["resize_layers"][i], x, 4); break; + case 1: x = conv_transpose_2d(m["resize_layers"][i], x, 2); break; + case 3: x = conv_2d(m["resize_layers"][i], x, 2, 1); break; + } + layer[i] = x; + } + + model_ref scratch = m["scratch"]; + tensor layer1_rn = conv_2d(scratch["layer1_rn"], layer[0], 1, 1); + tensor layer2_rn = conv_2d(scratch["layer2_rn"], layer[1], 1, 1); + tensor layer3_rn = conv_2d(scratch["layer3_rn"], layer[2], 1, 1); + tensor layer4_rn = conv_2d(scratch["layer4_rn"], layer[3], 1, 1); + + tensor path4 = feature_fusion(scratch["refinenet4"], layer4_rn, nullptr, layer3_rn->ne); + tensor path3 = feature_fusion(scratch["refinenet3"], path4, layer3_rn, layer2_rn->ne); + tensor path2 = feature_fusion(scratch["refinenet2"], path3, layer2_rn, layer1_rn->ne); + tensor path1 = feature_fusion(scratch["refinenet1"], path2, layer1_rn); + + tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); + out = interpolate( + m, out, {patch_w * 14, patch_h * 14}, + GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS); + + model_ref output_conv2 = scratch["output_conv2"]; + out = conv_2d(output_conv2[0], out, 1, 1); + out = ggml_relu_inplace(m, out); + out = conv_2d(output_conv2[2], out); + out = ggml_relu_inplace(m, out); + return out; +} + +} // namespace dpt + +inline tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) { + auto [c, w, h, n] = nelements(image); + int w_patch = w / p.dino.patch_size; + int h_patch = h / p.dino.patch_size; + + auto features = dino_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); + tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); + depth = ggml_relu_inplace(m, depth); + return compute_graph_output(m, depth); +} + +inline image_data depthany_process_input(image_view image, depthany_params const& p) { + constexpr f32x4 mean = f32x4{0.485f, 0.456f, 0.406f, 0.f}; + constexpr f32x4 std = f32x4{0.229f, 0.224f, 0.225f, 1.f}; + + i32x2 target = dpt::compute_inference_extent(image.extent, p); + image_data resized; + if (image.extent != target) { + resized = image_scale(image, target); + image = image_view(resized); + } + return image_u8_to_f32(image, image_format::rgb_f32, -mean, 1.f / std); +} + +inline image_data depthany_process_output( + span data, i32x2 extent, depthany_params const& p) { + + image_view depth_output(dpt::compute_inference_extent(extent, p), data); + image_data depth_resized; + if (depth_output.extent != extent) { + depth_resized = image_scale(depth_output, extent); + depth_output = depth_resized; + } + return image_f32_to_u8(depth_output, image_format::alpha_u8); +} + +} // namespace visp \ No newline at end of file diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 8c56908..c774b6a 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -10,6 +10,16 @@ #pragma optimize("", off) namespace visp { + +struct dino_params { + int patch_size = 16; + int embed_dim = 384; + int n_blocks = 12; + int n_heads = 6; + int mlp_ratio = 4; + bool flash_attention = false; +}; + namespace dino { inline tensor interpolate_pos_encoding( @@ -113,15 +123,6 @@ inline tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { return named(m, x); } -struct dino_params { - int patch_size = 16; - int embed_dim = 384; - int n_blocks = 12; - int n_heads = 6; - int mlp_ratio = 4; - bool flash_attention = false; -}; - inline tensor block(model_ref m, tensor x, dino_params const& p) { tensor attn = x; attn = layer_norm(m["norm1"], attn); @@ -138,21 +139,34 @@ inline tensor block(model_ref m, tensor x, dino_params const& p) { return named(m, x); } +template +bool contains(std::span r, T const& value) { + return std::find(r.begin(), r.end(), value) != r.end(); +} + inline std::vector get_intermediate_layers( - model_ref m, tensor x, int n, dino_params const& p) { - + model_ref m, tensor x, std::span layers, dino_params const& p) { + x = prepare_tokens(m, x, p.patch_size); std::vector outputs; model_ref blocks = m["blocks"]; for (int i = 0; i < p.n_blocks; ++i) { x = block(blocks[i], x, p); - if (i >= p.n_blocks - n) { - outputs.push_back(x); + + if (contains(layers, i)) { + tensor out = layer_norm(m["norm"], x); + outputs.push_back(out); } } return outputs; } } // namespace dino + +inline std::vector dino_intermediate_layers( + model_ref m, tensor x, std::span layers, dino_params const& p) { + return dino::get_intermediate_layers(m, x, layers, p); +} + } // namespace visp diff --git a/tests/test_depth_anything.py b/tests/test_depth_anything.py index e933c43..e3d428e 100644 --- a/tests/test_depth_anything.py +++ b/tests/test_depth_anything.py @@ -1,6 +1,8 @@ import math +import pytest import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor from tests import workbench @@ -285,5 +287,306 @@ def test_block(): expected = module(x) result = workbench.invoke_test("dino_block", x, state, dict(n_heads=num_heads)) - workbench.print_results(result, expected) - assert torch.allclose(result, expected, atol=1e-2) # precision drop due to GELU in MLP + assert torch.allclose(result, expected, atol=1e-2) # precision drop due to GELU in MLP + + +class ResidualConvUnit(nn.Module): + def __init__(self, features, activation, bn=False): + super().__init__() + self.bn = bn + self.groups = 1 + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + ): + super(FeatureFusionBlock, self).__init__() + self.deconv = deconv + self.align_corners = align_corners + self.groups = 1 + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1 + ) + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + output = xs[0] + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + output = self.out_conv(output) + return output + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +@pytest.mark.parametrize("inputs", [1, 2]) +def test_feature_fusion(inputs): + features = 6 + x = [input_tensor(1, features, 4, 4)] + size = (8, 8) + if inputs == 2: + x.append(input_tensor(1, features, 4, 4)) + size = None + + module = _make_fusion_block(features, use_bn=False) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + expected = module(*x, size=size) + result = workbench.invoke_test("depthany_feature_fusion", x, state=state) + + assert torch.allclose(result, expected) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True), + ) + + def forward(self, x): + return self.conv_block(x) + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False, + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) + for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0, + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1, + ), + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()) + ) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate( + out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True + ) + out = self.scratch.output_conv2(out) + + return out + + +def test_dpt_head(): + in_channels = 4 + features = 6 + h, w = 8, 8 + module = DPTHead(in_channels=in_channels, features=features, use_clstoken=False) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + x = [input_tensor(2, 1, h * w, in_channels) for _ in range(4)] + expected = module(x, h, w) + + state = convert_to_nhwc(state, key="projects") + result = workbench.invoke_test("depthany_head", x, state) + + assert torch.allclose(result, expected, atol=1e-3) From ed991d8dc284986efb564dd4f52645a7d0f6fd3f Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 7 Oct 2025 17:28:51 +0200 Subject: [PATCH 04/24] tests: fix passing multiple inputs to workbench --- tests/workbench.cpp | 23 +++++++++++++++++++---- tests/workbench.py | 2 +- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 83d34ee..76d7ac7 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -433,7 +433,7 @@ DEF(dino_attention)(model_ref m, span input, param_dict const& p) { } DEF(dino_block)(model_ref m, span input, param_dict const& p) { - dino::dino_params params{}; + dino_params params{}; params.n_heads = p.get("n_heads", 8); params.flash_attention = p.get("flash_attn", 0) != 0; return {dino::block(m, input[0], params)}; @@ -442,6 +442,22 @@ DEF(dino_block)(model_ref m, span input, param_dict const& p) { // // Depth Anything +DEF(depthany_feature_fusion)(model_ref m, span input, param_dict const& p) { + if (input.size() == 1) { + int64_t size[] = {8, 8, 6, 1}; + return {dpt::feature_fusion(m, input[0], nullptr, size)}; + } else { + ASSERT(input.size() == 2); + return {dpt::feature_fusion(m, input[0], input[1])}; + } +} + +DEF(depthany_head)(model_ref m, span input, param_dict const& p) { + int patch_w = p.get("patch_w", 8); + int patch_h = p.get("patch_h", 8); + return {dpt::head(m, input, patch_w, patch_h)}; +} + // // Workbench implementation // @@ -572,9 +588,8 @@ void workbench_run( for (raw_tensor const& raw : tensors) { auto tensor = ggml_new_tensor_4d( m.weights_context, raw.type(), raw.ne[0], raw.ne[1], raw.ne[2], raw.ne[3]); - if (raw.name && raw.name[0] != '\0' && raw.name != std::string_view("input")) { - ggml_set_name(tensor, raw.name); - } else { + ggml_set_name(tensor, raw.name); + if (std::string_view(raw.name).starts_with("input")) { inputs.push_back(tensor); } } diff --git a/tests/workbench.py b/tests/workbench.py index 7e7da42..a0b51ae 100644 --- a/tests/workbench.py +++ b/tests/workbench.py @@ -112,7 +112,7 @@ def invoke_test( backend: str = "cpu", ): input = input if isinstance(input, list) else [input] - raw_inputs = [torch_to_raw_tensor("", tensor) for tensor in input] + raw_inputs = [torch_to_raw_tensor(f"input{i}", tensor) for i, tensor in enumerate(input)] raw_inputs += [torch_to_raw_tensor(name, tensor) for name, tensor in state.items()] input_tensors = [t for _, t in raw_inputs] input_tensors # keep the tensors alive From 7178f551a6f94622acedc34acc4a24cb290be136 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 8 Oct 2025 00:39:15 +0200 Subject: [PATCH 05/24] depth-anything: model conversion, api and cli --- include/visp/ml.h | 13 ++++++++ include/visp/vision.h | 32 ++++++++++++++++++++ scripts/convert.py | 27 ++++++++++++++++- src/cli/cli.cpp | 43 ++++++++++++++++++++++++++- src/visp/arch/depth-anything.h | 54 ++++++++++++++++++---------------- src/visp/arch/dino.h | 14 ++------- src/visp/nn.cpp | 1 - src/visp/vision.cpp | 35 ++++++++++++++++++++++ 8 files changed, 179 insertions(+), 40 deletions(-) diff --git a/include/visp/ml.h b/include/visp/ml.h index ed108a4..530ca0d 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -296,6 +296,19 @@ extern swin_params const swin_t_params; extern swin_params const swin_l_params; VISP_API swin_params swin_detect_params(model_file const&); +// +// DINO + +struct dino_params { + int patch_size = 16; + int embed_dim = 384; + int n_blocks = 12; + int n_heads = 6; + int mlp_ratio = 4; + bool flash_attention = false; +}; + + // // implementation diff --git a/include/visp/vision.h b/include/visp/vision.h index 4daeaab..3ac2093 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -162,6 +162,27 @@ VISP_API image_data birefnet_process_output( VISP_API tensor birefnet_predict(model_ref, tensor image, birefnet_params const&); +// +// Depth Anything - depth estimation + +struct depthany_model; + +VISP_API depthany_model depthany_load_model(char const* filepath, backend_device const&); +VISP_API image_data depthany_compute(depthany_model&, image_view image); + +// --- Depth Anything pipeline + +struct depthany_params { + int image_size = 518; + int image_multiple = 14; + i32x2 image_extent = {518, 518}; + std::array feature_layers = {2, 5, 8, 11}; + dino_params dino; +}; + +VISP_API depthany_params depthany_detect_params(model_file const&, i32x2 input_extent = {}); + + // // MI-GAN - image inpainting @@ -246,6 +267,17 @@ struct birefnet_model { tensor output = nullptr; }; +// internal +struct depthany_model { + backend_device const* backend = nullptr; + model_weights weights; + depthany_params params; + + compute_graph graph; + tensor input = nullptr; + tensor output = nullptr; +}; + // internal struct migan_model { backend_device const* backend = nullptr; diff --git a/scripts/convert.py b/scripts/convert.py index 054bf42..feaefaf 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -100,7 +100,7 @@ def is_conv_2d(name: str, tensor: Tensor): return ( tensor.ndim == 4 and tensor.shape[2] == tensor.shape[3] - and tensor.shape[2] in (1, 3, 4, 7) + and tensor.shape[2] in (1, 3, 4, 7, 14) and name.endswith("weight") ) @@ -341,6 +341,28 @@ def convert_birefnet(input_filepath: Path, writer: Writer): writer.add_tensor(name, tensor) +# +# Depth-Anything + + +def convert_depth_anything(input_filepath: Path, writer: Writer): + writer.add_license("apache-2.0") + writer.set_tensor_layout_default(TensorLayout.nchw) + + model: dict[str, Tensor] = torch.load(input_filepath, map_location="cpu", weights_only=True) + + for key, tensor in model.items(): + name = key + + if is_conv_2d(name, tensor): + if "patch_embed" in name or "projects" in name: + tensor = conv_2d_to_nhwc(tensor) + else: + tensor = writer.convert_tensor_2d(tensor) + + writer.add_tensor(name, tensor) + + # # MI-GAN @@ -400,6 +422,7 @@ def convert_esrgan(input_filepath: Path, writer: Writer): arch_names = { "sam": "mobile-sam", "birefnet": "birefnet", + "depth-anything": "depth-anything", "migan": "migan", "esrgan": "esrgan", } @@ -448,6 +471,8 @@ def convert_esrgan(input_filepath: Path, writer: Writer): convert_sam(input_path, writer) case "birefnet": convert_birefnet(input_path, writer) + case "depthany" | "depth-anything": + convert_depth_anything(input_path, writer) case "migan": convert_migan(input_path, writer) case "esrgan": diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index 3e37434..a1a9472 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -2,6 +2,8 @@ #include "util/string.h" #include "visp/vision.h" +#include "visp/arch/depth-anything.h" + #include #include #include @@ -13,7 +15,7 @@ namespace visp { using std::filesystem::path; -enum class cli_command { none, sam, birefnet, migan, esrgan }; +enum class cli_command { none, sam, birefnet, depth_anything, migan, esrgan }; struct cli_args { cli_command command = cli_command::none; @@ -38,6 +40,7 @@ Usage: vision-cli [options] Commands: sam - MobileSAM image segmentation birefnet - BirefNet background removal + depthany - Depth-Anything depth estimation migan - MI-GAN inpainting esrgan - ESRGAN/Real-ESRGAN upscaling @@ -119,6 +122,8 @@ cli_args cli_parse(int argc, char** argv) { r.command = cli_command::sam; } else if (arg1 == "birefnet") { r.command = cli_command::birefnet; + } else if (arg1 == "depthany" || arg1 == "depth-anything") { + r.command = cli_command::depth_anything; } else if (arg1 == "migan") { r.command = cli_command::migan; } else if (arg1 == "esrgan") { @@ -162,6 +167,7 @@ cli_args cli_parse(int argc, char** argv) { void run_sam(cli_args const&); void run_birefnet(cli_args const&); +void run_depth_anything(cli_args const&); void run_migan(cli_args const&); void run_esrgan(cli_args const&); @@ -179,6 +185,7 @@ int main(int argc, char** argv) { switch (args.command) { case cli_command::sam: run_sam(args); break; case cli_command::birefnet: run_birefnet(args); break; + case cli_command::depth_anything: run_depth_anything(args); break; case cli_command::migan: run_migan(args); break; case cli_command::esrgan: run_esrgan(args); break; case cli_command::none: break; @@ -432,6 +439,40 @@ void run_birefnet(cli_args const& args) { composite_image_with_mask(image, mask_resized, args.composite); } +// +// Depth Anything + +void run_depth_anything(cli_args const& args) { + backend_device backend = backend_init(args); + auto [file, weights] = load_model_weights( + args, backend, "models/DepthAnythingV2-Small-F32.gguf"); + + require_inputs(args.inputs, 1, ""); + image_data image = image_load(args.inputs[0]); + depthany_params params = depthany_detect_params(file, image.extent); + image_data input_data = depthany_process_input(image, params); + + i32x2 extent = params.image_extent; + printf("- model image size: %d\n", params.image_size); + printf("- inference image size: %dx%d\n", params.image_extent[0], params.image_extent[1]); + + compute_graph graph = compute_graph_init(); + model_ref m(weights, graph); + + tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); + tensor output = depthany_predict(m, input, params); + + compute_graph_allocate(graph, backend); + transfer_to_backend(input, input_data); + + compute_timed(graph, backend); + + tensor_data output_data = transfer_from_backend(output); + image_data depth_image = depthany_process_output(output_data.as_f32(), image.extent, params); + image_save(depth_image, args.output); + printf("-> depth image saved to %s\n", args.output); +} + // // MI-GAN diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h index 9fd0ee5..fb41c2b 100644 --- a/src/visp/arch/depth-anything.h +++ b/src/visp/arch/depth-anything.h @@ -1,25 +1,13 @@ +#pragma once + #include "visp/arch/dino.h" +#include "visp/vision.h" #include "visp/ml.h" #include "visp/nn.h" namespace visp { - -struct depthany_params { - int image_size = 518; - int image_multiple = 14; - std::array feature_layers = {2, 5, 8, 11}; - dino_params dino; -}; - namespace dpt { -i32x2 compute_inference_extent(i32x2 extent, depthany_params const& p) { - int min_side = std::min(extent[0], extent[1]); - int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple)); - i32x2 target = extent * tgt_side / min_side; - return next_multiple(target, p.image_multiple); -} - tensor residual_conv(model_ref m, tensor x) { tensor out = x; out = ggml_relu(m, out); @@ -40,18 +28,19 @@ tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = n int64_t w = size ? size[0] : x->ne[0] * 2; int64_t h = size ? size[1] : x->ne[1] * 2; - x = interpolate(m, x, {w, h}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS); + int32_t mode = int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS; + x = interpolate(m, x, {w, h}, mode); x = conv_2d(m["out_conv"], x); return named(m, x); } -tensor head(model_ref m, span features, int patch_w, int patch_h) { +tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h) { ASSERT(features.size() == 4); std::array layer; for (int i = 0; i < 4; ++i) { tensor x = features[i]; - x = slice(m, x, {}, {}, {}, 0); + x = slice(m, x, {}, {1, x->ne[1]}, {}, {}); x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]); model_ref proj = m["projects"][i]; @@ -81,7 +70,7 @@ tensor head(model_ref m, span features, int patch_w, int patch_h) { tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); out = interpolate( m, out, {patch_w * 14, patch_h * 14}, - GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS); + int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS); model_ref output_conv2 = scratch["output_conv2"]; out = conv_2d(output_conv2[0], out, 1, 1); @@ -95,8 +84,8 @@ tensor head(model_ref m, span features, int patch_w, int patch_h) { inline tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) { auto [c, w, h, n] = nelements(image); - int w_patch = w / p.dino.patch_size; - int h_patch = h / p.dino.patch_size; + int64_t w_patch = w / p.dino.patch_size; + int64_t h_patch = h / p.dino.patch_size; auto features = dino_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); @@ -104,14 +93,29 @@ inline tensor depthany_predict(model_ref m, tensor image, depthany_params const& return compute_graph_output(m, depth); } +i32x2 depthany_image_extent(i32x2 extent, depthany_params const& p) { + int min_side = std::min(extent[0], extent[1]); + int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple)); + i32x2 target = extent * tgt_side / min_side; + return next_multiple(target, p.image_multiple); +} + +inline depthany_params depthany_detect_params(model_file const&, i32x2 input_extent) { + depthany_params p; + p.dino.patch_size = 14; + if (input_extent[0] > 0 && input_extent[1] > 0) { + p.image_extent = depthany_image_extent(input_extent, p); + } + return p; +} + inline image_data depthany_process_input(image_view image, depthany_params const& p) { constexpr f32x4 mean = f32x4{0.485f, 0.456f, 0.406f, 0.f}; constexpr f32x4 std = f32x4{0.229f, 0.224f, 0.225f, 1.f}; - i32x2 target = dpt::compute_inference_extent(image.extent, p); image_data resized; - if (image.extent != target) { - resized = image_scale(image, target); + if (image.extent != p.image_extent) { + resized = image_scale(image, p.image_extent); image = image_view(resized); } return image_u8_to_f32(image, image_format::rgb_f32, -mean, 1.f / std); @@ -120,7 +124,7 @@ inline image_data depthany_process_input(image_view image, depthany_params const inline image_data depthany_process_output( span data, i32x2 extent, depthany_params const& p) { - image_view depth_output(dpt::compute_inference_extent(extent, p), data); + image_view depth_output(p.image_extent, data); image_data depth_resized; if (depth_output.extent != extent) { depth_resized = image_scale(depth_output, extent); diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index c774b6a..0f0e53d 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -10,16 +10,6 @@ #pragma optimize("", off) namespace visp { - -struct dino_params { - int patch_size = 16; - int embed_dim = 384; - int n_blocks = 12; - int n_heads = 6; - int mlp_ratio = 4; - bool flash_attention = false; -}; - namespace dino { inline tensor interpolate_pos_encoding( @@ -96,8 +86,8 @@ inline tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { tensor v = split(qkv, 2, !flash_attn); if (flash_attn) { - int64_t c_pad = GGML_PAD(c, 4) - c; - int64_t n_pad = GGML_PAD(n, 32) - n; + int c_pad = int(GGML_PAD(c, 4) - c); + int n_pad = int(GGML_PAD(n, 32) - n); q = ggml_pad(m, q, c_pad, n_pad, 0, 0); k = ggml_pad(m, k, c_pad, n_pad, 0, 0); v = ggml_pad(m, v, c_pad, n_pad, 0, 0); diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp index e6092bb..53ab710 100644 --- a/src/visp/nn.cpp +++ b/src/visp/nn.cpp @@ -3,7 +3,6 @@ namespace visp { - tensor linear(model_ref m, tensor x) { x = ggml_mul_mat(m, m.weights("weight"), x); if (tensor bias = m.find("bias")) { diff --git a/src/visp/vision.cpp b/src/visp/vision.cpp index bd8216e..99e43c6 100644 --- a/src/visp/vision.cpp +++ b/src/visp/vision.cpp @@ -115,6 +115,41 @@ image_data birefnet_compute(birefnet_model& model, image_view image) { return birefnet_process_output(mask_data.as_f32(), image.extent, model.params); } +// +// Depth Anything + +depthany_model depthany_load_model(char const* filepath, backend_device const& dev) { + depthany_model model; + model.backend = &dev; + model_file file = model_load(filepath); + model.weights = model_init(file.n_tensors()); + model_transfer(file, model.weights, dev, dev.preferred_float_type(), dev.preferred_layout()); + return model; +} + +image_data depthany_compute(depthany_model& model, image_view image) { + depthany_params params{}; + i32x2 res = depthany_image_extent(image.extent, params); + + if (!model.graph || res != model.params.image_extent) { + model.params.image_extent = res; + model.graph = compute_graph_init(); + + model_ref m(model.weights, model.graph); + model.input = compute_graph_input(m, GGML_TYPE_F32, {3, res[0], res[1], 1}); + model.output = depthany_predict(m, model.input, params); + compute_graph_allocate(model.graph, *model.backend); + } + + image_data img_data = depthany_process_input(image, params); + transfer_to_backend(model.input, img_data); + + compute(model.graph, *model.backend); + + tensor_data output_data = transfer_from_backend(model.output); + return depthany_process_output(output_data.as_f32(), image.extent, params); +} + // // MI-GAN From a30a3535f0f63f7679f30c4180936c7f04300a1e Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 8 Oct 2025 20:20:11 +0200 Subject: [PATCH 06/24] ml: add helper to save tensor data to file (mostly for debug) --- include/visp/ml.h | 1 + src/visp/ml.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/include/visp/ml.h b/include/visp/ml.h index 530ca0d..5b1da1b 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -225,6 +225,7 @@ VISP_API tensor_data tensor_alloc(tensor x); // Loads tensor data from a file storing raw numbers as binary. VISP_API tensor_data tensor_load(tensor x, char const* filepath); +VISP_API void tensor_save(tensor x, char const* filepath); // Copies data to the tensor's backend buffer (which should already be allocated). VISP_API void transfer_to_backend(tensor_data const&); diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index 8a85888..b214f8d 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -587,6 +587,18 @@ tensor_data tensor_load(tensor x, char const* filepath) { return result; } +void tensor_save(tensor x, char const* filepath) { + FILE* file = fopen(filepath, "wb"); + if (!file) { + throw except("Failed to open file for writing: {}", filepath); + } + size_t written = fwrite(x->data, 1, ggml_nbytes(x), file); + fclose(file); + if (written != ggml_nbytes(x)) { + throw except("Failed to write tensor data to file: {}", filepath); + } +} + std::span tensor_data::as_f32() { ASSERT(x->type == GGML_TYPE_F32); return span(reinterpret_cast(data.get()), ggml_nelements(x)); From 7382f25bf86b8dc4534ca67874865f88bdedf41d Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 8 Oct 2025 21:26:11 +0200 Subject: [PATCH 07/24] dino: fix bad _inplace causing output to be overwritten sometimes --- src/cli/cli.cpp | 3 +- src/visp/arch/depth-anything.h | 6 +- src/visp/arch/dino.h | 14 +- tests/test_depth_anything.py | 278 +++++++++++++++++++++++++++++++++ tests/workbench.cpp | 11 ++ 5 files changed, 302 insertions(+), 10 deletions(-) diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index a1a9472..db61365 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -460,7 +460,8 @@ void run_depth_anything(cli_args const& args) { model_ref m(weights, graph); tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); - tensor output = depthany_predict(m, input, params); + tensor depth = depthany_predict(m, input, params); + tensor output = compute_graph_output(m, ggml_sigmoid(m, depth)); compute_graph_allocate(graph, backend); transfer_to_backend(input, input_data); diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h index fb41c2b..efa0f91 100644 --- a/src/visp/arch/depth-anything.h +++ b/src/visp/arch/depth-anything.h @@ -1,9 +1,9 @@ #pragma once #include "visp/arch/dino.h" -#include "visp/vision.h" #include "visp/ml.h" #include "visp/nn.h" +#include "visp/vision.h" namespace visp { namespace dpt { @@ -87,9 +87,9 @@ inline tensor depthany_predict(model_ref m, tensor image, depthany_params const& int64_t w_patch = w / p.dino.patch_size; int64_t h_patch = h / p.dino.patch_size; - auto features = dino_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); + auto features = dino_get_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); - depth = ggml_relu_inplace(m, depth); + // depth = ggml_relu_inplace(m, depth); <- reference does another ReLU here return compute_graph_output(m, depth); } diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 0f0e53d..616d37c 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -115,16 +115,16 @@ inline tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { inline tensor block(model_ref m, tensor x, dino_params const& p) { tensor attn = x; - attn = layer_norm(m["norm1"], attn); + attn = layer_norm(m["norm1"], attn, 1e-6f); attn = attention(m["attn"], attn, p.n_heads, p.flash_attention); attn = layer_scale(m["ls1"], attn); - x = ggml_add_inplace(m, x, attn); + x = ggml_add(m, x, attn); tensor ffn = x; - ffn = layer_norm(m["norm2"], ffn); + ffn = layer_norm(m["norm2"], ffn, 1e-6f); ffn = mlp(m["mlp"], ffn); ffn = layer_scale(m["ls2"], ffn); - x = ggml_add_inplace(m, x, ffn); + x = ggml_add(m, x, ffn); return named(m, x); } @@ -145,7 +145,9 @@ inline std::vector get_intermediate_layers( x = block(blocks[i], x, p); if (contains(layers, i)) { - tensor out = layer_norm(m["norm"], x); + tensor out = layer_norm(m["norm"], x, 1e-6f); + ggml_format_name(out, "dino_layer_%d", i); + ggml_build_forward_expand(m.graph, out); outputs.push_back(out); } } @@ -154,7 +156,7 @@ inline std::vector get_intermediate_layers( } // namespace dino -inline std::vector dino_intermediate_layers( +inline std::vector dino_get_intermediate_layers( model_ref m, tensor x, std::span layers, dino_params const& p) { return dino::get_intermediate_layers(m, x, layers, p); } diff --git a/tests/test_depth_anything.py b/tests/test_depth_anything.py index e3d428e..235fae2 100644 --- a/tests/test_depth_anything.py +++ b/tests/test_depth_anything.py @@ -1,3 +1,4 @@ +from functools import partial import math import pytest import torch @@ -8,6 +9,10 @@ from tests import workbench from tests.workbench import convert_to_nhwc, generate_state, input_tensor, to_nchw, to_nhwc +# +# DINOv2 +# + class PatchEmbed(nn.Module): def __init__( @@ -290,6 +295,279 @@ def test_block(): assert torch.allclose(result, expected, atol=1e-2) # precision drop due to GELU in MLP +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + assert False, "swiglu not implemented" + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + # drop_path=dpr[i], + # norm_layer=norm_layer, + # act_layer=act_layer, + # ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + # self.init_weights() + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x, flatten=True) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + print(x.shape, self.cls_token.shape) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append({ + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + }) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), ( + f"only {len(output)} / {len(blocks_to_take)} blocks found" + ) + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: int | list[int] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ): + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def test_dino_intermediate_layers(): + img_size = 8 + patch_size = 4 + embed_dim = 6 + depth = 4 + num_heads = 3 + module = DinoVisionTransformer( + img_size=(img_size, img_size), + patch_size=(patch_size, patch_size), + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + init_values=1.0, + interpolate_offset=0.0, + ) + state = generate_state(module.state_dict()) + module.load_state_dict(state) + module.eval() + + x = input_tensor(1, 3, img_size, img_size) + expected = module.get_intermediate_layers(x, n=4) + + state = convert_to_nhwc(state, key="patch_embed.proj") + x = to_nhwc(x) + result = workbench.invoke_test("dino_intermediate_layers", x, state) + + for r, e in zip(result, expected): + r = r.squeeze(0) + r = r[:, 1:, :] # remove cls token + assert torch.allclose(r, e) + + +# +# Depth Anything +# + + class ResidualConvUnit(nn.Module): def __init__(self, features, activation, bn=False): super().__init__() diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 76d7ac7..8c70076 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -439,6 +439,17 @@ DEF(dino_block)(model_ref m, span input, param_dict const& p) { return {dino::block(m, input[0], params)}; } +DEF(dino_intermediate_layers)(model_ref m, span input, param_dict const& p) { + dino_params params{}; + params.patch_size = 4; + params.embed_dim = 6; + params.n_blocks = 4; + params.n_heads = 3; + params.flash_attention = p.get("flash_attn", 0) != 0; + auto layers = std::array{0, 1, 2, 3}; + return dino::get_intermediate_layers(m, input[0], layers, params); +} + // // Depth Anything From 597feed97b7bda1f1f232539fe4c11045927965a Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 8 Oct 2025 21:35:16 +0200 Subject: [PATCH 08/24] dino: move implementation to .cpp file --- include/visp/ml.h | 2 + src/visp/CMakeLists.txt | 1 + src/visp/arch/dino.cpp | 158 ++++++++++++++++++++++++++++++++++++++ src/visp/arch/dino.h | 164 +++------------------------------------- 4 files changed, 171 insertions(+), 154 deletions(-) create mode 100644 src/visp/arch/dino.cpp diff --git a/include/visp/ml.h b/include/visp/ml.h index 5b1da1b..eebf1c6 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -309,6 +309,8 @@ struct dino_params { bool flash_attention = false; }; +VISP_API std::vector dino_get_intermediate_layers( + model_ref, tensor image, std::span layers, dino_params const&); // // implementation diff --git a/src/visp/CMakeLists.txt b/src/visp/CMakeLists.txt index 5cdbd54..ca45ca9 100644 --- a/src/visp/CMakeLists.txt +++ b/src/visp/CMakeLists.txt @@ -2,6 +2,7 @@ add_library(visioncpp SHARED) target_sources(visioncpp PRIVATE arch/birefnet.cpp + arch/dino.cpp arch/esrgan.cpp arch/migan.cpp arch/mobile-sam.cpp diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp new file mode 100644 index 0000000..3da855a --- /dev/null +++ b/src/visp/arch/dino.cpp @@ -0,0 +1,158 @@ +#include "util/math.h" +#include "visp/arch/dino.h" +#include "visp/ml.h" +#include "visp/nn.h" + +namespace visp { +namespace dino { + +tensor interpolate_pos_encoding( + model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { + + tensor pos_embed = ggml_cast(m, m.weights("pos_embed"), GGML_TYPE_F32); + int64_t n_patch = x->ne[1] - 1; + int64_t n = pos_embed->ne[1] - 1; + if (n_patch == n && w == h) { + return pos_embed; + } + + tensor class_embed = slice(m, pos_embed, {}, {0}, {}, {}); + tensor patch_embed = slice(m, pos_embed, {}, {1, n + 1}, {}, {}); + int64_t dim = x->ne[0]; + i64x2 target = i64x2{w, h} / patch_size; + int64_t sqrt_n = int64_t(std::sqrt(float(n)) + 0.01f); + + patch_embed = ggml_reshape_4d(m, patch_embed, dim, sqrt_n, sqrt_n, 1); + patch_embed = ggml_cont(m, permute_cwhn_to_whcn(m, patch_embed)); + patch_embed = interpolate(m, patch_embed, target, GGML_SCALE_MODE_BICUBIC); + patch_embed = ggml_cont(m, permute_whcn_to_cwhn(m, patch_embed)); + patch_embed = ggml_reshape_3d(m, patch_embed, dim, target[0] * target[1], 1); + return concat(m, {class_embed, patch_embed}, 1); +} + +tensor prepare_tokens(model_ref m, tensor x, int patch_size) { + auto [c, w, h, n] = nelements(x); + x = patch_embed(m["patch_embed"], x, patch_size); + x = ggml_reshape_3d(m, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); + + tensor cls_token = m.weights("cls_token"); + if (cls_token->ne[2] != n) { + cls_token = ggml_repeat_4d(m, cls_token, cls_token->ne[0], 1, n, 1); + } + x = concat(m, {cls_token, x}, 1); + + tensor pos_enc = interpolate_pos_encoding(m, x, w, h, patch_size); + x = ggml_add_inplace(m, x, pos_enc); + return x; +} + +tensor layer_scale(model_ref m, tensor x) { + return ggml_mul(m, x, m.weights("gamma")); +} + +tensor mlp(model_ref m, tensor x) { + x = linear(m["fc1"], x); + x = ggml_gelu(m, x); + x = linear(m["fc2"], x); + return x; +} + +tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { + auto [c, n, b, _] = nelements(x); + float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); + + tensor qkv = linear(m["qkv"], x); + qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); + qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); + + auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { + tensor = slice(m, tensor, {}, {}, {}, index); + tensor = ggml_reshape_4d(m, tensor, c / n_heads, n_heads, n, b); + if (transpose) { + tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); + } else { + tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); + } + return tensor; + }; + tensor q = split(qkv, 0); + tensor k = split(qkv, 1); + tensor v = split(qkv, 2, !flash_attn); + + if (flash_attn) { + int c_pad = int(GGML_PAD(c, 4) - c); + int n_pad = int(GGML_PAD(n, 32) - n); + q = ggml_pad(m, q, c_pad, n_pad, 0, 0); + k = ggml_pad(m, k, c_pad, n_pad, 0, 0); + v = ggml_pad(m, v, c_pad, n_pad, 0, 0); + + ggml_type dtype = m.weights("qkv.weight")->type; + k = ggml_cast(m, k, dtype); + v = ggml_cast(m, v, dtype); + + x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); + x = slice(m, x, {}, {}, {0, n}, {}); + } else { + q = ggml_scale_inplace(m, q, scale); + + tensor attn = ggml_mul_mat(m, k, q); + attn = ggml_soft_max(m, attn); + + x = ggml_mul_mat(m, v, attn); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_3d(m, x, c, n, b); + } + + x = linear(m["proj"], x); + return named(m, x); +} + +tensor block(model_ref m, tensor x, dino_params const& p) { + tensor attn = x; + attn = layer_norm(m["norm1"], attn, 1e-6f); + attn = attention(m["attn"], attn, p.n_heads, p.flash_attention); + attn = layer_scale(m["ls1"], attn); + x = ggml_add(m, x, attn); + + tensor ffn = x; + ffn = layer_norm(m["norm2"], ffn, 1e-6f); + ffn = mlp(m["mlp"], ffn); + ffn = layer_scale(m["ls2"], ffn); + x = ggml_add(m, x, ffn); + + return named(m, x); +} + +template +bool contains(std::span r, T const& value) { + return std::find(r.begin(), r.end(), value) != r.end(); +} + +std::vector get_intermediate_layers( + model_ref m, tensor x, std::span layers, dino_params const& p) { + + x = prepare_tokens(m, x, p.patch_size); + + std::vector outputs; + model_ref blocks = m["blocks"]; + for (int i = 0; i < p.n_blocks; ++i) { + x = block(blocks[i], x, p); + + if (contains(layers, i)) { + tensor out = layer_norm(m["norm"], x, 1e-6f); + ggml_format_name(out, "dino_layer_%d", i); + ggml_build_forward_expand(m.graph, out); + outputs.push_back(out); + } + } + return outputs; +} + +} // namespace dino + +std::vector dino_get_intermediate_layers( + model_ref m, tensor x, std::span layers, dino_params const& p) { + return dino::get_intermediate_layers(m, x, layers, p); +} + +} // namespace visp diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 616d37c..51ffdc5 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -1,164 +1,20 @@ #pragma once #include "util/math.h" -#include "visp/image.h" #include "visp/ml.h" -#include "visp/nn.h" #include -#pragma optimize("", off) +namespace visp::dino { -namespace visp { -namespace dino { +tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int patch_size); +tensor prepare_tokens(model_ref m, tensor x, int patch_size); +tensor layer_scale(model_ref m, tensor x); +tensor mlp(model_ref m, tensor x); +tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn); +tensor block(model_ref m, tensor x, dino_params const& p); -inline tensor interpolate_pos_encoding( - model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { +std::vector get_intermediate_layers( + model_ref m, tensor x, std::span layers, dino_params const& p); - tensor pos_embed = ggml_cast(m, m.weights("pos_embed"), GGML_TYPE_F32); - int64_t n_patch = x->ne[1] - 1; - int64_t n = pos_embed->ne[1] - 1; - if (n_patch == n && w == h) { - return pos_embed; - } - - tensor class_embed = slice(m, pos_embed, {}, {0}, {}, {}); - tensor patch_embed = slice(m, pos_embed, {}, {1, n + 1}, {}, {}); - int64_t dim = x->ne[0]; - i64x2 target = i64x2{w, h} / patch_size; - int64_t sqrt_n = int64_t(std::sqrt(float(n)) + 0.01f); - - patch_embed = ggml_reshape_4d(m, patch_embed, dim, sqrt_n, sqrt_n, 1); - patch_embed = ggml_cont(m, permute_cwhn_to_whcn(m, patch_embed)); - patch_embed = interpolate(m, patch_embed, target, GGML_SCALE_MODE_BICUBIC); - patch_embed = ggml_cont(m, permute_whcn_to_cwhn(m, patch_embed)); - patch_embed = ggml_reshape_3d(m, patch_embed, dim, target[0] * target[1], 1); - return concat(m, {class_embed, patch_embed}, 1); -} - -inline tensor prepare_tokens(model_ref m, tensor x, int patch_size) { - auto [c, w, h, n] = nelements(x); - x = patch_embed(m["patch_embed"], x, patch_size); - x = ggml_reshape_3d(m, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); - - tensor cls_token = m.weights("cls_token"); - if (cls_token->ne[2] != n) { - cls_token = ggml_repeat_4d(m, cls_token, cls_token->ne[0], 1, n, 1); - } - x = concat(m, {cls_token, x}, 1); - - tensor pos_enc = interpolate_pos_encoding(m, x, w, h, patch_size); - x = ggml_add_inplace(m, x, pos_enc); - return x; -} - -inline tensor layer_scale(model_ref m, tensor x) { - return ggml_mul(m, x, m.weights("gamma")); -} - -inline tensor mlp(model_ref m, tensor x) { - x = linear(m["fc1"], x); - x = ggml_gelu(m, x); - x = linear(m["fc2"], x); - return x; -} - -inline tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { - auto [c, n, b, _] = nelements(x); - float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); - - tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); - qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - - auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { - tensor = slice(m, tensor, {}, {}, {}, index); - tensor = ggml_reshape_4d(m, tensor, c / n_heads, n_heads, n, b); - if (transpose) { - tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); - } else { - tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); - } - return tensor; - }; - tensor q = split(qkv, 0); - tensor k = split(qkv, 1); - tensor v = split(qkv, 2, !flash_attn); - - if (flash_attn) { - int c_pad = int(GGML_PAD(c, 4) - c); - int n_pad = int(GGML_PAD(n, 32) - n); - q = ggml_pad(m, q, c_pad, n_pad, 0, 0); - k = ggml_pad(m, k, c_pad, n_pad, 0, 0); - v = ggml_pad(m, v, c_pad, n_pad, 0, 0); - - ggml_type dtype = m.weights("qkv.weight")->type; - k = ggml_cast(m, k, dtype); - v = ggml_cast(m, v, dtype); - - x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); - x = slice(m, x, {}, {}, {0, n}, {}); - } else { - q = ggml_scale_inplace(m, q, scale); - - tensor attn = ggml_mul_mat(m, k, q); - attn = ggml_soft_max(m, attn); - - x = ggml_mul_mat(m, v, attn); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_3d(m, x, c, n, b); - } - - x = linear(m["proj"], x); - return named(m, x); -} - -inline tensor block(model_ref m, tensor x, dino_params const& p) { - tensor attn = x; - attn = layer_norm(m["norm1"], attn, 1e-6f); - attn = attention(m["attn"], attn, p.n_heads, p.flash_attention); - attn = layer_scale(m["ls1"], attn); - x = ggml_add(m, x, attn); - - tensor ffn = x; - ffn = layer_norm(m["norm2"], ffn, 1e-6f); - ffn = mlp(m["mlp"], ffn); - ffn = layer_scale(m["ls2"], ffn); - x = ggml_add(m, x, ffn); - - return named(m, x); -} - -template -bool contains(std::span r, T const& value) { - return std::find(r.begin(), r.end(), value) != r.end(); -} - -inline std::vector get_intermediate_layers( - model_ref m, tensor x, std::span layers, dino_params const& p) { - - x = prepare_tokens(m, x, p.patch_size); - - std::vector outputs; - model_ref blocks = m["blocks"]; - for (int i = 0; i < p.n_blocks; ++i) { - x = block(blocks[i], x, p); - - if (contains(layers, i)) { - tensor out = layer_norm(m["norm"], x, 1e-6f); - ggml_format_name(out, "dino_layer_%d", i); - ggml_build_forward_expand(m.graph, out); - outputs.push_back(out); - } - } - return outputs; -} - -} // namespace dino - -inline std::vector dino_get_intermediate_layers( - model_ref m, tensor x, std::span layers, dino_params const& p) { - return dino::get_intermediate_layers(m, x, layers, p); -} - -} // namespace visp +} // namespace visp::dino From 58c18d73a1c8d3e0198218289490d14c66f9d298 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 8 Oct 2025 21:46:12 +0200 Subject: [PATCH 09/24] depth-anything: move implementation to .cpp file --- include/visp/vision.h | 6 ++ src/cli/cli.cpp | 2 - src/visp/CMakeLists.txt | 1 + src/visp/arch/depth-anything.cpp | 137 +++++++++++++++++++++++++++++++ src/visp/arch/depth-anything.h | 134 ++---------------------------- 5 files changed, 149 insertions(+), 131 deletions(-) create mode 100644 src/visp/arch/depth-anything.cpp diff --git a/include/visp/vision.h b/include/visp/vision.h index 3ac2093..23b628d 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -181,7 +181,13 @@ struct depthany_params { }; VISP_API depthany_params depthany_detect_params(model_file const&, i32x2 input_extent = {}); +VISP_API i32x2 depthany_image_extent(i32x2 input_extent, depthany_params const&); +VISP_API image_data depthany_process_input(image_view image, depthany_params const&); +image_data depthany_process_output( + span output_data, i32x2 target_extent, depthany_params const&); + +VISP_API tensor depthany_predict(model_ref, tensor image, depthany_params const&); // // MI-GAN - image inpainting diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index db61365..fe61748 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -2,8 +2,6 @@ #include "util/string.h" #include "visp/vision.h" -#include "visp/arch/depth-anything.h" - #include #include #include diff --git a/src/visp/CMakeLists.txt b/src/visp/CMakeLists.txt index ca45ca9..39d321b 100644 --- a/src/visp/CMakeLists.txt +++ b/src/visp/CMakeLists.txt @@ -2,6 +2,7 @@ add_library(visioncpp SHARED) target_sources(visioncpp PRIVATE arch/birefnet.cpp + arch/depth-anything.cpp arch/dino.cpp arch/esrgan.cpp arch/migan.cpp diff --git a/src/visp/arch/depth-anything.cpp b/src/visp/arch/depth-anything.cpp new file mode 100644 index 0000000..e9d6b8b --- /dev/null +++ b/src/visp/arch/depth-anything.cpp @@ -0,0 +1,137 @@ + +#include "visp/arch/depth-anything.h" +#include "util/math.h" +#include "util/string.h" +#include "visp/arch/dino.h" +#include "visp/ml.h" +#include "visp/nn.h" + +namespace visp { +namespace dpt { + +tensor residual_conv(model_ref m, tensor x) { + tensor out = x; + out = ggml_relu(m, out); + out = conv_2d(m["conv1"], out, 1, 1); + out = ggml_relu(m, out); + out = conv_2d(m["conv2"], out, 1, 1); + x = ggml_add_inplace(m, x, out); + return named(m, x); +} + +tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size) { + tensor x = x0; + if (x1) { + tensor res = residual_conv(m["resConfUnit1"], x1); + x = ggml_add_inplace(m, x, res); + } + x = residual_conv(m["resConfUnit2"], x); + + int64_t w = size ? size[0] : x->ne[0] * 2; + int64_t h = size ? size[1] : x->ne[1] * 2; + int32_t mode = int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS; + x = interpolate(m, x, {w, h}, mode); + + x = conv_2d(m["out_conv"], x); + return named(m, x); +} + +tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h) { + ASSERT(features.size() == 4); + + std::array layer; + for (int i = 0; i < 4; ++i) { + tensor x = features[i]; + x = slice(m, x, {}, {1, x->ne[1]}, {}, {}); + x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]); + + model_ref proj = m["projects"][i]; + proj.flags |= model_build_flag::cwhn; + x = conv_2d(proj, x); // 1x1 conv, keep CWHN layout and directly use mul_mat + + x = cwhn_to_contiguous_2d(m, x); + switch (i) { + case 0: x = conv_transpose_2d(m["resize_layers"][i], x, 4); break; + case 1: x = conv_transpose_2d(m["resize_layers"][i], x, 2); break; + case 3: x = conv_2d(m["resize_layers"][i], x, 2, 1); break; + } + layer[i] = x; + } + + model_ref scratch = m["scratch"]; + tensor layer1_rn = conv_2d(scratch["layer1_rn"], layer[0], 1, 1); + tensor layer2_rn = conv_2d(scratch["layer2_rn"], layer[1], 1, 1); + tensor layer3_rn = conv_2d(scratch["layer3_rn"], layer[2], 1, 1); + tensor layer4_rn = conv_2d(scratch["layer4_rn"], layer[3], 1, 1); + + tensor path4 = feature_fusion(scratch["refinenet4"], layer4_rn, nullptr, layer3_rn->ne); + tensor path3 = feature_fusion(scratch["refinenet3"], path4, layer3_rn, layer2_rn->ne); + tensor path2 = feature_fusion(scratch["refinenet2"], path3, layer2_rn, layer1_rn->ne); + tensor path1 = feature_fusion(scratch["refinenet1"], path2, layer1_rn); + + tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); + out = interpolate( + m, out, {patch_w * 14, patch_h * 14}, + int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS); + + model_ref output_conv2 = scratch["output_conv2"]; + out = conv_2d(output_conv2[0], out, 1, 1); + out = ggml_relu_inplace(m, out); + out = conv_2d(output_conv2[2], out); + out = ggml_relu_inplace(m, out); + return out; +} + +} // namespace dpt + +tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) { + auto [c, w, h, n] = nelements(image); + int64_t w_patch = w / p.dino.patch_size; + int64_t h_patch = h / p.dino.patch_size; + + auto features = dino_get_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); + tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); + // depth = ggml_relu_inplace(m, depth); <- reference does another ReLU here + return compute_graph_output(m, depth); +} + +i32x2 depthany_image_extent(i32x2 extent, depthany_params const& p) { + int min_side = std::min(extent[0], extent[1]); + int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple)); + i32x2 target = extent * tgt_side / min_side; + return next_multiple(target, p.image_multiple); +} + +depthany_params depthany_detect_params(model_file const&, i32x2 input_extent) { + depthany_params p; + p.dino.patch_size = 14; + if (input_extent[0] > 0 && input_extent[1] > 0) { + p.image_extent = depthany_image_extent(input_extent, p); + } + return p; +} + +image_data depthany_process_input(image_view image, depthany_params const& p) { + constexpr f32x4 mean = f32x4{0.485f, 0.456f, 0.406f, 0.f}; + constexpr f32x4 std = f32x4{0.229f, 0.224f, 0.225f, 1.f}; + + image_data resized; + if (image.extent != p.image_extent) { + resized = image_scale(image, p.image_extent); + image = image_view(resized); + } + return image_u8_to_f32(image, image_format::rgb_f32, -mean, 1.f / std); +} + +image_data depthany_process_output(span data, i32x2 extent, depthany_params const& p) { + + image_view depth_output(p.image_extent, data); + image_data depth_resized; + if (depth_output.extent != extent) { + depth_resized = image_scale(depth_output, extent); + depth_output = depth_resized; + } + return image_f32_to_u8(depth_output, image_format::alpha_u8); +} + +} // namespace visp \ No newline at end of file diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h index efa0f91..d937124 100644 --- a/src/visp/arch/depth-anything.h +++ b/src/visp/arch/depth-anything.h @@ -1,136 +1,12 @@ #pragma once -#include "visp/arch/dino.h" #include "visp/ml.h" -#include "visp/nn.h" #include "visp/vision.h" -namespace visp { -namespace dpt { +namespace visp::dpt { -tensor residual_conv(model_ref m, tensor x) { - tensor out = x; - out = ggml_relu(m, out); - out = conv_2d(m["conv1"], out, 1, 1); - out = ggml_relu(m, out); - out = conv_2d(m["conv2"], out, 1, 1); - x = ggml_add_inplace(m, x, out); - return named(m, x); -} +tensor residual_conv(model_ref m, tensor x); +tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = nullptr); +tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h); -tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = nullptr) { - tensor x = x0; - if (x1) { - tensor res = residual_conv(m["resConfUnit1"], x1); - x = ggml_add_inplace(m, x, res); - } - x = residual_conv(m["resConfUnit2"], x); - - int64_t w = size ? size[0] : x->ne[0] * 2; - int64_t h = size ? size[1] : x->ne[1] * 2; - int32_t mode = int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS; - x = interpolate(m, x, {w, h}, mode); - - x = conv_2d(m["out_conv"], x); - return named(m, x); -} - -tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h) { - ASSERT(features.size() == 4); - std::array layer; - for (int i = 0; i < 4; ++i) { - tensor x = features[i]; - x = slice(m, x, {}, {1, x->ne[1]}, {}, {}); - x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]); - - model_ref proj = m["projects"][i]; - proj.flags |= model_build_flag::cwhn; - x = conv_2d(proj, x); // 1x1 conv, keep CWHN layout and directly use mul_mat - - x = cwhn_to_contiguous_2d(m, x); - switch (i) { - case 0: x = conv_transpose_2d(m["resize_layers"][i], x, 4); break; - case 1: x = conv_transpose_2d(m["resize_layers"][i], x, 2); break; - case 3: x = conv_2d(m["resize_layers"][i], x, 2, 1); break; - } - layer[i] = x; - } - - model_ref scratch = m["scratch"]; - tensor layer1_rn = conv_2d(scratch["layer1_rn"], layer[0], 1, 1); - tensor layer2_rn = conv_2d(scratch["layer2_rn"], layer[1], 1, 1); - tensor layer3_rn = conv_2d(scratch["layer3_rn"], layer[2], 1, 1); - tensor layer4_rn = conv_2d(scratch["layer4_rn"], layer[3], 1, 1); - - tensor path4 = feature_fusion(scratch["refinenet4"], layer4_rn, nullptr, layer3_rn->ne); - tensor path3 = feature_fusion(scratch["refinenet3"], path4, layer3_rn, layer2_rn->ne); - tensor path2 = feature_fusion(scratch["refinenet2"], path3, layer2_rn, layer1_rn->ne); - tensor path1 = feature_fusion(scratch["refinenet1"], path2, layer1_rn); - - tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); - out = interpolate( - m, out, {patch_w * 14, patch_h * 14}, - int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS); - - model_ref output_conv2 = scratch["output_conv2"]; - out = conv_2d(output_conv2[0], out, 1, 1); - out = ggml_relu_inplace(m, out); - out = conv_2d(output_conv2[2], out); - out = ggml_relu_inplace(m, out); - return out; -} - -} // namespace dpt - -inline tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) { - auto [c, w, h, n] = nelements(image); - int64_t w_patch = w / p.dino.patch_size; - int64_t h_patch = h / p.dino.patch_size; - - auto features = dino_get_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); - tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); - // depth = ggml_relu_inplace(m, depth); <- reference does another ReLU here - return compute_graph_output(m, depth); -} - -i32x2 depthany_image_extent(i32x2 extent, depthany_params const& p) { - int min_side = std::min(extent[0], extent[1]); - int tgt_side = std::max(p.image_size, next_multiple(min_side, p.image_multiple)); - i32x2 target = extent * tgt_side / min_side; - return next_multiple(target, p.image_multiple); -} - -inline depthany_params depthany_detect_params(model_file const&, i32x2 input_extent) { - depthany_params p; - p.dino.patch_size = 14; - if (input_extent[0] > 0 && input_extent[1] > 0) { - p.image_extent = depthany_image_extent(input_extent, p); - } - return p; -} - -inline image_data depthany_process_input(image_view image, depthany_params const& p) { - constexpr f32x4 mean = f32x4{0.485f, 0.456f, 0.406f, 0.f}; - constexpr f32x4 std = f32x4{0.229f, 0.224f, 0.225f, 1.f}; - - image_data resized; - if (image.extent != p.image_extent) { - resized = image_scale(image, p.image_extent); - image = image_view(resized); - } - return image_u8_to_f32(image, image_format::rgb_f32, -mean, 1.f / std); -} - -inline image_data depthany_process_output( - span data, i32x2 extent, depthany_params const& p) { - - image_view depth_output(p.image_extent, data); - image_data depth_resized; - if (depth_output.extent != extent) { - depth_resized = image_scale(depth_output, extent); - depth_output = depth_resized; - } - return image_f32_to_u8(depth_output, image_format::alpha_u8); -} - -} // namespace visp \ No newline at end of file +} // namespace visp::dpt From a90679fae4f664b6ba20d71e42ad020d8e6ad40f Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 9 Oct 2025 21:28:48 +0200 Subject: [PATCH 10/24] depth-anything: min/max-normalize output depth to [0, 1] --- include/visp/image.h | 6 ++++ src/cli/cli.cpp | 6 ++-- src/visp/arch/depth-anything.cpp | 10 +++---- src/visp/image.cpp | 49 +++++++++++++++++++++++++++++++- tests/test-image.cpp | 20 +++++++++++++ 5 files changed, 81 insertions(+), 10 deletions(-) diff --git a/include/visp/image.h b/include/visp/image.h index cb766cb..ddc2596 100644 --- a/include/visp/image.h +++ b/include/visp/image.h @@ -169,6 +169,12 @@ VISP_API void image_alpha_composite( VISP_API image_data image_alpha_composite( image_view const& fg, image_view const& bg, image_view const& mask); +// Rescale pixels values such that the minimum value over all pixels becomes `min` and +// the maximum becomes `max`. Channels are processed independently. +VISP_API void image_normalize( + image_view const& src, image_span const& dst, float min = 0, float max = 1); +VISP_API image_data image_normalize(image_view const& img, float min = 0, float max = 1); + // Compute root-mean-square difference between two images VISP_API float image_difference_rms(image_view const& a, image_view const& b); diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index fe61748..f0e66d3 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -458,8 +458,7 @@ void run_depth_anything(cli_args const& args) { model_ref m(weights, graph); tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); - tensor depth = depthany_predict(m, input, params); - tensor output = compute_graph_output(m, ggml_sigmoid(m, depth)); + tensor output = depthany_predict(m, input, params); compute_graph_allocate(graph, backend); transfer_to_backend(input, input_data); @@ -467,7 +466,8 @@ void run_depth_anything(cli_args const& args) { compute_timed(graph, backend); tensor_data output_data = transfer_from_backend(output); - image_data depth_image = depthany_process_output(output_data.as_f32(), image.extent, params); + image_data depth_raw = depthany_process_output(output_data.as_f32(), image.extent, params); + image_data depth_image = image_f32_to_u8(depth_raw, image_format::alpha_u8); image_save(depth_image, args.output); printf("-> depth image saved to %s\n", args.output); } diff --git a/src/visp/arch/depth-anything.cpp b/src/visp/arch/depth-anything.cpp index e9d6b8b..e58aa3e 100644 --- a/src/visp/arch/depth-anything.cpp +++ b/src/visp/arch/depth-anything.cpp @@ -124,14 +124,12 @@ image_data depthany_process_input(image_view image, depthany_params const& p) { } image_data depthany_process_output(span data, i32x2 extent, depthany_params const& p) { - image_view depth_output(p.image_extent, data); - image_data depth_resized; - if (depth_output.extent != extent) { - depth_resized = image_scale(depth_output, extent); - depth_output = depth_resized; + image_data normalized = image_normalize(depth_output); + if (normalized.extent != extent) { + return image_scale(normalized, extent); } - return image_f32_to_u8(depth_output, image_format::alpha_u8); + return normalized; } } // namespace visp \ No newline at end of file diff --git a/src/visp/image.cpp b/src/visp/image.cpp index f230876..77cb42c 100644 --- a/src/visp/image.cpp +++ b/src/visp/image.cpp @@ -197,7 +197,7 @@ image_data image_load(char const* filepath) { void image_save(image_view const& img, char const* filepath) { ASSERT(img.extent[0] > 0 && img.extent[1] > 0); - + if (!(img.format == image_format::alpha_u8 || img.format == image_format::rgb_u8 || img.format == image_format::rgba_u8)) { throw except("Unsupported image format [{}]", int(img.format)); @@ -534,6 +534,53 @@ void image_erosion(image_view const& src, image_span const& dst, int radius) { } } +void image_normalize(image_view const& src, image_span const& dst, float min, float max) { + ASSERT(src.extent == dst.extent); + ASSERT(is_float(src.format) && is_float(dst.format)); + ASSERT(min < max); + + float const fmax = std::numeric_limits::max(); + int const channels = n_channels(src); + float const* src_data = (float const*)src.data; + float* dst_data = (float*)dst.data; + + f32x4 min_val = {fmax, fmax, fmax, fmax}; + f32x4 max_val = {-fmax, -fmax, -fmax, -fmax}; + + for (int y = 0; y < src.extent[1]; ++y) { + for (int x = 0; x < src.extent[0]; ++x) { + for (int c = 0; c < channels; ++c) { + float v = src_data[y * src.stride / 4 + x * channels + c]; + min_val[c] = std::min(min_val[c], v); + max_val[c] = std::max(max_val[c], v); + } + } + } + + f32x4 delta = max_val - min_val; + for (int c = 0; c < channels; ++c) { + delta[c] = delta[c] < 1e-5f ? 1.0f : delta[c]; + } + f32x4 scale = f32x4{max - min} / delta; + f32x4 offset = -min_val * scale + f32x4{min}; + + for (int y = 0; y < src.extent[1]; ++y) { + for (int x = 0; x < src.extent[0]; ++x) { + for (int c = 0; c < channels; ++c) { + float v = src_data[y * src.stride / 4 + x * channels + c]; + v = v * scale[c] + offset[c]; + dst_data[y * dst.stride / 4 + x * channels + c] = v; + } + } + } +} + +image_data image_normalize(image_view const& img, float min, float max) { + image_data dst = image_alloc(img.extent, img.format); + image_normalize(img, dst, min, max); + return dst; +} + template float difference_rms(image_source a, image_source b) { float sum_sq_diff = 0.0f; diff --git a/tests/test-image.cpp b/tests/test-image.cpp index 85a94c6..83837e3 100644 --- a/tests/test-image.cpp +++ b/tests/test-image.cpp @@ -280,6 +280,26 @@ VISP_TEST(image_erosion) { CHECK_IMAGES_EQUAL(output, expected); } +VISP_TEST(image_normalize) { + constexpr i32x2 extent{2, 2}; + std::array input_data = { + f32x3{-1.0f, 4.2f, 0.5f}, f32x3{5.0f, 4.2f, 0.0f}, // + f32x3{-5.0f, 4.2f, 0.6f}, f32x3{1.0f, 4.2f, 1.0f}, // + }; + std::array expected_data = { + f32x3{0.4f, 0.0f, 0.5f}, f32x3{1.0f, 0.0f, 0.0f}, // + f32x3{0.0f, 0.0f, 0.6f}, f32x3{0.6f, 0.0f, 1.0f}, // + }; + std::array output_data{}; + + auto input = image_view(extent, input_data); + auto output = image_span(extent, output_data); + image_normalize(input, output); + + auto expected = image_view(extent, expected_data); + CHECK_IMAGES_EQUAL(output, expected); +} + VISP_TEST(tile_merge) { std::array, 4> tiles; for (int t = 0; t < 4; ++t) { From 085e483e6a343127f0add6be9ccec002ae1304ee Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 9 Oct 2025 22:16:29 +0200 Subject: [PATCH 11/24] depth-anything: support CWHN layout and add to benchmark * CWHN is a bit faster on CPU (900ms -> 750ms for ViT-S) --- scripts/convert.py | 4 +++- src/visp/arch/depth-anything.cpp | 9 +++++++-- src/visp/vision.cpp | 10 +++++----- tests/benchmark.cpp | 15 +++++++++++++++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index feaefaf..239ff14 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -357,6 +357,8 @@ def convert_depth_anything(input_filepath: Path, writer: Writer): if is_conv_2d(name, tensor): if "patch_embed" in name or "projects" in name: tensor = conv_2d_to_nhwc(tensor) + elif "resize_layers.0" in name or "resize_layers.1" in name: + pass # ConvTranspose2D, don't change layout else: tensor = writer.convert_tensor_2d(tensor) @@ -422,7 +424,7 @@ def convert_esrgan(input_filepath: Path, writer: Writer): arch_names = { "sam": "mobile-sam", "birefnet": "birefnet", - "depth-anything": "depth-anything", + "depth-anything": "depthanything", "migan": "migan", "esrgan": "esrgan", } diff --git a/src/visp/arch/depth-anything.cpp b/src/visp/arch/depth-anything.cpp index e58aa3e..211d624 100644 --- a/src/visp/arch/depth-anything.cpp +++ b/src/visp/arch/depth-anything.cpp @@ -27,10 +27,13 @@ tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size) { } x = residual_conv(m["resConfUnit2"], x); - int64_t w = size ? size[0] : x->ne[0] * 2; - int64_t h = size ? size[1] : x->ne[1] * 2; + int const dim = m.flags & model_build_flag::cwhn ? 1 : 0; + int64_t w = size ? size[dim + 0] : x->ne[dim + 0] * 2; + int64_t h = size ? size[dim + 1] : x->ne[dim + 1] * 2; int32_t mode = int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS; + x = contiguous_2d_to_whcn(m, x); x = interpolate(m, x, {w, h}, mode); + x = whcn_to_contiguous_2d(m, x); x = conv_2d(m["out_conv"], x); return named(m, x); @@ -70,9 +73,11 @@ tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h tensor path1 = feature_fusion(scratch["refinenet1"], path2, layer1_rn); tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); + out = contiguous_2d_to_whcn(m, out); out = interpolate( m, out, {patch_w * 14, patch_h * 14}, int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS); + out = whcn_to_contiguous_2d(m, out); model_ref output_conv2 = scratch["output_conv2"]; out = conv_2d(output_conv2[0], out, 1, 1); diff --git a/src/visp/vision.cpp b/src/visp/vision.cpp index 99e43c6..36d324c 100644 --- a/src/visp/vision.cpp +++ b/src/visp/vision.cpp @@ -122,14 +122,14 @@ depthany_model depthany_load_model(char const* filepath, backend_device const& d depthany_model model; model.backend = &dev; model_file file = model_load(filepath); + model.params = depthany_detect_params(file); model.weights = model_init(file.n_tensors()); model_transfer(file, model.weights, dev, dev.preferred_float_type(), dev.preferred_layout()); return model; } image_data depthany_compute(depthany_model& model, image_view image) { - depthany_params params{}; - i32x2 res = depthany_image_extent(image.extent, params); + i32x2 res = depthany_image_extent(image.extent, model.params); if (!model.graph || res != model.params.image_extent) { model.params.image_extent = res; @@ -137,17 +137,17 @@ image_data depthany_compute(depthany_model& model, image_view image) { model_ref m(model.weights, model.graph); model.input = compute_graph_input(m, GGML_TYPE_F32, {3, res[0], res[1], 1}); - model.output = depthany_predict(m, model.input, params); + model.output = depthany_predict(m, model.input, model.params); compute_graph_allocate(model.graph, *model.backend); } - image_data img_data = depthany_process_input(image, params); + image_data img_data = depthany_process_input(image, model.params); transfer_to_backend(model.input, img_data); compute(model.graph, *model.backend); tensor_data output_data = transfer_from_backend(model.output); - return depthany_process_output(output_data.as_f32(), image.extent, params); + return depthany_process_output(output_data.as_f32(), image.extent, model.params); } // diff --git a/tests/benchmark.cpp b/tests/benchmark.cpp index a75bd13..8c28bb6 100644 --- a/tests/benchmark.cpp +++ b/tests/benchmark.cpp @@ -93,6 +93,17 @@ bench_timings benchmark_birefnet(path model_path, backend_device& backend) { return run_benchmark(model.graph, backend, 8, {{model.input, input_data}}); } +bench_timings benchmark_depth_anything(path model_path, backend_device& backend) { + path input_path = test_dir().input / "cat-and-hat.jpg"; + + depthany_model model = depthany_load_model(model_path.string().c_str(), backend); + image_data input = image_load(input_path.string().c_str()); + image_data input_data = depthany_process_input(input, model.params); + + depthany_compute(model, input); + return run_benchmark(model.graph, backend, 8, {{model.input, input_data}}); +} + bench_timings benchmark_migan(path model_path, backend_device& backend) { path image_path = test_dir().input / "bench-image.jpg"; path mask_path = test_dir().input / "bench-mask.png"; @@ -172,6 +183,10 @@ bench_result benchmark_model( path model_path = select_model(model, "BiRefNet-lite-F16.gguf"); result.time = benchmark_birefnet(model_path, backend); + } else if (arch == "depthany") { + path model_path = select_model(model, "DepthAnythingV2-Small-F16.gguf"); + result.time = benchmark_depth_anything(model_path, backend); + } else if (arch == "migan") { path model_path = select_model(model, "MIGAN-512-places2-F16.gguf"); result.time = benchmark_migan(model_path, backend); From d5960ae4355dad95efcbfcaead7c7b96a0e50043 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 9 Oct 2025 22:30:57 +0200 Subject: [PATCH 12/24] depth-anything: add model test --- tests/test-models.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test-models.cpp b/tests/test-models.cpp index 3f7b803..70b69a8 100644 --- a/tests/test-models.cpp +++ b/tests/test-models.cpp @@ -70,6 +70,22 @@ VISP_TEST(test_birefnet_dynamic) { compare_images("birefnet-dynamic.png", output2, 0.015f); } +VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) { + path model_path = test_dir().models / "DepthAnythingV2-Small-F16.gguf"; + path input_path = test_dir().input / "cat-and-hat.jpg"; + std::string name = "depth-anything"; + name += bt == backend_type::cpu ? "-cpu.png" : "-gpu.png"; + + backend_device b = backend_init(bt); + depthany_model model = depthany_load_model(model_path.string().c_str(), b); + image_data input = image_load(input_path.string().c_str()); + image_data depth = depthany_compute(model, input); + image_data output = image_f32_to_u8(depth, image_format::alpha_u8); + + float tolerance = bt == backend_type::cpu ? 0.01f : 0.015f; + compare_images(name, output, tolerance); +} + VISP_BACKEND_TEST(test_migan)(backend_type bt) { path model_path = test_dir().models / "MIGAN-512-places2-F16.gguf"; path image_path = test_dir().input / "bench-image.jpg"; From 4aea83d9c3ed3029f6f8805bcb2d411aae12d9eb Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 10 Oct 2025 11:12:59 +0200 Subject: [PATCH 13/24] depth-anything: use flash attention --- include/visp/ml.h | 4 ++-- src/cli/cli.cpp | 4 +++- src/visp/arch/dino.cpp | 25 ++++++------------------- src/visp/arch/dino.h | 2 +- src/visp/ml.cpp | 5 +++-- tests/test_depth_anything.py | 14 +++++++------- tests/workbench.cpp | 7 ++++--- 7 files changed, 26 insertions(+), 35 deletions(-) diff --git a/include/visp/ml.h b/include/visp/ml.h index eebf1c6..24e8b13 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -65,7 +65,8 @@ enum class model_build_flag { conv_2d_direct_cwhn = 1 << 1, concat_n = 1 << 2, f16_conv_transpose = 1 << 3, - window_partition = 1 << 4 + window_partition = 1 << 4, + flash_attention = 1 << 5 }; // clang-format on using model_build_flags = flags; @@ -306,7 +307,6 @@ struct dino_params { int n_blocks = 12; int n_heads = 6; int mlp_ratio = 4; - bool flash_attention = false; }; VISP_API std::vector dino_get_intermediate_layers( diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index f0e66d3..e7600f5 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -443,7 +443,7 @@ void run_birefnet(cli_args const& args) { void run_depth_anything(cli_args const& args) { backend_device backend = backend_init(args); auto [file, weights] = load_model_weights( - args, backend, "models/DepthAnythingV2-Small-F32.gguf"); + args, backend, "models/DepthAnythingV2-Small-F32.gguf", 0, backend.preferred_layout()); require_inputs(args.inputs, 1, ""); image_data image = image_load(args.inputs[0]); @@ -456,6 +456,8 @@ void run_depth_anything(cli_args const& args) { compute_graph graph = compute_graph_init(); model_ref m(weights, graph); + bool flash_attn = !!(m.flags & model_build_flag::flash_attention); + printf("- flash attention: %s\n", flash_attn ? "on" : "off"); tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); tensor output = depthany_predict(m, input, params); diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp index 3da855a..ef257a0 100644 --- a/src/visp/arch/dino.cpp +++ b/src/visp/arch/dino.cpp @@ -57,7 +57,7 @@ tensor mlp(model_ref m, tensor x) { return x; } -tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { +tensor attention(model_ref m, tensor x, int n_heads) { auto [c, n, b, _] = nelements(x); float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); @@ -77,32 +77,19 @@ tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { }; tensor q = split(qkv, 0); tensor k = split(qkv, 1); - tensor v = split(qkv, 2, !flash_attn); - - if (flash_attn) { - int c_pad = int(GGML_PAD(c, 4) - c); - int n_pad = int(GGML_PAD(n, 32) - n); - q = ggml_pad(m, q, c_pad, n_pad, 0, 0); - k = ggml_pad(m, k, c_pad, n_pad, 0, 0); - v = ggml_pad(m, v, c_pad, n_pad, 0, 0); - - ggml_type dtype = m.weights("qkv.weight")->type; - k = ggml_cast(m, k, dtype); - v = ggml_cast(m, v, dtype); + tensor v = split(qkv, 2, !(m.flags & model_build_flag::flash_attention)); + if (m.flags & model_build_flag::flash_attention) { x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); - x = slice(m, x, {}, {}, {0, n}, {}); } else { - q = ggml_scale_inplace(m, q, scale); - tensor attn = ggml_mul_mat(m, k, q); - attn = ggml_soft_max(m, attn); + attn = ggml_soft_max_ext(m, attn, nullptr, scale, 0.0f); x = ggml_mul_mat(m, v, attn); x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_3d(m, x, c, n, b); } + x = ggml_reshape_3d(m, x, c, n, b); x = linear(m["proj"], x); return named(m, x); } @@ -110,7 +97,7 @@ tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn) { tensor block(model_ref m, tensor x, dino_params const& p) { tensor attn = x; attn = layer_norm(m["norm1"], attn, 1e-6f); - attn = attention(m["attn"], attn, p.n_heads, p.flash_attention); + attn = attention(m["attn"], attn, p.n_heads); attn = layer_scale(m["ls1"], attn); x = ggml_add(m, x, attn); diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 51ffdc5..192a358 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -11,7 +11,7 @@ tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int tensor prepare_tokens(model_ref m, tensor x, int patch_size); tensor layer_scale(model_ref m, tensor x); tensor mlp(model_ref m, tensor x); -tensor attention(model_ref m, tensor x, int n_heads, bool flash_attn); +tensor attention(model_ref m, tensor x, int n_heads); tensor block(model_ref m, tensor x, dino_params const& p); std::vector get_intermediate_layers( diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index b214f8d..6e88685 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -142,8 +142,9 @@ model_build_flags backend_default_flags(backend_type type) { using enum model_build_flag; switch (type) { case backend_type::cpu: - return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition; - case backend_type::gpu: return {}; + return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition | + flash_attention; + case backend_type::gpu: return flash_attention; } return {}; } diff --git a/tests/test_depth_anything.py b/tests/test_depth_anything.py index 235fae2..dc2be0b 100644 --- a/tests/test_depth_anything.py +++ b/tests/test_depth_anything.py @@ -213,7 +213,8 @@ def forward(self, x: Tensor) -> Tensor: return x -def test_attention(): +@pytest.mark.parametrize("flash_attn", [0, 1]) +def test_attention(flash_attn: int): dim = 6 num_heads = 3 module = Attention(dim=dim, num_heads=num_heads, qkv_bias=True, proj_bias=True) @@ -224,7 +225,7 @@ def test_attention(): x = input_tensor(1, 12, dim) expected = module(x) result = workbench.invoke_test( - "dino_attention", x, state, dict(n_heads=num_heads, flash_attn=0) + "dino_attention", x, state, dict(n_heads=num_heads, flash_attn=flash_attn) ) assert torch.allclose(result, expected) @@ -292,7 +293,7 @@ def test_block(): expected = module(x) result = workbench.invoke_test("dino_block", x, state, dict(n_heads=num_heads)) - assert torch.allclose(result, expected, atol=1e-2) # precision drop due to GELU in MLP + assert torch.allclose(result, expected, rtol=1e-2) # precision drop due to GELU in MLP class DinoVisionTransformer(nn.Module): @@ -431,8 +432,6 @@ def prepare_tokens_with_masks(self, x, masks=None): if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - print(x.shape, self.cls_token.shape) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) @@ -861,8 +860,9 @@ def test_dpt_head(): module.load_state_dict(state) module.eval() - x = [input_tensor(2, 1, h * w, in_channels) for _ in range(4)] - expected = module(x, h, w) + x = [input_tensor(1, 1 + h * w, in_channels) for _ in range(4)] + x_tuples = [(t[:, 1:, :], None) for t in x] + expected = module(x_tuples, h, w) state = convert_to_nhwc(state, key="projects") result = workbench.invoke_test("depthany_head", x, state) diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 8c70076..93d86c3 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -429,13 +429,15 @@ DEF(dino_prepare_tokens)(model_ref m, span input, param_dict const& p) { } DEF(dino_attention)(model_ref m, span input, param_dict const& p) { - return {dino::attention(m, input[0], p.get("n_heads", 8), p.get("flash_attn", 0) != 0)}; + if (p.get("flash_attn", 0) != 0) { + m.flags |= model_build_flag::flash_attention; + } + return {dino::attention(m, input[0], p.get("n_heads", 8))}; } DEF(dino_block)(model_ref m, span input, param_dict const& p) { dino_params params{}; params.n_heads = p.get("n_heads", 8); - params.flash_attention = p.get("flash_attn", 0) != 0; return {dino::block(m, input[0], params)}; } @@ -445,7 +447,6 @@ DEF(dino_intermediate_layers)(model_ref m, span input, param_dict const& params.embed_dim = 6; params.n_blocks = 4; params.n_heads = 3; - params.flash_attention = p.get("flash_attn", 0) != 0; auto layers = std::array{0, 1, 2, 3}; return dino::get_intermediate_layers(m, input[0], layers, params); } From 46cc0d1f684c970c6b35a2f1b1e97b55a506bae3 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 10 Oct 2025 13:06:11 +0200 Subject: [PATCH 14/24] nn: remove the cwhn im2col path for gpu (not worth maintaining) --- src/visp/nn.cpp | 14 ++++---------- tests/test-models.cpp | 4 ++++ 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp index 53ab710..3073d06 100644 --- a/src/visp/nn.cpp +++ b/src/visp/nn.cpp @@ -87,16 +87,10 @@ tensor conv_2d(model_ref m, tensor x, int stride, int pad) { x = permute_whcn_to_cwhn(m, x); } else { - x = permute_cwhn_to_whcn(m, x); - tensor permuted_weight = permute_cwhn_to_whcn(m, weight); - tensor cols = ggml_im2col( - m, permuted_weight, x, stride, stride, pad, pad, 1, 1, true, GGML_TYPE_F32); - tensor a = ggml_reshape_2d( - m, cols, cols->ne[0], cols->ne[1] * cols->ne[2] * cols->ne[3]); - tensor b = ggml_reshape_2d( - m, weight, weight->ne[0] * weight->ne[1] * weight->ne[2], weight->ne[3]); - x = ggml_mul_mat(m, b, a); - x = ggml_reshape_4d(m, x, weight->ne[3], cols->ne[1], cols->ne[2], cols->ne[3]); + weight = ggml_cont(m, permute_cwhn_to_whcn(m, weight)); + x = ggml_cont(m, permute_cwhn_to_whcn(m, x)); + x = ggml_conv_2d(m, weight, x, stride, stride, pad, pad, 1, 1); + x = ggml_cont(m, permute_whcn_to_cwhn(m, x)); } } else { // WHCN layout x = ggml_conv_2d_direct(m, weight, x, stride, stride, pad, pad, 1, 1); diff --git a/tests/test-models.cpp b/tests/test-models.cpp index 70b69a8..6031220 100644 --- a/tests/test-models.cpp +++ b/tests/test-models.cpp @@ -71,6 +71,10 @@ VISP_TEST(test_birefnet_dynamic) { } VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) { + if (bt == backend_type::gpu) { + throw test_skip{"DepthAnything does not support GPU backend"}; + } + path model_path = test_dir().models / "DepthAnythingV2-Small-F16.gguf"; path input_path = test_dir().input / "cat-and-hat.jpg"; std::string name = "depth-anything"; From ffc0ea39ff5b5e59638d8cb6fefc2b72a553633f Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 10 Oct 2025 19:42:15 +0200 Subject: [PATCH 15/24] depth-anything: support vulkan with flash attention --- scripts/convert.py | 4 ++++ src/visp/arch/dino.cpp | 28 +++++++++++----------------- tests/test-models.cpp | 4 ---- 3 files changed, 15 insertions(+), 21 deletions(-) diff --git a/scripts/convert.py b/scripts/convert.py index 239ff14..30fac4b 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -362,6 +362,10 @@ def convert_depth_anything(input_filepath: Path, writer: Writer): else: tensor = writer.convert_tensor_2d(tensor) + if "pos_embed" in name or "cls_token" in name: + writer.add_tensor(name, tensor, "f32") + continue + writer.add_tensor(name, tensor) diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp index ef257a0..d3fd62c 100644 --- a/src/visp/arch/dino.cpp +++ b/src/visp/arch/dino.cpp @@ -1,15 +1,13 @@ -#include "util/math.h" #include "visp/arch/dino.h" +#include "util/math.h" #include "visp/ml.h" #include "visp/nn.h" namespace visp { namespace dino { -tensor interpolate_pos_encoding( - model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { - - tensor pos_embed = ggml_cast(m, m.weights("pos_embed"), GGML_TYPE_F32); +tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { + tensor pos_embed = m.weights("pos_embed"); int64_t n_patch = x->ne[1] - 1; int64_t n = pos_embed->ne[1] - 1; if (n_patch == n && w == h) { @@ -65,19 +63,15 @@ tensor attention(model_ref m, tensor x, int n_heads) { qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { - tensor = slice(m, tensor, {}, {}, {}, index); - tensor = ggml_reshape_4d(m, tensor, c / n_heads, n_heads, n, b); - if (transpose) { - tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); - } else { - tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); - } - return tensor; + auto split = [=](tensor qkv, size_t index, ggml_type type, bool transpose = false) mutable { + tensor t = slice(m, qkv, {}, {}, {}, index); + t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); + t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3); + return ggml_cast(m, t, type); }; - tensor q = split(qkv, 0); - tensor k = split(qkv, 1); - tensor v = split(qkv, 2, !(m.flags & model_build_flag::flash_attention)); + tensor q = split(qkv, 0, GGML_TYPE_F32); + tensor k = split(qkv, 1, GGML_TYPE_F16); + tensor v = split(qkv, 2, GGML_TYPE_F16, !(m.flags & model_build_flag::flash_attention)); if (m.flags & model_build_flag::flash_attention) { x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); diff --git a/tests/test-models.cpp b/tests/test-models.cpp index 6031220..70b69a8 100644 --- a/tests/test-models.cpp +++ b/tests/test-models.cpp @@ -71,10 +71,6 @@ VISP_TEST(test_birefnet_dynamic) { } VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) { - if (bt == backend_type::gpu) { - throw test_skip{"DepthAnything does not support GPU backend"}; - } - path model_path = test_dir().models / "DepthAnythingV2-Small-F16.gguf"; path input_path = test_dir().input / "cat-and-hat.jpg"; std::string name = "depth-anything"; From b1a48e95885ca05bbf8cffa901f32fbd6cb373af Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 10 Oct 2025 22:12:06 +0200 Subject: [PATCH 16/24] ggml: tests for bicubic interpolation on gpu and fix issue align corners * if any of the spatial dimensions are 1, align-corners would run int division-by-zero or oob access * updated birefnet gpu images after the fix * added reference images for depth-anything tests --- depend/ggml | 2 +- tests/reference-images.cmake | 6 ++++-- tests/test-models.cpp | 2 +- tests/test_primitives.py | 5 +++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/depend/ggml b/depend/ggml index 96840f1..7d1a4d8 160000 --- a/depend/ggml +++ b/depend/ggml @@ -1 +1 @@ -Subproject commit 96840f15c3d0aa61a901c05003efd1976df4e5a8 +Subproject commit 7d1a4d803cb807b45beb9c4c6605013d9a8354f7 diff --git a/tests/reference-images.cmake b/tests/reference-images.cmake index d2d0a0b..bdc1cdc 100644 --- a/tests/reference-images.cmake +++ b/tests/reference-images.cmake @@ -1,6 +1,8 @@ file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-cpu.png/c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09" "tests/reference/birefnet-cpu.png" EXPECTED_HASH SHA256=c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09) -file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-dynamic.png/720bf20140f6f93c3c3953ed2e28a9cb395def8426f53c031d58a8393784227f" "tests/reference/birefnet-dynamic.png" EXPECTED_HASH SHA256=720bf20140f6f93c3c3953ed2e28a9cb395def8426f53c031d58a8393784227f) -file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-gpu.png/c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09" "tests/reference/birefnet-gpu.png" EXPECTED_HASH SHA256=c8663d4c985f94b29fcca3c3c5d2058c53447f19c521b7c5f97276cace68bb09) +file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-dynamic.png/5ef6a13855c566609de54e08112c4308c97a0f6740b410e8639bc993b2273c7c" "tests/reference/birefnet-dynamic.png" EXPECTED_HASH SHA256=5ef6a13855c566609de54e08112c4308c97a0f6740b410e8639bc993b2273c7c) +file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/birefnet-gpu.png/1d55cdcb0f3648c32830ad1247d768b867e34e20cdbcf08ed166859b55f75aad" "tests/reference/birefnet-gpu.png" EXPECTED_HASH SHA256=1d55cdcb0f3648c32830ad1247d768b867e34e20cdbcf08ed166859b55f75aad) +file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/depth-anything-cpu.png/36adde57ebd2589fe37bf7c0efbf9d3a013f98f7d7a45bb19fd2c492c8ade7a9" "tests/reference/depth-anything-cpu.png" EXPECTED_HASH SHA256=36adde57ebd2589fe37bf7c0efbf9d3a013f98f7d7a45bb19fd2c492c8ade7a9) +file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/depth-anything-gpu.png/b3639c0e049081ea35d2fdc37c12634457d52c320a6b839f4d6099319103464b" "tests/reference/depth-anything-gpu.png" EXPECTED_HASH SHA256=b3639c0e049081ea35d2fdc37c12634457d52c320a6b839f4d6099319103464b) file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/esrgan-cpu.png/481dcc0eb617feb9f8f7403ce179e77e2eba2c7a067f4a1ea90e0fb47083d814" "tests/reference/esrgan-cpu.png" EXPECTED_HASH SHA256=481dcc0eb617feb9f8f7403ce179e77e2eba2c7a067f4a1ea90e0fb47083d814) file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/esrgan-gpu.png/a8bfab0e07aeca16b737872bb3dbbe0e6b76cfff5616d2f02f2b0465cc7a0937" "tests/reference/esrgan-gpu.png" EXPECTED_HASH SHA256=a8bfab0e07aeca16b737872bb3dbbe0e6b76cfff5616d2f02f2b0465cc7a0937) file(DOWNLOAD "https://lfs.interstice.cloud/vision.cpp/tests/reference/migan-cpu.png/9fb32419246e3e073c73df8f4a0fefd334934ffddf8a157535b8b2fc3c1d93ee" "tests/reference/migan-cpu.png" EXPECTED_HASH SHA256=9fb32419246e3e073c73df8f4a0fefd334934ffddf8a157535b8b2fc3c1d93ee) diff --git a/tests/test-models.cpp b/tests/test-models.cpp index 70b69a8..4444e84 100644 --- a/tests/test-models.cpp +++ b/tests/test-models.cpp @@ -72,7 +72,7 @@ VISP_TEST(test_birefnet_dynamic) { VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) { path model_path = test_dir().models / "DepthAnythingV2-Small-F16.gguf"; - path input_path = test_dir().input / "cat-and-hat.jpg"; + path input_path = test_dir().input / "wardrobe.jpg"; std::string name = "depth-anything"; name += bt == backend_type::cpu ? "-cpu.png" : "-gpu.png"; diff --git a/tests/test_primitives.py b/tests/test_primitives.py index 28df77d..a151a40 100644 --- a/tests/test_primitives.py +++ b/tests/test_primitives.py @@ -152,7 +152,8 @@ def test_roll(shift: tuple[int, int, int, int], backend: str): @pytest.mark.parametrize("align_corners", [True, False]) @pytest.mark.parametrize("size", ["small", "large"]) @pytest.mark.parametrize("scale", [0.6, 2.0]) -def test_interpolate(mode: str, align_corners: bool, size: str, scale: float): +@pytest.mark.parametrize("backend", ["cpu", "vulkan"]) +def test_interpolate(mode: str, align_corners: bool, size: str, scale: float, backend: str): b, c, h, w = { "small": (1, 3, 2, 3), "large": (4, 19, 20, 30), @@ -164,5 +165,5 @@ def test_interpolate(mode: str, align_corners: bool, size: str, scale: float): ) params = dict(mode=mode, h=target[0], w=target[1], align_corners=1 if align_corners else 0) - result = workbench.invoke_test("interpolate", x, {}, params) + result = workbench.invoke_test("interpolate", x, {}, params, backend) assert torch.allclose(result, expected) From 5e173b940caf7fd5c96082f5b242c742b7323e13 Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 13 Oct 2025 20:00:41 +0200 Subject: [PATCH 17/24] depth-anything: change weight names to follow transformers and support more model sizes --- include/visp/ml.h | 10 ++-- include/visp/vision.h | 1 + scripts/convert.py | 39 ++++++++++++++-- src/visp/arch/depth-anything.cpp | 79 ++++++++++++++++++-------------- src/visp/arch/depth-anything.h | 3 +- src/visp/arch/dino.cpp | 62 ++++++++++++++----------- src/visp/arch/dino.h | 2 +- src/visp/ml.cpp | 12 +++++ src/visp/nn.cpp | 3 +- tests/benchmark.cpp | 4 +- tests/test-models.cpp | 2 +- tests/workbench.cpp | 8 ++-- 12 files changed, 147 insertions(+), 78 deletions(-) diff --git a/include/visp/ml.h b/include/visp/ml.h index 24e8b13..5a3e0dc 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -88,6 +88,7 @@ struct model_file { VISP_API int64_t key(char const* name) const; VISP_API int get_int(char const* name) const; VISP_API std::string_view get_string(char const* name) const; + VISP_API void get_array(char const* name, span out_values) const; }; // Opens a .gguf file and reads its contents into memory. @@ -303,12 +304,13 @@ VISP_API swin_params swin_detect_params(model_file const&); struct dino_params { int patch_size = 16; - int embed_dim = 384; - int n_blocks = 12; - int n_heads = 6; - int mlp_ratio = 4; + int embed_dim = 768; + int n_layers = 12; + int n_heads = 12; }; +VISP_API dino_params dino_detect_params(model_file const&); + VISP_API std::vector dino_get_intermediate_layers( model_ref, tensor image, std::span layers, dino_params const&); diff --git a/include/visp/vision.h b/include/visp/vision.h index 23b628d..e529754 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -176,6 +176,7 @@ struct depthany_params { int image_size = 518; int image_multiple = 14; i32x2 image_extent = {518, 518}; + float max_depth = 1; std::array feature_layers = {2, 5, 8, 11}; dino_params dino; }; diff --git a/scripts/convert.py b/scripts/convert.py index 30fac4b..d99476c 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -93,6 +93,14 @@ def add_conv2d_weight_indices(self): self.add_array(f"{self.arch}.conv2d_weights", self.conv2d_weights) +def load_model(path: Path) -> dict[str, Tensor]: + if path.suffix in [".safetensors", ".safetensor"]: + weights = safetensors.safe_open(path, "pt") + return {k: weights.get_tensor(k) for k in weights.keys()} + else: + return torch.load(path, map_location="cpu", weights_only=True) + + batch_norm_eps = 1e-5 @@ -349,20 +357,43 @@ def convert_depth_anything(input_filepath: Path, writer: Writer): writer.add_license("apache-2.0") writer.set_tensor_layout_default(TensorLayout.nchw) - model: dict[str, Tensor] = torch.load(input_filepath, map_location="cpu", weights_only=True) + model: dict[str, Tensor] = load_model(input_filepath) + + if "pretrained.cls_token" in model: + print("The converter is written for the transformers (.safetensors) version of the model.") + print("The original weights (.pth) are currently not supported.") + raise ValueError("Weights not supported") + + shape = model["backbone.embeddings.patch_embeddings.projection.weight"].shape + writer.add_int32("dino.patch_size", shape[2]) + writer.add_int32("dino.embed_dim", shape[0]) + writer.add_int32("depthanything.image_size", 518) + match shape[0]: + case 384: # Small + writer.add_int32("dino.n_heads", 6) + writer.add_int32("dino.n_layers", 12) + writer.add_array("depthanything.feature_layers", [2, 5, 8, 11]) + case 768: # Base + writer.add_int32("dino.n_heads", 12) + writer.add_int32("dino.n_layers", 12) + writer.add_array("depthanything.feature_layers", [2, 5, 8, 11]) + case 1024: # Large + writer.add_int32("dino.n_heads", 16) + writer.add_int32("dino.n_layers", 24) + writer.add_array("depthanything.feature_layers", [4, 11, 17, 23]) for key, tensor in model.items(): name = key if is_conv_2d(name, tensor): - if "patch_embed" in name or "projects" in name: + if "patch_embeddings" in name or ("projection" in name and "fusion" not in name): tensor = conv_2d_to_nhwc(tensor) - elif "resize_layers.0" in name or "resize_layers.1" in name: + elif "0.resize" in name or "1.resize" in name: pass # ConvTranspose2D, don't change layout else: tensor = writer.convert_tensor_2d(tensor) - if "pos_embed" in name or "cls_token" in name: + if "position_embeddings" in name or "cls_token" in name: writer.add_tensor(name, tensor, "f32") continue diff --git a/src/visp/arch/depth-anything.cpp b/src/visp/arch/depth-anything.cpp index 211d624..22a4127 100644 --- a/src/visp/arch/depth-anything.cpp +++ b/src/visp/arch/depth-anything.cpp @@ -9,12 +9,15 @@ namespace visp { namespace dpt { +int32_t const bilinear_align_corners = int32_t(GGML_SCALE_MODE_BILINEAR) | + GGML_SCALE_FLAG_ALIGN_CORNERS; + tensor residual_conv(model_ref m, tensor x) { tensor out = x; out = ggml_relu(m, out); - out = conv_2d(m["conv1"], out, 1, 1); + out = conv_2d(m["convolution1"], out, 1, 1); out = ggml_relu(m, out); - out = conv_2d(m["conv2"], out, 1, 1); + out = conv_2d(m["convolution2"], out, 1, 1); x = ggml_add_inplace(m, x, out); return named(m, x); } @@ -22,68 +25,73 @@ tensor residual_conv(model_ref m, tensor x) { tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size) { tensor x = x0; if (x1) { - tensor res = residual_conv(m["resConfUnit1"], x1); + tensor res = residual_conv(m["residual_layer1"], x1); x = ggml_add_inplace(m, x, res); } - x = residual_conv(m["resConfUnit2"], x); + x = residual_conv(m["residual_layer2"], x); int const dim = m.flags & model_build_flag::cwhn ? 1 : 0; int64_t w = size ? size[dim + 0] : x->ne[dim + 0] * 2; int64_t h = size ? size[dim + 1] : x->ne[dim + 1] * 2; - int32_t mode = int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS; x = contiguous_2d_to_whcn(m, x); - x = interpolate(m, x, {w, h}, mode); + x = interpolate(m, x, {w, h}, bilinear_align_corners); x = whcn_to_contiguous_2d(m, x); - x = conv_2d(m["out_conv"], x); + x = conv_2d(m["projection"], x); return named(m, x); } -tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h) { +tensor neck(model_ref m, span features, int64_t patch_w, int64_t patch_h) { ASSERT(features.size() == 4); - std::array layer; + + model_ref reassemble = m["reassemble_stage.layers"]; for (int i = 0; i < 4; ++i) { tensor x = features[i]; x = slice(m, x, {}, {1, x->ne[1]}, {}, {}); x = ggml_reshape_4d(m, x, x->ne[0], patch_w, patch_h, x->ne[3]); - model_ref proj = m["projects"][i]; + model_ref proj = reassemble[i]["projection"]; proj.flags |= model_build_flag::cwhn; x = conv_2d(proj, x); // 1x1 conv, keep CWHN layout and directly use mul_mat x = cwhn_to_contiguous_2d(m, x); switch (i) { - case 0: x = conv_transpose_2d(m["resize_layers"][i], x, 4); break; - case 1: x = conv_transpose_2d(m["resize_layers"][i], x, 2); break; - case 3: x = conv_2d(m["resize_layers"][i], x, 2, 1); break; + case 0: x = conv_transpose_2d(reassemble[i]["resize"], x, 4); break; + case 1: x = conv_transpose_2d(reassemble[i]["resize"], x, 2); break; + case 3: x = conv_2d(reassemble[i]["resize"], x, 2, 1); break; } layer[i] = x; } - model_ref scratch = m["scratch"]; - tensor layer1_rn = conv_2d(scratch["layer1_rn"], layer[0], 1, 1); - tensor layer2_rn = conv_2d(scratch["layer2_rn"], layer[1], 1, 1); - tensor layer3_rn = conv_2d(scratch["layer3_rn"], layer[2], 1, 1); - tensor layer4_rn = conv_2d(scratch["layer4_rn"], layer[3], 1, 1); + model_ref convs = m["convs"]; + for (int i = 0; i < 4; ++i) { + layer[i] = conv_2d(convs[i], layer[i], 1, 1); + } - tensor path4 = feature_fusion(scratch["refinenet4"], layer4_rn, nullptr, layer3_rn->ne); - tensor path3 = feature_fusion(scratch["refinenet3"], path4, layer3_rn, layer2_rn->ne); - tensor path2 = feature_fusion(scratch["refinenet2"], path3, layer2_rn, layer1_rn->ne); - tensor path1 = feature_fusion(scratch["refinenet1"], path2, layer1_rn); + model_ref fusion = m["fusion_stage.layers"]; + tensor fused; + fused = feature_fusion(fusion[0], layer[3], nullptr, layer[2]->ne); + fused = feature_fusion(fusion[1], fused, layer[2], layer[1]->ne); + fused = feature_fusion(fusion[2], fused, layer[1], layer[0]->ne); + fused = feature_fusion(fusion[3], fused, layer[0]); + return fused; +} - tensor out = conv_2d(scratch["output_conv1"], path1, 1, 1); +tensor head(model_ref m, tensor x, int64_t w, int64_t h, float max_depth) { + tensor out = conv_2d(m["conv1"], x, 1, 1); out = contiguous_2d_to_whcn(m, out); - out = interpolate( - m, out, {patch_w * 14, patch_h * 14}, - int32_t(GGML_SCALE_MODE_BILINEAR) | GGML_SCALE_FLAG_ALIGN_CORNERS); + out = interpolate(m, out, {w, h}, bilinear_align_corners); out = whcn_to_contiguous_2d(m, out); - model_ref output_conv2 = scratch["output_conv2"]; - out = conv_2d(output_conv2[0], out, 1, 1); + out = conv_2d(m["conv2"], out, 1, 1); out = ggml_relu_inplace(m, out); - out = conv_2d(output_conv2[2], out); + out = conv_2d(m["conv3"], out); out = ggml_relu_inplace(m, out); + + if (max_depth != 1) { + out = ggml_scale(m, out, max_depth); + } return out; } @@ -94,9 +102,10 @@ tensor depthany_predict(model_ref m, tensor image, depthany_params const& p) { int64_t w_patch = w / p.dino.patch_size; int64_t h_patch = h / p.dino.patch_size; - auto features = dino_get_intermediate_layers(m["pretrained"], image, p.feature_layers, p.dino); - tensor depth = dpt::head(m["depth_head"], features, w_patch, h_patch); - // depth = ggml_relu_inplace(m, depth); <- reference does another ReLU here + auto features = dino_get_intermediate_layers(m["backbone"], image, p.feature_layers, p.dino); + tensor fused = dpt::neck(m["neck"], features, w_patch, h_patch); + tensor depth = dpt::head(m["head"], fused, w, h, p.max_depth); + return compute_graph_output(m, depth); } @@ -107,9 +116,11 @@ i32x2 depthany_image_extent(i32x2 extent, depthany_params const& p) { return next_multiple(target, p.image_multiple); } -depthany_params depthany_detect_params(model_file const&, i32x2 input_extent) { +depthany_params depthany_detect_params(model_file const& file, i32x2 input_extent) { depthany_params p; - p.dino.patch_size = 14; + p.dino = dino_detect_params(file); + p.image_size = file.get_int("depthanything.image_size"); + file.get_array("depthanything.feature_layers", p.feature_layers); if (input_extent[0] > 0 && input_extent[1] > 0) { p.image_extent = depthany_image_extent(input_extent, p); } diff --git a/src/visp/arch/depth-anything.h b/src/visp/arch/depth-anything.h index d937124..cc8a0c3 100644 --- a/src/visp/arch/depth-anything.h +++ b/src/visp/arch/depth-anything.h @@ -7,6 +7,7 @@ namespace visp::dpt { tensor residual_conv(model_ref m, tensor x); tensor feature_fusion(model_ref m, tensor x0, tensor x1, int64_t const* size = nullptr); -tensor head(model_ref m, span features, int64_t patch_w, int64_t patch_h); +tensor neck(model_ref m, span features, int64_t patch_w, int64_t patch_h); +tensor head(model_ref m, tensor fused, int64_t patch_w, int64_t patch_h, float max_depth); } // namespace visp::dpt diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp index d3fd62c..a1717c4 100644 --- a/src/visp/arch/dino.cpp +++ b/src/visp/arch/dino.cpp @@ -7,7 +7,7 @@ namespace visp { namespace dino { tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int patch_size) { - tensor pos_embed = m.weights("pos_embed"); + tensor pos_embed = m.weights("position_embeddings"); int64_t n_patch = x->ne[1] - 1; int64_t n = pos_embed->ne[1] - 1; if (n_patch == n && w == h) { @@ -30,7 +30,7 @@ tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int tensor prepare_tokens(model_ref m, tensor x, int patch_size) { auto [c, w, h, n] = nelements(x); - x = patch_embed(m["patch_embed"], x, patch_size); + x = patch_embed(m["patch_embeddings"], x, patch_size); x = ggml_reshape_3d(m, x, x->ne[0], x->ne[1] * x->ne[2], x->ne[3]); tensor cls_token = m.weights("cls_token"); @@ -45,7 +45,7 @@ tensor prepare_tokens(model_ref m, tensor x, int patch_size) { } tensor layer_scale(model_ref m, tensor x) { - return ggml_mul(m, x, m.weights("gamma")); + return ggml_mul(m, x, m.weights("lambda1")); } tensor mlp(model_ref m, tensor x) { @@ -58,22 +58,21 @@ tensor mlp(model_ref m, tensor x) { tensor attention(model_ref m, tensor x, int n_heads) { auto [c, n, b, _] = nelements(x); float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); - - tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); - qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - - auto split = [=](tensor qkv, size_t index, ggml_type type, bool transpose = false) mutable { - tensor t = slice(m, qkv, {}, {}, {}, index); - t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); - t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3); - return ggml_cast(m, t, type); + bool flash_attn = bool(m.flags & model_build_flag::flash_attention); + ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + auto split = [=](model_ref m, tensor x, ggml_type type, bool transpose = false) mutable { + x = linear(m, x); + x = ggml_reshape_4d(m, x, c / n_heads, n_heads, n, b); + x = transpose ? ggml_permute(m, x, 1, 2, 0, 3) : ggml_permute(m, x, 0, 2, 1, 3); + return ggml_cast(m, x, type); }; - tensor q = split(qkv, 0, GGML_TYPE_F32); - tensor k = split(qkv, 1, GGML_TYPE_F16); - tensor v = split(qkv, 2, GGML_TYPE_F16, !(m.flags & model_build_flag::flash_attention)); - if (m.flags & model_build_flag::flash_attention) { + tensor q = split(m["attention.query"], x, GGML_TYPE_F32); + tensor k = split(m["attention.key"], x, kv_type); + tensor v = split(m["attention.value"], x, kv_type, !flash_attn); + + if (flash_attn) { x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); } else { tensor attn = ggml_mul_mat(m, k, q); @@ -84,21 +83,21 @@ tensor attention(model_ref m, tensor x, int n_heads) { } x = ggml_reshape_3d(m, x, c, n, b); - x = linear(m["proj"], x); + x = linear(m["output.dense"], x); return named(m, x); } -tensor block(model_ref m, tensor x, dino_params const& p) { +tensor layer(model_ref m, tensor x, dino_params const& p) { tensor attn = x; attn = layer_norm(m["norm1"], attn, 1e-6f); - attn = attention(m["attn"], attn, p.n_heads); - attn = layer_scale(m["ls1"], attn); + attn = attention(m["attention"], attn, p.n_heads); + attn = layer_scale(m["layer_scale1"], attn); x = ggml_add(m, x, attn); tensor ffn = x; ffn = layer_norm(m["norm2"], ffn, 1e-6f); ffn = mlp(m["mlp"], ffn); - ffn = layer_scale(m["ls2"], ffn); + ffn = layer_scale(m["layer_scale2"], ffn); x = ggml_add(m, x, ffn); return named(m, x); @@ -112,15 +111,15 @@ bool contains(std::span r, T const& value) { std::vector get_intermediate_layers( model_ref m, tensor x, std::span layers, dino_params const& p) { - x = prepare_tokens(m, x, p.patch_size); + x = prepare_tokens(m["embeddings"], x, p.patch_size); std::vector outputs; - model_ref blocks = m["blocks"]; - for (int i = 0; i < p.n_blocks; ++i) { - x = block(blocks[i], x, p); + model_ref encoder = m["encoder.layer"]; + for (int i = 0; i < p.n_layers; ++i) { + x = layer(encoder[i], x, p); if (contains(layers, i)) { - tensor out = layer_norm(m["norm"], x, 1e-6f); + tensor out = layer_norm(m["layernorm"], x, 1e-6f); ggml_format_name(out, "dino_layer_%d", i); ggml_build_forward_expand(m.graph, out); outputs.push_back(out); @@ -136,4 +135,13 @@ std::vector dino_get_intermediate_layers( return dino::get_intermediate_layers(m, x, layers, p); } +dino_params dino_detect_params(model_file const& file) { + dino_params p{}; + p.patch_size = file.get_int("dino.patch_size"); + p.embed_dim = file.get_int("dino.embed_dim"); + p.n_heads = file.get_int("dino.n_heads"); + p.n_layers = file.get_int("dino.n_layers"); + return p; +} + } // namespace visp diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 192a358..fe65afa 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -12,7 +12,7 @@ tensor prepare_tokens(model_ref m, tensor x, int patch_size); tensor layer_scale(model_ref m, tensor x); tensor mlp(model_ref m, tensor x); tensor attention(model_ref m, tensor x, int n_heads); -tensor block(model_ref m, tensor x, dino_params const& p); +tensor layer(model_ref m, tensor x, dino_params const& p); std::vector get_intermediate_layers( model_ref m, tensor x, std::span layers, dino_params const& p); diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index 6e88685..287ba7d 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -200,6 +200,18 @@ int model_file::get_int(char const* key_name) const { return gguf_get_val_i32(gguf.get(), key(key_name)); } +void model_file::get_array(char const* key_name, span out_values) const { + int64_t key_id = key(key_name); + if (gguf_get_arr_n(gguf.get(), key_id) != out_values.size()) { + throw except("Array size mismatch for key '{}' in model file {}", key_name, path); + } + if (gguf_get_arr_type(gguf.get(), key_id) != GGUF_TYPE_INT32) { + throw except("Array type mismatch for key '{}' in model file {}, expected int32", key_name, path); + } + auto ptr = (int const*)gguf_get_arr_data(gguf.get(), key_id); + std::copy(ptr, ptr + out_values.size(), out_values.data()); +} + std::string_view model_file::arch() const { return get_string("general.architecture"); } diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp index 3073d06..6d3268c 100644 --- a/src/visp/nn.cpp +++ b/src/visp/nn.cpp @@ -169,9 +169,10 @@ tensor batch_norm_2d(model_ref m, tensor x) { tensor patch_embed(model_ref m, tensor x, int patch_size) { ASSERT(x->ne[1] % patch_size == 0 && x->ne[2] % patch_size == 0); + char const* proj = m.find("proj.weight") ? "proj" : "projection"; m.flags |= model_build_flag::cwhn; - x = conv_2d(m["proj"], x, patch_size); + x = conv_2d(m[proj], x, patch_size); if (m.find("norm.weight")) { auto [c, w, h, b] = nelements(x); diff --git a/tests/benchmark.cpp b/tests/benchmark.cpp index 8c28bb6..7d12ed2 100644 --- a/tests/benchmark.cpp +++ b/tests/benchmark.cpp @@ -101,7 +101,7 @@ bench_timings benchmark_depth_anything(path model_path, backend_device& backend) image_data input_data = depthany_process_input(input, model.params); depthany_compute(model, input); - return run_benchmark(model.graph, backend, 8, {{model.input, input_data}}); + return run_benchmark(model.graph, backend, 12, {{model.input, input_data}}); } bench_timings benchmark_migan(path model_path, backend_device& backend) { @@ -184,7 +184,7 @@ bench_result benchmark_model( result.time = benchmark_birefnet(model_path, backend); } else if (arch == "depthany") { - path model_path = select_model(model, "DepthAnythingV2-Small-F16.gguf"); + path model_path = select_model(model, "Depth-Anything-V2-Small-F16.gguf"); result.time = benchmark_depth_anything(model_path, backend); } else if (arch == "migan") { diff --git a/tests/test-models.cpp b/tests/test-models.cpp index 4444e84..2ca1cd6 100644 --- a/tests/test-models.cpp +++ b/tests/test-models.cpp @@ -71,7 +71,7 @@ VISP_TEST(test_birefnet_dynamic) { } VISP_BACKEND_TEST(test_depth_anything)(backend_type bt) { - path model_path = test_dir().models / "DepthAnythingV2-Small-F16.gguf"; + path model_path = test_dir().models / "Depth-Anything-V2-Small-F16.gguf"; path input_path = test_dir().input / "wardrobe.jpg"; std::string name = "depth-anything"; name += bt == backend_type::cpu ? "-cpu.png" : "-gpu.png"; diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 93d86c3..2b877b9 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -438,14 +438,14 @@ DEF(dino_attention)(model_ref m, span input, param_dict const& p) { DEF(dino_block)(model_ref m, span input, param_dict const& p) { dino_params params{}; params.n_heads = p.get("n_heads", 8); - return {dino::block(m, input[0], params)}; + return {dino::layer(m, input[0], params)}; } DEF(dino_intermediate_layers)(model_ref m, span input, param_dict const& p) { dino_params params{}; params.patch_size = 4; params.embed_dim = 6; - params.n_blocks = 4; + params.n_layers = 4; params.n_heads = 3; auto layers = std::array{0, 1, 2, 3}; return dino::get_intermediate_layers(m, input[0], layers, params); @@ -467,7 +467,9 @@ DEF(depthany_feature_fusion)(model_ref m, span input, param_dict const& DEF(depthany_head)(model_ref m, span input, param_dict const& p) { int patch_w = p.get("patch_w", 8); int patch_h = p.get("patch_h", 8); - return {dpt::head(m, input, patch_w, patch_h)}; + tensor fused = dpt::neck(m, input, patch_w, patch_h); + tensor depth = dpt::head(m, fused, patch_w * 14, patch_h * 14, 1.0f); + return {depth}; } // From 4be7932afd43fcccd4739da493c90d87d79eee0d Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 13 Oct 2025 20:32:38 +0200 Subject: [PATCH 18/24] depth-anything: remove tests using old arch * too much work to migrate atm --- tests/test_depth_anything.py | 870 ----------------------------------- 1 file changed, 870 deletions(-) delete mode 100644 tests/test_depth_anything.py diff --git a/tests/test_depth_anything.py b/tests/test_depth_anything.py deleted file mode 100644 index dc2be0b..0000000 --- a/tests/test_depth_anything.py +++ /dev/null @@ -1,870 +0,0 @@ -from functools import partial -import math -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - -from tests import workbench -from tests.workbench import convert_to_nhwc, generate_state, input_tensor, to_nchw, to_nhwc - -# -# DINOv2 -# - - -class PatchEmbed(nn.Module): - def __init__( - self, - img_size=(224, 224), - patch_size=(16, 16), - in_chans: int = 3, - embed_dim: int = 768, - ): - super().__init__() - self.embed_dim = embed_dim - self.patch_size = patch_size - self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x: Tensor, flatten=False) -> Tensor: - _, _, H, W = x.shape - patch_H, patch_W = self.patch_size - assert H % patch_H == 0, ( - f"Input image height {H} is not a multiple of patch height {patch_H}" - ) - assert W % patch_W == 0, ( - f"Input image width {W} is not a multiple of patch width: {patch_W}" - ) - x = self.proj(x) # B C H W - H, W = x.size(2), x.size(3) - x = x.flatten(2).transpose(1, 2) # B HW C - # x = self.norm(x) - if not flatten: - x = x.reshape(-1, H, W, self.embed_dim) # B H W C - return x - - -def test_patch_embed(): - patch_embed = PatchEmbed(img_size=(16, 16), patch_size=(4, 4), in_chans=3, embed_dim=8) - state = generate_state(patch_embed.state_dict()) - patch_embed.load_state_dict(state) - patch_embed.eval() - - x = input_tensor(1, 3, 8, 12) - expected = patch_embed(x) - - x = to_nhwc(x) - state = convert_to_nhwc(state, key="proj") - result = workbench.invoke_test("biref_patch_embed", x, state) - - assert torch.allclose(result, expected) - - -def interpolate_pos_encoding(pos_embed: Tensor, x: Tensor, w: int, h: int, patch_size: int): - # This is 0.1 in official code, which would cause a small difference because ggml - # does not support passing a scale_factor to interpolate - interpolate_offset = 0.0 - interpolate_antialias = False - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = pos_embed.shape[1] - 1 - if npatch == N and w == h: - return pos_embed - pos_embed = pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // patch_size - h0 = h // patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 - w0, h0 = w0 + interpolate_offset, h0 + interpolate_offset - # w0, h0 = w0 + 0.1, h0 + 0.1 - - sqrt_N = math.sqrt(N) - sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), - scale_factor=(sx, sy), - # (int(w0), int(h0)), # to solve the upsampling shape issue - mode="bicubic", - antialias=interpolate_antialias, - ) - - assert int(w0) == patch_pos_embed.shape[-2] - assert int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - -def test_interpolate_pos_encoding(): - img_size = 12 - patch_size = 4 - num_patches = (img_size // patch_size) ** 2 - embed_dim = 8 - pos_embed = torch.randn(1, num_patches + 1, embed_dim) - - x = input_tensor(1, num_patches, embed_dim) - expected = interpolate_pos_encoding(pos_embed, x, img_size, img_size, patch_size) - - state = {"pos_embed": pos_embed} - params = {"img_size": img_size, "patch_size": patch_size} - result = workbench.invoke_test("dino_interpolate_pos_encoding", x, state, params) - - assert torch.allclose(result, expected) - - -class PrepareTokensModule(nn.Module): - def __init__(self, img_size, patch_size, embed_dim: int): - super().__init__() - self.patch_size = patch_size - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=(patch_size, patch_size), embed_dim=embed_dim - ) - num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - - def prepare_tokens_with_masks(self, x: Tensor, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x, flatten=True) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + interpolate_pos_encoding(self.pos_embed, x, w, h, self.patch_size) - return x - - -def test_prepare_tokens_with_masks(): - img_size = 12 - patch_size = 4 - embed_dim = 6 - module = PrepareTokensModule((img_size, img_size), patch_size, embed_dim) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - x = input_tensor(1, 3, img_size, img_size) - expected = module.prepare_tokens_with_masks(x) - - x = to_nhwc(x) - state = convert_to_nhwc(state, key="patch_embed.proj") - result = workbench.invoke_test("dino_prepare_tokens", x, state) - - assert torch.allclose(result, expected) - - -class Mlp(nn.Module): - def __init__( - self, - in_features: int, - hidden_features: int | None = None, - out_features: int | None = None, - act_layer=nn.GELU, - bias: bool = True, - ) -> None: - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) - - def forward(self, x: Tensor) -> Tensor: - x = self.fc1(x) - x = self.act(x) - # x = self.drop(x) - x = self.fc2(x) - # x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__( - self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True - ) -> None: - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.proj = nn.Linear(dim, dim, bias=proj_bias) - - def forward(self, x: Tensor) -> Tensor: - B, N, C = x.shape - qkv = ( - self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - ) - - q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] - attn = q @ k.transpose(-2, -1) - - attn = attn.softmax(dim=-1) - # attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - # x = self.proj_drop(x) - return x - - -@pytest.mark.parametrize("flash_attn", [0, 1]) -def test_attention(flash_attn: int): - dim = 6 - num_heads = 3 - module = Attention(dim=dim, num_heads=num_heads, qkv_bias=True, proj_bias=True) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - x = input_tensor(1, 12, dim) - expected = module(x) - result = workbench.invoke_test( - "dino_attention", x, state, dict(n_heads=num_heads, flash_attn=flash_attn) - ) - - assert torch.allclose(result, expected) - - -class LayerScale(nn.Module): - def __init__(self, dim: int, init_values=1e-5, inplace: bool = False) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: Tensor) -> Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - mlp_ratio: float = 4.0, - qkv_bias: bool = True, - proj_bias: bool = True, - ffn_bias: bool = True, - init_values=None, - ) -> None: - super().__init__() - self.norm1 = nn.LayerNorm(dim) - self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias) - self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path1 = nn.Identity() - - self.norm2 = nn.LayerNorm(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=nn.GELU, - bias=ffn_bias, - ) - self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() - self.drop_path2 = nn.Identity() - - def forward(self, x: Tensor) -> Tensor: - def attn_residual_func(x: Tensor) -> Tensor: - return self.ls1(self.attn(self.norm1(x))) - - def ffn_residual_func(x: Tensor) -> Tensor: - return self.ls2(self.mlp(self.norm2(x))) - - x = x + attn_residual_func(x) - x = x + ffn_residual_func(x) - return x - - -def test_block(): - dim = 6 - num_heads = 3 - module = Block(dim=dim, num_heads=num_heads, init_values=1.0) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - x = input_tensor(1, 12, dim) - expected = module(x) - result = workbench.invoke_test("dino_block", x, state, dict(n_heads=num_heads)) - - assert torch.allclose(result, expected, rtol=1e-2) # precision drop due to GELU in MLP - - -class DinoVisionTransformer(nn.Module): - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=True, - ffn_bias=True, - proj_bias=True, - drop_path_rate=0.0, - drop_path_uniform=False, - init_values=None, # for layerscale: None or 0 => no layerscale - embed_layer=PatchEmbed, - act_layer=nn.GELU, - block_fn=Block, - ffn_layer="mlp", - block_chunks=1, - num_register_tokens=0, - interpolate_antialias=False, - interpolate_offset=0.1, - ): - super().__init__() - norm_layer = partial(nn.LayerNorm, eps=1e-6) - - self.num_features = self.embed_dim = ( - embed_dim # num_features for consistency with other models - ) - self.num_tokens = 1 - self.n_blocks = depth - self.num_heads = num_heads - self.patch_size = patch_size - self.num_register_tokens = num_register_tokens - self.interpolate_antialias = interpolate_antialias - self.interpolate_offset = interpolate_offset - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim - ) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) - assert num_register_tokens >= 0 - self.register_tokens = ( - nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) - if num_register_tokens - else None - ) - - if drop_path_uniform is True: - dpr = [drop_path_rate] * depth - else: - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - - if ffn_layer == "mlp": - ffn_layer = Mlp - elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": - assert False, "swiglu not implemented" - elif ffn_layer == "identity": - - def f(*args, **kwargs): - return nn.Identity() - - ffn_layer = f - else: - raise NotImplementedError - - blocks_list = [ - block_fn( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - proj_bias=proj_bias, - ffn_bias=ffn_bias, - # drop_path=dpr[i], - # norm_layer=norm_layer, - # act_layer=act_layer, - # ffn_layer=ffn_layer, - init_values=init_values, - ) - for i in range(depth) - ] - - self.chunked_blocks = False - self.blocks = nn.ModuleList(blocks_list) - self.norm = norm_layer(embed_dim) - self.head = nn.Identity() - self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) - # self.init_weights() - - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - if npatch == N and w == h: - return self.pos_embed - pos_embed = self.pos_embed.float() - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 - w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset - # w0, h0 = w0 + 0.1, h0 + 0.1 - - sqrt_N = math.sqrt(N) - sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), - scale_factor=(sx, sy), - # (int(w0), int(h0)), # to solve the upsampling shape issue - mode="bicubic", - antialias=self.interpolate_antialias, - ) - - assert int(w0) == patch_pos_embed.shape[-2] - assert int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape - x = self.patch_embed(x, flatten=True) - if masks is not None: - x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) - - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) - - if self.register_tokens is not None: - x = torch.cat( - ( - x[:, :1], - self.register_tokens.expand(x.shape[0], -1, -1), - x[:, 1:], - ), - dim=1, - ) - - return x - - def forward_features_list(self, x_list, masks_list): - x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] - for blk in self.blocks: - x = blk(x) - - all_x = x - output = [] - for x, masks in zip(all_x, masks_list): - x_norm = self.norm(x) - output.append({ - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - }) - return output - - def forward_features(self, x, masks=None): - if isinstance(x, list): - return self.forward_features_list(x, masks) - - x = self.prepare_tokens_with_masks(x, masks) - - for blk in self.blocks: - x = blk(x) - - x_norm = self.norm(x) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], - "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], - "x_prenorm": x, - "masks": masks, - } - - def _get_intermediate_layers_not_chunked(self, x, n=1): - x = self.prepare_tokens_with_masks(x) - # If n is an int, take the n last blocks. If it's a list, take them - output, total_block_len = [], len(self.blocks) - blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n - for i, blk in enumerate(self.blocks): - x = blk(x) - if i in blocks_to_take: - output.append(x) - assert len(output) == len(blocks_to_take), ( - f"only {len(output)} / {len(blocks_to_take)} blocks found" - ) - return output - - def get_intermediate_layers( - self, - x: torch.Tensor, - n: int | list[int] = 1, # Layers or n last layers to take - reshape: bool = False, - return_class_token: bool = False, - norm=True, - ): - outputs = self._get_intermediate_layers_not_chunked(x, n) - if norm: - outputs = [self.norm(out) for out in outputs] - class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] - if reshape: - B, _, w, h = x.shape - outputs = [ - out.reshape(B, w // self.patch_size, h // self.patch_size, -1) - .permute(0, 3, 1, 2) - .contiguous() - for out in outputs - ] - if return_class_token: - return tuple(zip(outputs, class_tokens)) - return tuple(outputs) - - def forward(self, *args, is_training=False, **kwargs): - ret = self.forward_features(*args, **kwargs) - if is_training: - return ret - else: - return self.head(ret["x_norm_clstoken"]) - - -def test_dino_intermediate_layers(): - img_size = 8 - patch_size = 4 - embed_dim = 6 - depth = 4 - num_heads = 3 - module = DinoVisionTransformer( - img_size=(img_size, img_size), - patch_size=(patch_size, patch_size), - embed_dim=embed_dim, - depth=depth, - num_heads=num_heads, - init_values=1.0, - interpolate_offset=0.0, - ) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - x = input_tensor(1, 3, img_size, img_size) - expected = module.get_intermediate_layers(x, n=4) - - state = convert_to_nhwc(state, key="patch_embed.proj") - x = to_nhwc(x) - result = workbench.invoke_test("dino_intermediate_layers", x, state) - - for r, e in zip(result, expected): - r = r.squeeze(0) - r = r[:, 1:, :] # remove cls token - assert torch.allclose(r, e) - - -# -# Depth Anything -# - - -class ResidualConvUnit(nn.Module): - def __init__(self, features, activation, bn=False): - super().__init__() - self.bn = bn - self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - self.activation = activation - self.skip_add = nn.quantized.FloatFunctional() - - def forward(self, x): - out = self.activation(x) - out = self.conv1(out) - if self.bn == True: - out = self.bn1(out) - - out = self.activation(out) - out = self.conv2(out) - if self.bn == True: - out = self.bn2(out) - - if self.groups > 1: - out = self.conv_merge(out) - - return self.skip_add.add(out, x) - - -class FeatureFusionBlock(nn.Module): - """Feature fusion block.""" - - def __init__( - self, - features, - activation, - deconv=False, - bn=False, - expand=False, - align_corners=True, - size=None, - ): - super(FeatureFusionBlock, self).__init__() - self.deconv = deconv - self.align_corners = align_corners - self.groups = 1 - self.expand = expand - out_features = features - if self.expand == True: - out_features = features // 2 - - self.out_conv = nn.Conv2d( - features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1 - ) - self.resConfUnit1 = ResidualConvUnit(features, activation, bn) - self.resConfUnit2 = ResidualConvUnit(features, activation, bn) - self.skip_add = nn.quantized.FloatFunctional() - self.size = size - - def forward(self, *xs, size=None): - output = xs[0] - if len(xs) == 2: - res = self.resConfUnit1(xs[1]) - output = self.skip_add.add(output, res) - - output = self.resConfUnit2(output) - - if (size is None) and (self.size is None): - modifier = {"scale_factor": 2} - elif size is None: - modifier = {"size": self.size} - else: - modifier = {"size": size} - - output = nn.functional.interpolate( - output, **modifier, mode="bilinear", align_corners=self.align_corners - ) - output = self.out_conv(output) - return output - - -def _make_fusion_block(features, use_bn, size=None): - return FeatureFusionBlock( - features, - nn.ReLU(False), - deconv=False, - bn=use_bn, - expand=False, - align_corners=True, - size=size, - ) - - -@pytest.mark.parametrize("inputs", [1, 2]) -def test_feature_fusion(inputs): - features = 6 - x = [input_tensor(1, features, 4, 4)] - size = (8, 8) - if inputs == 2: - x.append(input_tensor(1, features, 4, 4)) - size = None - - module = _make_fusion_block(features, use_bn=False) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - expected = module(*x, size=size) - result = workbench.invoke_test("depthany_feature_fusion", x, state=state) - - assert torch.allclose(result, expected) - - -class ConvBlock(nn.Module): - def __init__(self, in_feature, out_feature): - super().__init__() - - self.conv_block = nn.Sequential( - nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), - nn.BatchNorm2d(out_feature), - nn.ReLU(True), - ) - - def forward(self, x): - return self.conv_block(x) - - -def _make_scratch(in_shape, out_shape, groups=1, expand=False): - scratch = nn.Module() - - out_shape1 = out_shape - out_shape2 = out_shape - out_shape3 = out_shape - if len(in_shape) >= 4: - out_shape4 = out_shape - - if expand: - out_shape1 = out_shape - out_shape2 = out_shape * 2 - out_shape3 = out_shape * 4 - if len(in_shape) >= 4: - out_shape4 = out_shape * 8 - - scratch.layer1_rn = nn.Conv2d( - in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer2_rn = nn.Conv2d( - in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer3_rn = nn.Conv2d( - in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - if len(in_shape) >= 4: - scratch.layer4_rn = nn.Conv2d( - in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - - return scratch - - -class DPTHead(nn.Module): - def __init__( - self, - in_channels, - features=256, - use_bn=False, - out_channels=[256, 512, 1024, 1024], - use_clstoken=False, - ): - super(DPTHead, self).__init__() - - self.use_clstoken = use_clstoken - - self.projects = nn.ModuleList([ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channel, - kernel_size=1, - stride=1, - padding=0, - ) - for out_channel in out_channels - ]) - - self.resize_layers = nn.ModuleList([ - nn.ConvTranspose2d( - in_channels=out_channels[0], - out_channels=out_channels[0], - kernel_size=4, - stride=4, - padding=0, - ), - nn.ConvTranspose2d( - in_channels=out_channels[1], - out_channels=out_channels[1], - kernel_size=2, - stride=2, - padding=0, - ), - nn.Identity(), - nn.Conv2d( - in_channels=out_channels[3], - out_channels=out_channels[3], - kernel_size=3, - stride=2, - padding=1, - ), - ]) - - if use_clstoken: - self.readout_projects = nn.ModuleList() - for _ in range(len(self.projects)): - self.readout_projects.append( - nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()) - ) - - self.scratch = _make_scratch( - out_channels, - features, - groups=1, - expand=False, - ) - - self.scratch.stem_transpose = None - - self.scratch.refinenet1 = _make_fusion_block(features, use_bn) - self.scratch.refinenet2 = _make_fusion_block(features, use_bn) - self.scratch.refinenet3 = _make_fusion_block(features, use_bn) - self.scratch.refinenet4 = _make_fusion_block(features, use_bn) - - head_features_1 = features - head_features_2 = 32 - - self.scratch.output_conv1 = nn.Conv2d( - head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 - ) - self.scratch.output_conv2 = nn.Sequential( - nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), - nn.ReLU(True), - nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), - nn.ReLU(True), - nn.Identity(), - ) - - def forward(self, out_features, patch_h, patch_w): - out = [] - for i, x in enumerate(out_features): - if self.use_clstoken: - x, cls_token = x[0], x[1] - readout = cls_token.unsqueeze(1).expand_as(x) - x = self.readout_projects[i](torch.cat((x, readout), -1)) - else: - x = x[0] - - x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) - - x = self.projects[i](x) - x = self.resize_layers[i](x) - - out.append(x) - - layer_1, layer_2, layer_3, layer_4 = out - - layer_1_rn = self.scratch.layer1_rn(layer_1) - layer_2_rn = self.scratch.layer2_rn(layer_2) - layer_3_rn = self.scratch.layer3_rn(layer_3) - layer_4_rn = self.scratch.layer4_rn(layer_4) - - path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) - path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) - path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) - path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - - out = self.scratch.output_conv1(path_1) - out = F.interpolate( - out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True - ) - out = self.scratch.output_conv2(out) - - return out - - -def test_dpt_head(): - in_channels = 4 - features = 6 - h, w = 8, 8 - module = DPTHead(in_channels=in_channels, features=features, use_clstoken=False) - state = generate_state(module.state_dict()) - module.load_state_dict(state) - module.eval() - - x = [input_tensor(1, 1 + h * w, in_channels) for _ in range(4)] - x_tuples = [(t[:, 1:, :], None) for t in x] - expected = module(x_tuples, h, w) - - state = convert_to_nhwc(state, key="projects") - result = workbench.invoke_test("depthany_head", x, state) - - assert torch.allclose(result, expected, atol=1e-3) From 466c561c7b5d98f0709a1b52e4ea1530410631fe Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 14 Oct 2025 12:43:26 +0200 Subject: [PATCH 19/24] ml: add public interface for swin & dino backbones in vision.h --- include/visp/ml.h | 37 ---- include/visp/vision.h | 38 ++++- src/visp/CMakeLists.txt | 1 + src/visp/arch/birefnet.cpp | 332 ++---------------------------------- src/visp/arch/birefnet.h | 47 +----- src/visp/arch/dino.h | 1 + src/visp/arch/swin.cpp | 334 +++++++++++++++++++++++++++++++++++++ src/visp/arch/swin.h | 39 +++++ tests/test_birefnet.py | 4 +- tests/workbench.cpp | 48 +++--- 10 files changed, 453 insertions(+), 428 deletions(-) create mode 100644 src/visp/arch/swin.cpp create mode 100644 src/visp/arch/swin.h diff --git a/include/visp/ml.h b/include/visp/ml.h index 5a3e0dc..efb70e1 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -277,43 +277,6 @@ VISP_API tensor concat(model_ref const&, std::array src, i // Up- or downsample a 2D tensor (WHCN) to target width x height. VISP_API tensor interpolate(model_ref const&, tensor x, i64x2 target, int32_t mode); -// -// SWIN Transformer - -struct swin_layer_t { - int depth; - int n_heads; - int n_features; - bool downsample; -}; - -struct swin_params { - static constexpr int n_layers = 4; - - int embed_dim; - int window_size; - std::array layers; -}; - -extern swin_params const swin_t_params; -extern swin_params const swin_l_params; -VISP_API swin_params swin_detect_params(model_file const&); - -// -// DINO - -struct dino_params { - int patch_size = 16; - int embed_dim = 768; - int n_layers = 12; - int n_heads = 12; -}; - -VISP_API dino_params dino_detect_params(model_file const&); - -VISP_API std::vector dino_get_intermediate_layers( - model_ref, tensor image, std::span layers, dino_params const&); - // // implementation diff --git a/include/visp/vision.h b/include/visp/vision.h index e529754..94d257a 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -79,6 +79,42 @@ namespace visp { +// SWIN - vision transformer for feature extraction + +constexpr int swin_n_layers = 4; + +struct swin_layer_t { + int depth; + int n_heads; + int n_features; +}; + +struct swin_params { + int embed_dim; + int window_size; + std::array layers; +}; + +using swin_buffers = std::array; +using swin_result = std::array; + +VISP_API swin_params swin_detect_params(model_file const&); +VISP_API swin_buffers swin_precompute(model_ref, i32x2 image_extent, swin_params const&); +VISP_API swin_result swin_encode(model_ref, tensor image, swin_params const&); + +// DINO - vision transformer for feature extraction + +struct dino_params { + int patch_size = 16; + int embed_dim = 768; + int n_layers = 12; + int n_heads = 12; +}; + +VISP_API dino_params dino_detect_params(model_file const&); +VISP_API std::vector dino_get_intermediate_layers( + model_ref, tensor image, span layers_ids, dino_params const&); + // // Mobile SAM - image segmentation with prompt (point or box) @@ -148,7 +184,7 @@ struct birefnet_params { swin_params encoder; }; -using birefnet_buffers = std::array; +using birefnet_buffers = swin_buffers; VISP_API birefnet_params birefnet_detect_params( model_file const&, i32x2 dynamic_extent = {}, size_t max_alloc = SIZE_MAX); diff --git a/src/visp/CMakeLists.txt b/src/visp/CMakeLists.txt index 39d321b..14d7964 100644 --- a/src/visp/CMakeLists.txt +++ b/src/visp/CMakeLists.txt @@ -7,6 +7,7 @@ target_sources(visioncpp PRIVATE arch/esrgan.cpp arch/migan.cpp arch/mobile-sam.cpp + arch/swin.cpp image.cpp ml.cpp nn.cpp diff --git a/src/visp/arch/birefnet.cpp b/src/visp/arch/birefnet.cpp index 0bcd3e2..37915db 100644 --- a/src/visp/arch/birefnet.cpp +++ b/src/visp/arch/birefnet.cpp @@ -1,6 +1,7 @@ #include "visp/arch/birefnet.h" #include "util/math.h" #include "util/string.h" +#include "visp/arch/swin.h" #include "visp/nn.h" #include "visp/vision.h" @@ -9,278 +10,9 @@ namespace visp { namespace birefnet { -tensor mlp(model_ref m, tensor x) { - x = linear(m["fc1"], x); - x = ggml_gelu_inplace(m, x); - x = linear(m["fc2"], x); - return named(m, x); -} - -// Ensures that the tensor's data is not overwritten during computation. -tensor make_constant(tensor x, tensor_name name) { - ggml_set_name(x, name.c_str()); - ggml_set_input(x); // allocate at the beginning of the graph buffer - ggml_set_output(x); // don't reuse memory for computations - return x; -} - -void compute_relative_position_index(span dst, int window_size) { - int n = window_size; - int n2 = n * n; - int n4 = n2 * n2; - for (int i = 0; i < n4; ++i) { - int x0 = i % n; - int y0 = (i / n) % n; - int x1 = (i / n2) % n; - int y1 = (i / n2 / n) % n; - dst[i] = (y1 - y0 + n - 1) * (2 * n - 1) + (x1 - x0 + n - 1); - } -} - -tensor_data create_relative_position_index(ggml_context* ctx, int window_size) { - int n = window_size; - auto result = tensor_alloc(ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n * n * n * n)); - auto name = format("window_attention_{}.rel_pos_index", n); - compute_relative_position_index(result.as_i32(), n); - make_constant(result.x, name); - return result; -} - -tensor window_partition(model_ref m, tensor x, int window) { - auto [c, w, h, b] = nelements(x); - ASSERT(w % window == 0 && h % window == 0, "Expecting padded input"); - - x = ggml_reshape_4d(m, x, c * window, w / window, window, (h / window) * b); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_3d(m, x, c, window * window, (w / window) * (h / window) * b); - return x; -} - -tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) { - int64_t c = x->ne[0]; - int64_t b = x->ne[2] / (w / window) / (h / window); - ASSERT(x->ne[2] % (w / window) == 0, "Expecting ne[2] to be multiple of window count"); - - x = ggml_reshape_4d(m, x, c * window, window, w / window, (h / window) * b); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_4d(m, x, c, w, h, b); - return x; -} - -tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window) { - auto [c, n, b, _] = nelements(x); - - tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, c / num_heads, num_heads, 3, n * b); - qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - - auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { - tensor = slice(m, tensor, {}, {}, {}, index); - tensor = ggml_reshape_4d(m, tensor, c / num_heads, num_heads, n, b); - if (transpose) { - tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); - } else { - tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); - } - return tensor; - }; - tensor q = split(qkv, 0); - tensor k = split(qkv, 1); - tensor v = split(qkv, 2, true); - - q = ggml_scale_inplace(m, q, 1.0f / std::sqrt(float(c / num_heads))); - - tensor attn = ggml_mul_mat(m, k, q); - - tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window); - tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str()); - tensor rel_pos_table = m.weights("relative_position_bias_table"); - tensor rel_pos_bias = ggml_get_rows(m, rel_pos_table, rel_pos_index); - rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, num_heads, window * window, window * window, 1); - rel_pos_bias = ggml_cont(m, ggml_permute(m, rel_pos_bias, 2, 0, 1, 3)); - attn = ggml_add_inplace(m, attn, rel_pos_bias); - - if (mask) { - int64_t nw = mask->ne[2]; - attn = ggml_reshape_4d(m, attn, n * n, num_heads, nw, b / nw); - mask = ggml_reshape_4d(m, mask, n * n, 1, nw, 1); - attn = ggml_add_inplace(m, attn, mask); - attn = ggml_reshape_4d(m, attn, n, n, num_heads, b); - } - attn = ggml_soft_max(m, attn); - - x = ggml_mul_mat(m, v, attn); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_3d(m, x, c, n, b); - - x = linear(m["proj"], x); - return named(m, x); -} - -tensor swin_block(model_ref m, tensor x, tensor mask, swin_block_params const& p) { - auto [c, n, b, _] = nelements(x); - auto [num_heads, window, w, h, shift] = p; - ASSERT(n == w * h && "Spatial dimensions do not match"); - - tensor shortcut = x; - x = layer_norm(m["norm1"], x); - x = ggml_reshape_4d(m, x, c, w, h, b); - - int pad_r = (window - w % window) % window; - int pad_b = (window - h % window) % window; - if (pad_r > 0 || pad_b > 0) { - x = ggml_pad(m, x, 0, pad_r, pad_b, 0); - } - - ASSERT(shift == 0 || mask != nullptr); - if (shift > 0) { - x = ggml_roll(m, x, 0, -shift, -shift, 0); - } - - x = window_partition(m, x, window); - x = window_attention(m["attn"], x, mask, num_heads, window); - x = window_reverse(m, x, w + pad_r, h + pad_b, window); - - if (shift > 0) { // undo shift - x = ggml_roll(m, x, 0, shift, shift, 0); - } - - if (pad_r > 0 || pad_b > 0) { // undo padding - x = ggml_reshape_4d(m, x, c, w + pad_r, h + pad_b, b); - x = slice(m, x, {}, {0, w}, {0, h}, {}); - x = ggml_cont(m, x); - } - - x = ggml_reshape_3d(m, x, c, n, b); - x = ggml_add_inplace(m, x, shortcut); - - tensor x_mlp = layer_norm(m["norm2"], x); - x_mlp = mlp(m["mlp"], x_mlp); - x = ggml_add_inplace(m, x, x_mlp); - - return named(m, x); -} - -tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h) { - auto [c, n, b, _] = nelements(x); - ASSERT(n == w * h, "Spatial dimensions do not match"); - ASSERT(w % 2 == 0 && h % 2 == 0, "Expecting even spatial dimensions"); - - x = ggml_reshape_4d(m, x, c, w, h, b); - // clang-format off - x = concat(m, { - slice(m, x, {}, {0, w, 2}, {0, h, 2}, {}), - slice(m, x, {}, {0, w, 2}, {1, h, 2}, {}), - slice(m, x, {}, {1, w, 2}, {0, h, 2}, {}), - slice(m, x, {}, {1, w, 2}, {1, h, 2}, {})}, 0); - // clang-format on - x = ggml_reshape_3d(m, x, c * 4, n / 4, b); - - x = layer_norm(m["norm"], x); - x = linear(m["reduction"], x); - return named(m, x); -} - -void compute_attention_mask(span out, int64_t w, int64_t h, int window_size) { - int n = window_size; - int n2 = n * n; - int n4 = n2 * n2; - int shift = window_size / 2; - int64_t nw_x = (w + n - 1) / n; - int64_t nw_y = (h + n - 1) / n; - int64_t w_pad = nw_x * n; - int64_t h_pad = nw_y * n; - - std::fill(out.begin(), out.end(), 0.0f); - - for (int iw_y = 0; iw_y < nw_y; ++iw_y) { - for (int iw_x = 0; iw_x < nw_x; ++iw_x) { - // Skip all windows that aren't at the right or bottom edges of the image - if (iw_y < nw_y - 1 && iw_x < nw_x - 1) { - continue; - } - int64_t base = iw_y * nw_x * n4 + iw_x * n4; - - for (int y0 = 0; y0 < n; ++y0) { - for (int x0 = 0; x0 < n; ++x0) { - for (int y1 = 0; y1 < n; ++y1) { - for (int x1 = 0; x1 < n; ++x1) { - // Window-local coordinates to global image coordinates - int yy0 = iw_y * n + y0; - int xx0 = iw_x * n + x0; - int yy1 = iw_y * n + y1; - int xx1 = iw_x * n + x1; - // Check if two patches being matched belong to the same window - // that is: they are both in the shift zone, or both outside - bool match_y = (yy0 < h_pad - shift) == (yy1 < h_pad - shift); - bool match_x = (xx0 < w_pad - shift) == (xx1 < w_pad - shift); - // If not, set mask to -100 (added to attention before softmax) - if (!match_y || !match_x) { - int64_t idx = base + (y0 * n + x0) * n2 + (y1 * n + x1); - out[idx] = -100.f; - } - } - } - } - } - } - } -} - -tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size) { - int n = window_size; - int64_t nw_x = (w + n - 1) / n; - int64_t nw_y = (h + n - 1) / n; - auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n * n, n * n, nw_x * nw_y)); - auto name = format("swin_layer_{}x{}.attn_mask", w, h); - compute_attention_mask(result.as_f32(), w, h, window_size); - make_constant(result.x, name); - return result; -} - -swin_layer_result swin_layer( - model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const& p, int window_size) { - // Attention masks need to be precomputed - tensor_name attn_mask_name = format("swin_layer_{}x{}.attn_mask", w, h); - tensor attn_mask = ggml_get_tensor(m, attn_mask_name.c_str()); - - model_ref blocks = m["blocks"]; - for (int i = 0; i < p.depth; ++i) { - x = swin_block( - blocks[i], x, attn_mask, - {.n_heads = p.n_heads, - .window_size = window_size, - .w = w, - .h = h, - .shift = i % 2 == 0 ? 0 : window_size / 2}); - } - if (p.downsample) { - tensor x_down = patch_merging(m["downsample"], x, w, h); - return {x, w, h, x_down, (w + 1) / 2, (h + 1) / 2}; - } - return {x, w, h, x, w, h}; -} - -swin_result swin_transformer(model_ref m, tensor x, swin_params const& p) { - x = patch_embed(m["patch_embed"], x, 4); - - auto [c, w, h, b] = nelements(x); - x = ggml_reshape_3d(m, x, c, w * h, b); - - swin_layer_result r{x, w, h, x, w, h}; - swin_result outs = {}; - - for (int i = 0; i < swin_params::n_layers; ++i) { - model_ref layer = m["layers"][i]; - r = swin_layer(layer, r.x_down, r.w_down, r.h_down, p.layers[i], p.window_size); - - tensor_name norm_layer = format("norm{}", i); - tensor out = layer_norm(m[norm_layer], r.x_out); - out = ggml_reshape_4d(m, out, p.layers[i].n_features, r.w_out, r.h_out, b); - outs[i] = out; - } - return outs; -} +// +// Encoder +// constexpr int32_t bilinear_align_corners = GGML_SCALE_MODE_BILINEAR | (int)GGML_SCALE_FLAG_ALIGN_CORNERS; @@ -333,9 +65,9 @@ swin_result encode_concat(model_ref m, swin_result& xs, swin_result& xs_low) { } swin_result encode(model_ref m, tensor x, swin_params const& p) { - auto xs = swin_transformer(m["bb"], x, p); + auto xs = swin_encode(m["bb"], x, p); auto x_low = downscale_by(m, x, 2); - auto xs_low = swin_transformer(m["bb"], x_low, p); + auto xs_low = swin_encode(m["bb"], x_low, p); encode_concat(m, xs, xs_low); return xs; } @@ -519,7 +251,7 @@ tensor decode(model_ref m, tensor x, swin_result const& features) { tensor birefnet_predict(model_ref m, tensor image, birefnet_params const& p) { // Encoder - birefnet::swin_result features = birefnet::encode(m, image, p.encoder); + swin_result features = birefnet::encode(m, image, p.encoder); // Squeeze block features[3] = birefnet::basic_decoder_block(m["squeeze_module.0"], features[3]); // Decoder @@ -553,52 +285,6 @@ image_data birefnet_process_output( return image_f32_to_u8(mask_output, image_format::alpha_u8); } -birefnet_buffers birefnet_precompute(model_ref m, birefnet_params const& params) { - int w = params.encoder.window_size; - int width = params.image_extent[0] / 4; - int height = params.image_extent[1] / 4; - - birefnet_buffers b; - b[0] = birefnet::create_relative_position_index(m, w); - for (int i = 0; i < swin_params::n_layers + 1; ++i) { - b[i + 1] = birefnet::create_attention_mask(m, width >> i, height >> i, w); - } - return b; -} - -// clang-format off -const swin_params swin_t_params = { - .embed_dim = 96, - .window_size = 7, - .layers = { - // depth n_heads n_features downsample - swin_layer_t{2, 3, 96 * 1, true}, - swin_layer_t{2, 6, 96 * 2, true}, - swin_layer_t{6, 12, 96 * 4, true}, - swin_layer_t{2, 24, 96 * 8, false}}}; - -const swin_params swin_l_params = { - .embed_dim = 192, - .window_size = 12, - .layers = { - // depth n_heads n_features downsample - swin_layer_t{2, 6, 192 * 1, true}, - swin_layer_t{2, 12, 192 * 2, true}, - swin_layer_t{18, 24, 192 * 4, true}, - swin_layer_t{2, 48, 192 * 8, false}}}; -// clang-format on - -swin_params swin_detect_params(model_file const& f) { - int embed_dim = f.get_int("swin.embed_dim"); - if (embed_dim == 96) { - return swin_t_params; - } else if (embed_dim == 192) { - return swin_l_params; - } else { - throw except("Unsupported Swin Transformer embed dim: {}", embed_dim); - } -} - i32x2 birefnet_image_extent(i32x2 input_extent, birefnet_params const& p, size_t max_alloc) { i32x2 extent{p.image_size, p.image_size}; if (p.image_size == -1) { @@ -632,4 +318,8 @@ birefnet_params birefnet_detect_params( return p; } +birefnet_buffers birefnet_precompute(model_ref m, birefnet_params const& p) { + return swin_precompute(m, p.image_extent, p.encoder); +} + } // namespace visp diff --git a/src/visp/arch/birefnet.h b/src/visp/arch/birefnet.h index 2ad1b3f..90d855c 100644 --- a/src/visp/arch/birefnet.h +++ b/src/visp/arch/birefnet.h @@ -1,48 +1,10 @@ #pragma once -#include "visp/ml.h" #include "visp/image.h" +#include "visp/ml.h" +#include "visp/vision.h" -#include - -namespace visp { -namespace birefnet { - -// SWIN Transformer - -struct swin_block_params { - int n_heads = 6; - int window_size = 7; - int64_t w = 0; - int64_t h = 0; - int shift = 0; -}; - -struct swin_layer_result { - tensor x_out; - int64_t w_out; - int64_t h_out; - tensor x_down; - int64_t w_down; - int64_t h_down; -}; - -using swin_result = std::array; - -void compute_relative_position_index(span dst, int window_size); -tensor_data create_relative_position_index(ggml_context* ctx, int window_size); -void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size); -tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size); - -tensor mlp(model_ref m, tensor x); -tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h); -tensor window_partition(model_ref m, tensor x, int window); -tensor window_reverse(model_ref m, tensor x, int w, int h, int window); -tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window); -tensor swin_block(model_ref m, tensor x, tensor mask, swin_block_params const&); -swin_layer_result swin_layer( - model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const&, int window_size); -swin_result swin_transformer(model_ref m, tensor x, swin_params const& p); +namespace visp::birefnet { // Encoder @@ -62,5 +24,4 @@ tensor image_to_patches(model_ref m, tensor x, int64_t out_w, int64_t out_h); tensor gdt_conv(model_ref m, tensor x); tensor decode(model_ref m, tensor x, swin_result const& features); -} // namespace birefnet -} // namespace visp \ No newline at end of file +} // namespace visp::birefnet \ No newline at end of file diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index fe65afa..43d915b 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -2,6 +2,7 @@ #include "util/math.h" #include "visp/ml.h" +#include "visp/vision.h" #include diff --git a/src/visp/arch/swin.cpp b/src/visp/arch/swin.cpp new file mode 100644 index 0000000..5d2caf0 --- /dev/null +++ b/src/visp/arch/swin.cpp @@ -0,0 +1,334 @@ +#include "visp/arch/swin.h" +#include "util/string.h" +#include "visp/nn.h" + +namespace visp { +namespace swin { + +tensor mlp(model_ref m, tensor x) { + x = linear(m["fc1"], x); + x = ggml_gelu_inplace(m, x); + x = linear(m["fc2"], x); + return named(m, x); +} + +// Ensures that the tensor's data is not overwritten during computation. +tensor make_constant(tensor x, tensor_name name) { + ggml_set_name(x, name.c_str()); + ggml_set_input(x); // allocate at the beginning of the graph buffer + ggml_set_output(x); // don't reuse memory for computations + return x; +} + +void compute_relative_position_index(span dst, int window_size) { + int n = window_size; + int n2 = n * n; + int n4 = n2 * n2; + for (int i = 0; i < n4; ++i) { + int x0 = i % n; + int y0 = (i / n) % n; + int x1 = (i / n2) % n; + int y1 = (i / n2 / n) % n; + dst[i] = (y1 - y0 + n - 1) * (2 * n - 1) + (x1 - x0 + n - 1); + } +} + +tensor_data create_relative_position_index(ggml_context* ctx, int window_size) { + int n = window_size; + auto result = tensor_alloc(ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n * n * n * n)); + auto name = format("window_attention_{}.rel_pos_index", n); + compute_relative_position_index(result.as_i32(), n); + make_constant(result.x, name); + return result; +} + +tensor window_partition(model_ref m, tensor x, int window) { + auto [c, w, h, b] = nelements(x); + ASSERT(w % window == 0 && h % window == 0, "Expecting padded input"); + + x = ggml_reshape_4d(m, x, c * window, w / window, window, (h / window) * b); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_3d(m, x, c, window * window, (w / window) * (h / window) * b); + return x; +} + +tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) { + int64_t c = x->ne[0]; + int64_t b = x->ne[2] / (w / window) / (h / window); + ASSERT(x->ne[2] % (w / window) == 0, "Expecting ne[2] to be multiple of window count"); + + x = ggml_reshape_4d(m, x, c * window, window, w / window, (h / window) * b); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(m, x, c, w, h, b); + return x; +} + +tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window) { + auto [c, n, b, _] = nelements(x); + + tensor qkv = linear(m["qkv"], x); + qkv = ggml_reshape_4d(m, qkv, c / num_heads, num_heads, 3, n * b); + qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); + + auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { + tensor = slice(m, tensor, {}, {}, {}, index); + tensor = ggml_reshape_4d(m, tensor, c / num_heads, num_heads, n, b); + if (transpose) { + tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); + } else { + tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); + } + return tensor; + }; + tensor q = split(qkv, 0); + tensor k = split(qkv, 1); + tensor v = split(qkv, 2, true); + + q = ggml_scale_inplace(m, q, 1.0f / std::sqrt(float(c / num_heads))); + + tensor attn = ggml_mul_mat(m, k, q); + + tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window); + tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str()); + tensor rel_pos_table = m.weights("relative_position_bias_table"); + tensor rel_pos_bias = ggml_get_rows(m, rel_pos_table, rel_pos_index); + rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, num_heads, window * window, window * window, 1); + rel_pos_bias = ggml_cont(m, ggml_permute(m, rel_pos_bias, 2, 0, 1, 3)); + attn = ggml_add_inplace(m, attn, rel_pos_bias); + + if (mask) { + int64_t nw = mask->ne[2]; + attn = ggml_reshape_4d(m, attn, n * n, num_heads, nw, b / nw); + mask = ggml_reshape_4d(m, mask, n * n, 1, nw, 1); + attn = ggml_add_inplace(m, attn, mask); + attn = ggml_reshape_4d(m, attn, n, n, num_heads, b); + } + attn = ggml_soft_max(m, attn); + + x = ggml_mul_mat(m, v, attn); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + x = ggml_reshape_3d(m, x, c, n, b); + + x = linear(m["proj"], x); + return named(m, x); +} + +tensor block(model_ref m, tensor x, tensor mask, block_params const& p) { + auto [c, n, b, _] = nelements(x); + auto [num_heads, window, w, h, shift] = p; + ASSERT(n == w * h && "Spatial dimensions do not match"); + + tensor shortcut = x; + x = layer_norm(m["norm1"], x); + x = ggml_reshape_4d(m, x, c, w, h, b); + + int pad_r = (window - w % window) % window; + int pad_b = (window - h % window) % window; + if (pad_r > 0 || pad_b > 0) { + x = ggml_pad(m, x, 0, pad_r, pad_b, 0); + } + + ASSERT(shift == 0 || mask != nullptr); + if (shift > 0) { + x = ggml_roll(m, x, 0, -shift, -shift, 0); + } + + x = window_partition(m, x, window); + x = window_attention(m["attn"], x, mask, num_heads, window); + x = window_reverse(m, x, w + pad_r, h + pad_b, window); + + if (shift > 0) { // undo shift + x = ggml_roll(m, x, 0, shift, shift, 0); + } + + if (pad_r > 0 || pad_b > 0) { // undo padding + x = ggml_reshape_4d(m, x, c, w + pad_r, h + pad_b, b); + x = slice(m, x, {}, {0, w}, {0, h}, {}); + x = ggml_cont(m, x); + } + + x = ggml_reshape_3d(m, x, c, n, b); + x = ggml_add_inplace(m, x, shortcut); + + tensor x_mlp = layer_norm(m["norm2"], x); + x_mlp = mlp(m["mlp"], x_mlp); + x = ggml_add_inplace(m, x, x_mlp); + + return named(m, x); +} + +tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h) { + auto [c, n, b, _] = nelements(x); + ASSERT(n == w * h, "Spatial dimensions do not match"); + ASSERT(w % 2 == 0 && h % 2 == 0, "Expecting even spatial dimensions"); + + x = ggml_reshape_4d(m, x, c, w, h, b); + // clang-format off + x = concat(m, { + slice(m, x, {}, {0, w, 2}, {0, h, 2}, {}), + slice(m, x, {}, {0, w, 2}, {1, h, 2}, {}), + slice(m, x, {}, {1, w, 2}, {0, h, 2}, {}), + slice(m, x, {}, {1, w, 2}, {1, h, 2}, {})}, 0); + // clang-format on + x = ggml_reshape_3d(m, x, c * 4, n / 4, b); + + x = layer_norm(m["norm"], x); + x = linear(m["reduction"], x); + return named(m, x); +} + +void compute_attention_mask(span out, int64_t w, int64_t h, int window_size) { + int n = window_size; + int n2 = n * n; + int n4 = n2 * n2; + int shift = window_size / 2; + int64_t nw_x = (w + n - 1) / n; + int64_t nw_y = (h + n - 1) / n; + int64_t w_pad = nw_x * n; + int64_t h_pad = nw_y * n; + + std::fill(out.begin(), out.end(), 0.0f); + + for (int iw_y = 0; iw_y < nw_y; ++iw_y) { + for (int iw_x = 0; iw_x < nw_x; ++iw_x) { + // Skip all windows that aren't at the right or bottom edges of the image + if (iw_y < nw_y - 1 && iw_x < nw_x - 1) { + continue; + } + int64_t base = iw_y * nw_x * n4 + iw_x * n4; + + for (int y0 = 0; y0 < n; ++y0) { + for (int x0 = 0; x0 < n; ++x0) { + for (int y1 = 0; y1 < n; ++y1) { + for (int x1 = 0; x1 < n; ++x1) { + // Window-local coordinates to global image coordinates + int yy0 = iw_y * n + y0; + int xx0 = iw_x * n + x0; + int yy1 = iw_y * n + y1; + int xx1 = iw_x * n + x1; + // Check if two patches being matched belong to the same window + // that is: they are both in the shift zone, or both outside + bool match_y = (yy0 < h_pad - shift) == (yy1 < h_pad - shift); + bool match_x = (xx0 < w_pad - shift) == (xx1 < w_pad - shift); + // If not, set mask to -100 (added to attention before softmax) + if (!match_y || !match_x) { + int64_t idx = base + (y0 * n + x0) * n2 + (y1 * n + x1); + out[idx] = -100.f; + } + } + } + } + } + } + } +} + +tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size) { + int n = window_size; + int64_t nw_x = (w + n - 1) / n; + int64_t nw_y = (h + n - 1) / n; + auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n * n, n * n, nw_x * nw_y)); + auto name = format("swin_layer_{}x{}.attn_mask", w, h); + compute_attention_mask(result.as_f32(), w, h, window_size); + make_constant(result.x, name); + return result; +} + +layer_result layer( + model_ref m, tensor x, int64_t w, int64_t h, swin_layer_t const& p, int window, bool down) { + // Attention masks need to be precomputed + tensor_name attn_mask_name = format("swin_layer_{}x{}.attn_mask", w, h); + tensor attn_mask = ggml_get_tensor(m, attn_mask_name.c_str()); + + model_ref blocks = m["blocks"]; + for (int i = 0; i < p.depth; ++i) { + x = block( + blocks[i], x, attn_mask, + {.n_heads = p.n_heads, + .window_size = window, + .w = w, + .h = h, + .shift = i % 2 == 0 ? 0 : window / 2}); + } + if (down) { + tensor x_down = patch_merging(m["downsample"], x, w, h); + return {x, w, h, x_down, (w + 1) / 2, (h + 1) / 2}; + } + return {x, w, h, x, w, h}; +} + +swin_result encode(model_ref m, tensor x, swin_params const& p) { + x = patch_embed(m["patch_embed"], x, 4); + + auto [c, w, h, b] = nelements(x); + x = ggml_reshape_3d(m, x, c, w * h, b); + + layer_result r{x, w, h, x, w, h}; + swin_result outs = {}; + + for (int i = 0; i < swin_n_layers; ++i) { + bool downsample = (i < swin_n_layers - 1); + r = layer( + m["layers"][i], r.x_down, r.w_down, r.h_down, p.layers[i], p.window_size, downsample); + + tensor_name norm_layer = format("norm{}", i); + tensor out = layer_norm(m[norm_layer], r.x_out); + out = ggml_reshape_4d(m, out, p.layers[i].n_features, r.w_out, r.h_out, b); + outs[i] = out; + } + return outs; +} + +} // namespace swin + +// clang-format off +const swin_params swin_t_params = { + .embed_dim = 96, + .window_size = 7, + .layers = { + // depth n_heads n_features + swin_layer_t{2, 3, 96 * 1}, + swin_layer_t{2, 6, 96 * 2}, + swin_layer_t{6, 12, 96 * 4}, + swin_layer_t{2, 24, 96 * 8}}}; + +const swin_params swin_l_params = { + .embed_dim = 192, + .window_size = 12, + .layers = { + // depth n_heads n_features + swin_layer_t{2, 6, 192 * 1}, + swin_layer_t{2, 12, 192 * 2}, + swin_layer_t{18, 24, 192 * 4}, + swin_layer_t{2, 48, 192 * 8}}}; +// clang-format on + +swin_params swin_detect_params(model_file const& f) { + int embed_dim = f.get_int("swin.embed_dim"); + if (embed_dim == 96) { + return swin_t_params; + } else if (embed_dim == 192) { + return swin_l_params; + } else { + throw except("Unsupported Swin Transformer embed dim: {}", embed_dim); + } +} + +swin_buffers swin_precompute(model_ref m, i32x2 image_extent, swin_params const& p) { + int w = p.window_size; + int width = image_extent[0] / 4; + int height = image_extent[1] / 4; + + swin_buffers b; + b[0] = swin::create_relative_position_index(m, w); + for (int i = 0; i < swin_n_layers + 1; ++i) { + b[i + 1] = swin::create_attention_mask(m, width >> i, height >> i, w); + } + return b; +} + +swin_result swin_encode(model_ref m, tensor image, swin_params const& p) { + return swin::encode(m, image, p); +} + +} // namespace visp \ No newline at end of file diff --git a/src/visp/arch/swin.h b/src/visp/arch/swin.h new file mode 100644 index 0000000..53a8033 --- /dev/null +++ b/src/visp/arch/swin.h @@ -0,0 +1,39 @@ +#pragma once + +#include "visp/ml.h" +#include "visp/vision.h" + +namespace visp::swin { + +struct block_params { + int n_heads = 6; + int window_size = 7; + int64_t w = 0; + int64_t h = 0; + int shift = 0; +}; + +struct layer_result { + tensor x_out; + int64_t w_out; + int64_t h_out; + tensor x_down; + int64_t w_down; + int64_t h_down; +}; + +void compute_relative_position_index(span dst, int window_size); +tensor_data create_relative_position_index(ggml_context* ctx, int window_size); +void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size); +tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size); + +tensor mlp(model_ref m, tensor x); +tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h); +tensor window_partition(model_ref m, tensor x, int window); +tensor window_reverse(model_ref m, tensor x, int w, int h, int window); +tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window); +tensor block(model_ref m, tensor x, tensor mask, block_params const&); +layer_result layer( + model_ref, tensor, int64_t w, int64_t h, swin_layer_t const&, int window_size, bool downsample); + +} // namespace visp::swin \ No newline at end of file diff --git a/tests/test_birefnet.py b/tests/test_birefnet.py index 353bb0d..3afa35d 100644 --- a/tests/test_birefnet.py +++ b/tests/test_birefnet.py @@ -740,8 +740,8 @@ def test_encode(): expected = forward_enc(x, xs, xs_low) state = {} - state.update({f"input{i}": to_nhwc(xs[i]) for i in range(4)}) - state.update({f"input_low{i}": to_nhwc(xs_low[i]) for i in range(4)}) + state.update({f"xs{i}": to_nhwc(xs[i]) for i in range(4)}) + state.update({f"xs_low{i}": to_nhwc(xs_low[i]) for i in range(4)}) results = workbench.invoke_test("biref_encode", x, state, nhwc_layout) diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 2b877b9..15d8b73 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -5,6 +5,7 @@ #include "visp/arch/esrgan.h" #include "visp/arch/migan.h" #include "visp/arch/mobile-sam.h" +#include "visp/arch/swin.h" #include "visp/nn.h" #include @@ -247,40 +248,40 @@ DEF(biref_patch_embed)(model_ref m, span input, param_dict const& p) { DEF(biref_relative_position_index)(model_ref m, span input, param_dict const& p) { auto dst = span(reinterpret_cast(input[0]->data), ggml_nelements(input[0])); - birefnet::compute_relative_position_index(dst, 3); + swin::compute_relative_position_index(dst, 3); return {input[0]}; } DEF(biref_window_attention)(model_ref m, span input, param_dict const& p) { int window_size = 3; tensor mask = m.find("mask"); - auto rel_pos_index = birefnet::create_relative_position_index(m, window_size); + auto rel_pos_index = swin::create_relative_position_index(m, window_size); ggml_backend_alloc_ctx_tensors(m, workbench_backend()); transfer_to_backend(rel_pos_index); - return {birefnet::window_attention(m, input[0], mask, 2, window_size)}; + return {swin::window_attention(m, input[0], mask, 2, window_size)}; } DEF(biref_swin_block)(model_ref m, span input, param_dict const& p) { - birefnet::swin_block_params block; + swin::block_params block; block.n_heads = 2; block.window_size = 3; block.w = 6; block.h = 6; block.shift = 0; tensor mask = m.find("mask"); - auto rel_pos_index = birefnet::create_relative_position_index(m, 3); + auto rel_pos_index = swin::create_relative_position_index(m, 3); ggml_backend_alloc_ctx_tensors(m, workbench_backend()); transfer_to_backend(rel_pos_index); - return {birefnet::swin_block(m, input[0], mask, block)}; + return {swin::block(m, input[0], mask, block)}; } DEF(biref_patch_merging)(model_ref m, span input, param_dict const& p) { - return {birefnet::patch_merging(m, input[0], 6, 4)}; + return {swin::patch_merging(m, input[0], 6, 4)}; } DEF(biref_attention_mask)(model_ref m, span input, param_dict const& p) { auto dst = span((float*)input[0]->data, ggml_nelements(input[0])); - birefnet::compute_attention_mask(dst, 18, 18, 6); + swin::compute_attention_mask(dst, 18, 18, 6); return {input[0]}; } @@ -289,13 +290,12 @@ DEF(biref_swin_layer)(model_ref m, span input, param_dict const& p) { layer.depth = 2; layer.n_heads = 2; layer.n_features = 8; - layer.downsample = true; - auto rel_pos_index = birefnet::create_relative_position_index(m, 3); - auto attn_mask = birefnet::create_attention_mask(m, 6, 6, 3); + auto rel_pos_index = swin::create_relative_position_index(m, 3); + auto attn_mask = swin::create_attention_mask(m, 6, 6, 3); ggml_backend_alloc_ctx_tensors(m, workbench_backend()); transfer_to_backend(rel_pos_index); transfer_to_backend(attn_mask); - auto result = birefnet::swin_layer(m, input[0], 6, 6, layer, 3); + auto result = swin::layer(m, input[0], 6, 6, layer, 3, true); ASSERT(result.w_down == 3 && result.h_down == 3); return {result.x_down}; } @@ -305,29 +305,29 @@ DEF(biref_swin_transformer)(model_ref m, span input, param_dict const& p .embed_dim = 8, .window_size = 3, .layers = { - swin_layer_t{2, 2, 8 * 1, true}, - swin_layer_t{2, 2, 8 * 2, true}, - swin_layer_t{2, 4, 8 * 4, true}, - swin_layer_t{2, 2, 8 * 8, false}, + swin_layer_t{2, 2, 8 * 1}, + swin_layer_t{2, 2, 8 * 2}, + swin_layer_t{2, 4, 8 * 4}, + swin_layer_t{2, 2, 8 * 8}, }}; - auto rel_pos_index = birefnet::create_relative_position_index(m, 3); + auto rel_pos_index = swin::create_relative_position_index(m, 3); auto attn_masks = std::array{ - birefnet::create_attention_mask(m, 8, 8, 3), birefnet::create_attention_mask(m, 4, 4, 3), - birefnet::create_attention_mask(m, 2, 2, 3), birefnet::create_attention_mask(m, 1, 1, 3)}; + swin::create_attention_mask(m, 8, 8, 3), swin::create_attention_mask(m, 4, 4, 3), + swin::create_attention_mask(m, 2, 2, 3), swin::create_attention_mask(m, 1, 1, 3)}; ggml_backend_alloc_ctx_tensors(m, workbench_backend()); transfer_to_backend(rel_pos_index); for (auto&& attn_mask : attn_masks) { transfer_to_backend(attn_mask); } - auto result = birefnet::swin_transformer(m, input[0], swinp); + auto result = swin_encode(m, input[0], swinp); return {result[0], result[1], result[2], result[3]}; } DEF(biref_encode)(model_ref m, span input, param_dict const& p) { - birefnet::swin_result xs, xs_low; + swin_result xs, xs_low; for (int i = 0; i < 4; ++i) { - xs[i] = m.find(format("input{}", i).c_str()); - xs_low[i] = m.find(format("input_low{}", i).c_str()); + xs[i] = m.find(format("xs{}", i).c_str()); + xs_low[i] = m.find(format("xs_low{}", i).c_str()); } birefnet::encode_concat(m, xs, xs_low); return std::vector{xs[0], xs[1], xs[2], xs[3]}; @@ -354,7 +354,7 @@ DEF(biref_image_to_patches_2)(model_ref m, span input, param_dict const& } DEF(biref_decode)(model_ref m, span input, param_dict const& p) { - birefnet::swin_result features; + swin_result features; for (int i = 0; i < 4; ++i) { features[i] = m.find(format("x{}", i + 1).c_str()); } From 01cfeb002ab1c6c73ef3d95858a8a65593c721ac Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 15 Oct 2025 19:15:39 +0200 Subject: [PATCH 20/24] birefnet: support flash attention in swin-v1 * redoing the measurements yields higher values than before for pytorch, not 100% sure if previous times were incorrect * comparison was against official BiRefNet repo, a3bb3efe2f824ec66644ca5941583c4c90c6e027 * tried with SDPA on/off --- README.md | 6 ++-- include/visp/ml.h | 2 ++ src/cli/cli.cpp | 9 +++-- src/visp/arch/swin.cpp | 82 +++++++++++++++++++++++------------------- src/visp/arch/swin.h | 2 +- src/visp/ml.cpp | 23 ++++++++++-- tests/test_birefnet.py | 12 +++++-- tests/workbench.cpp | 13 ++++--- tests/workbench.py | 1 + 9 files changed, 98 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 34a501a..2d02a2c 100644 --- a/README.md +++ b/README.md @@ -191,10 +191,10 @@ as other frameworks for inference speed, but with: | Model | | | _vision.cpp_ | PyTorch | ONNX Runtime | | :---- | :--- | :--- | -----------: | -------: | -----------: | -| Full | cpu | f32 | 16333 ms | 18800 ms | | -| Full | gpu | f16 | 243 ms | 140 ms | | +| Full | cpu | f32 | 16333 ms | 18290 ms | | +| Full | gpu | f16 | 208 ms | 190 ms | | | Lite | cpu | f32 | 4505 ms | 10900 ms | 6978 ms | -| Lite | gpu | f16 | 86 ms | 59 ms | | +| Lite | gpu | f16 | 85 ms | 84 ms | | #### MI-GAN, 512x512 diff --git a/include/visp/ml.h b/include/visp/ml.h index efb70e1..93a0af1 100644 --- a/include/visp/ml.h +++ b/include/visp/ml.h @@ -218,8 +218,10 @@ struct VISP_API tensor_data { span as_f32(); span as_i32(); + span as_bytes(); span as_f32() const; span as_i32() const; + span as_bytes() const; }; // Allocates data for a tensor in main memory, outside of context and backend buffers. diff --git a/src/cli/cli.cpp b/src/cli/cli.cpp index e7600f5..fc7f2a1 100644 --- a/src/cli/cli.cpp +++ b/src/cli/cli.cpp @@ -271,6 +271,11 @@ std::tuple load_model_weights( return {std::move(file), std::move(weights)}; } +void print_model_flags(model_ref const& m) { + bool flash_attn = !!(m.flags & model_build_flag::flash_attention); + printf("- flash attention: %s\n", flash_attn ? "on" : "off"); +} + void compute_timed(compute_graph const& g, backend_device const& b) { timer t; printf("Running inference... "); @@ -414,6 +419,7 @@ void run_birefnet(cli_args const& args) { compute_graph graph = compute_graph_init(6 * 1024); model_ref m(weights, graph); + print_model_flags(m); birefnet_buffers buffers = birefnet_precompute(m, params); tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); @@ -456,8 +462,7 @@ void run_depth_anything(cli_args const& args) { compute_graph graph = compute_graph_init(); model_ref m(weights, graph); - bool flash_attn = !!(m.flags & model_build_flag::flash_attention); - printf("- flash attention: %s\n", flash_attn ? "on" : "off"); + print_model_flags(m); tensor input = compute_graph_input(m, GGML_TYPE_F32, {3, extent[0], extent[1], 1}); tensor output = depthany_predict(m, input, params); diff --git a/src/visp/arch/swin.cpp b/src/visp/arch/swin.cpp index 5d2caf0..b46483d 100644 --- a/src/visp/arch/swin.cpp +++ b/src/visp/arch/swin.cpp @@ -63,52 +63,59 @@ tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) { return x; } -tensor window_attention(model_ref m, tensor x, tensor mask, int num_heads, int window) { +tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int window) { auto [c, n, b, _] = nelements(x); + float scale = 1.0f / std::sqrt(float(c / n_heads)); + bool flash_attn = bool(m.flags & model_build_flag::flash_attention); + ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, c / num_heads, num_heads, 3, n * b); + qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - auto split = [=](tensor tensor, size_t index, bool transpose = false) mutable { - tensor = slice(m, tensor, {}, {}, {}, index); - tensor = ggml_reshape_4d(m, tensor, c / num_heads, num_heads, n, b); - if (transpose) { - tensor = ggml_cont(m, ggml_permute(m, tensor, 1, 2, 0, 3)); - } else { - tensor = ggml_cont(m, ggml_permute(m, tensor, 0, 2, 1, 3)); - } - return tensor; + auto split = [=](tensor t, size_t index, ggml_type type, bool transpose = false) mutable { + t = slice(m, t, {}, {}, {}, index); + t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); + t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3); + t = ggml_cast(m, t, type); // TODO: future flash attention supports f32 and permutations + return t; }; - tensor q = split(qkv, 0); - tensor k = split(qkv, 1); - tensor v = split(qkv, 2, true); - - q = ggml_scale_inplace(m, q, 1.0f / std::sqrt(float(c / num_heads))); - - tensor attn = ggml_mul_mat(m, k, q); + tensor q = split(qkv, 0, GGML_TYPE_F32); + tensor k = split(qkv, 1, kv_type); + tensor v = split(qkv, 2, kv_type, !flash_attn); tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window); tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str()); tensor rel_pos_table = m.weights("relative_position_bias_table"); tensor rel_pos_bias = ggml_get_rows(m, rel_pos_table, rel_pos_index); - rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, num_heads, window * window, window * window, 1); - rel_pos_bias = ggml_cont(m, ggml_permute(m, rel_pos_bias, 2, 0, 1, 3)); - attn = ggml_add_inplace(m, attn, rel_pos_bias); + rel_pos_bias = ggml_reshape_4d(m, rel_pos_bias, n_heads, n, n, 1); + rel_pos_bias = ggml_permute(m, rel_pos_bias, 2, 0, 1, 3); // [n, n, n_heads, 1] + rel_pos_bias = ggml_cast(m, rel_pos_bias, GGML_TYPE_F16); // get_rows result is always f32 + tensor attn_mask = rel_pos_bias; if (mask) { - int64_t nw = mask->ne[2]; - attn = ggml_reshape_4d(m, attn, n * n, num_heads, nw, b / nw); - mask = ggml_reshape_4d(m, mask, n * n, 1, nw, 1); - attn = ggml_add_inplace(m, attn, mask); - attn = ggml_reshape_4d(m, attn, n, n, num_heads, b); + int64_t n_windows = mask->ne[2]; + if (b > n_windows) { // if there are multiple images in the batch + mask = ggml_reshape_4d(m, mask, n, n, n_windows, 1); + mask = ggml_repeat_4d(m, mask, n, n, n_windows, b / n_windows); + } + mask = ggml_reshape_4d(m, mask, n, n, 1, b); + mask = ggml_repeat_4d(m, mask, n, n, n_heads, b); // can only broadcast one operand in add + attn_mask = ggml_add(m, mask, attn_mask); // [n, n, n_heads, b] + [n, n, n_heads, 1] } - attn = ggml_soft_max(m, attn); - x = ggml_mul_mat(m, v, attn); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - x = ggml_reshape_3d(m, x, c, n, b); + if (flash_attn) { + x = ggml_flash_attn_ext(m, q, k, v, attn_mask, scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32); + } else { + tensor attn = ggml_mul_mat(m, k, q); + attn = ggml_soft_max_ext(m, attn, attn_mask, scale, 0.0f); + + x = ggml_mul_mat(m, v, attn); + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + } + x = ggml_reshape_3d(m, x, c, n, b); x = linear(m["proj"], x); return named(m, x); } @@ -177,7 +184,10 @@ tensor patch_merging(model_ref m, tensor x, int64_t w, int64_t h) { return named(m, x); } -void compute_attention_mask(span out, int64_t w, int64_t h, int window_size) { +constexpr uint16_t neg_inf_f16 = 0xfc00; // -infinity in IEEE 754 half-precision + +void compute_attention_mask(span out_bytes, int64_t w, int64_t h, int window_size) { + uint16_t* out = reinterpret_cast(out_bytes.data()); int n = window_size; int n2 = n * n; int n4 = n2 * n2; @@ -187,7 +197,7 @@ void compute_attention_mask(span out, int64_t w, int64_t h, int window_si int64_t w_pad = nw_x * n; int64_t h_pad = nw_y * n; - std::fill(out.begin(), out.end(), 0.0f); + std::memset(out, 0, out_bytes.size()); for (int iw_y = 0; iw_y < nw_y; ++iw_y) { for (int iw_x = 0; iw_x < nw_x; ++iw_x) { @@ -210,10 +220,10 @@ void compute_attention_mask(span out, int64_t w, int64_t h, int window_si // that is: they are both in the shift zone, or both outside bool match_y = (yy0 < h_pad - shift) == (yy1 < h_pad - shift); bool match_x = (xx0 < w_pad - shift) == (xx1 < w_pad - shift); - // If not, set mask to -100 (added to attention before softmax) + // If not, set attention mask to -inf so it is ignored by softmax if (!match_y || !match_x) { int64_t idx = base + (y0 * n + x0) * n2 + (y1 * n + x1); - out[idx] = -100.f; + out[idx] = neg_inf_f16; } } } @@ -227,9 +237,9 @@ tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int w int n = window_size; int64_t nw_x = (w + n - 1) / n; int64_t nw_y = (h + n - 1) / n; - auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n * n, n * n, nw_x * nw_y)); + auto result = tensor_alloc(ggml_new_tensor_3d(ctx, GGML_TYPE_F16, n * n, n * n, nw_x * nw_y)); auto name = format("swin_layer_{}x{}.attn_mask", w, h); - compute_attention_mask(result.as_f32(), w, h, window_size); + compute_attention_mask(result.as_bytes(), w, h, window_size); make_constant(result.x, name); return result; } diff --git a/src/visp/arch/swin.h b/src/visp/arch/swin.h index 53a8033..6b1195b 100644 --- a/src/visp/arch/swin.h +++ b/src/visp/arch/swin.h @@ -24,7 +24,7 @@ struct layer_result { void compute_relative_position_index(span dst, int window_size); tensor_data create_relative_position_index(ggml_context* ctx, int window_size); -void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size); +void compute_attention_mask(std::span out, int64_t w, int64_t h, int window_size); tensor_data create_attention_mask(ggml_context* ctx, int64_t w, int64_t h, int window_size); tensor mlp(model_ref m, tensor x); diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index 287ba7d..e926c69 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -138,13 +138,21 @@ void backend_set_n_threads(backend_device& b, int n_threads) { // // model_build_flags +model_build_flags flash_attn_flag() { + static model_build_flags const flag = []() { + char const* env = getenv("VISP_NO_FLASH_ATTENTION"); + return !env || env[0] == '0' ? model_build_flag::flash_attention : model_build_flags{}; + }(); + return flag; +} + model_build_flags backend_default_flags(backend_type type) { using enum model_build_flag; switch (type) { case backend_type::cpu: return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition | - flash_attention; - case backend_type::gpu: return flash_attention; + flash_attn_flag(); + case backend_type::gpu: return flash_attn_flag(); } return {}; } @@ -206,7 +214,8 @@ void model_file::get_array(char const* key_name, span out_values) const { throw except("Array size mismatch for key '{}' in model file {}", key_name, path); } if (gguf_get_arr_type(gguf.get(), key_id) != GGUF_TYPE_INT32) { - throw except("Array type mismatch for key '{}' in model file {}, expected int32", key_name, path); + throw except( + "Array type mismatch for key '{}' in model file {}, expected int32", key_name, path); } auto ptr = (int const*)gguf_get_arr_data(gguf.get(), key_id); std::copy(ptr, ptr + out_values.size(), out_values.data()); @@ -632,6 +641,14 @@ std::span tensor_data::as_i32() const { return span(reinterpret_cast(data.get()), ggml_nelements(x)); } +std::span tensor_data::as_bytes() { + return span(data.get(), ggml_nbytes(x)); +} + +std::span tensor_data::as_bytes() const { + return span(data.get(), ggml_nbytes(x)); +} + void transfer_to_backend(tensor_data const& d) { ggml_backend_tensor_set(d.x, d.data.get(), 0, ggml_nbytes(d.x)); } diff --git a/tests/test_birefnet.py b/tests/test_birefnet.py index 3afa35d..b57586a 100644 --- a/tests/test_birefnet.py +++ b/tests/test_birefnet.py @@ -118,7 +118,9 @@ def test_relative_position_index(): @pytest.mark.parametrize("masking", ["mask", "no_mask"]) -def test_window_attention(masking: bool): +@pytest.mark.parametrize("backend", ["cpu", "gpu"]) +@pytest.mark.parametrize("attn", ["default", "flash_attn"]) +def test_window_attention(masking: bool, backend: str, attn: str): num_heads = 2 window_attention = WindowAttention(dim=8, window_size=(3, 3), num_heads=num_heads) state = generate_state(window_attention.state_dict()) @@ -132,9 +134,13 @@ def test_window_attention(masking: bool): state["mask"] = mask expected = window_attention(x, mask) - result = workbench.invoke_test("biref_window_attention", x, state) + del state["relative_position_index"] # computed in C++ + if mask is not None: + state["mask"] = mask.half() + state["relative_position_bias_table"] = state["relative_position_bias_table"].half() + result = workbench.invoke_test("biref_window_attention", x, state, {"attn": attn}, backend) - assert torch.allclose(result, expected) + assert torch.allclose(result, expected, rtol=1e-3) def window_partition(x, window_size): diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 15d8b73..b3dc1d4 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -253,6 +253,11 @@ DEF(biref_relative_position_index)(model_ref m, span input, param_dict c } DEF(biref_window_attention)(model_ref m, span input, param_dict const& p) { + if (p.get("attn", "default") == "flash_attn"sv) { + m.flags = m.flags | model_build_flag::flash_attention; + } else { + m.flags = m.flags & ~model_build_flag::flash_attention; + } int window_size = 3; tensor mask = m.find("mask"); auto rel_pos_index = swin::create_relative_position_index(m, window_size); @@ -280,7 +285,7 @@ DEF(biref_patch_merging)(model_ref m, span input, param_dict const& p) { } DEF(biref_attention_mask)(model_ref m, span input, param_dict const& p) { - auto dst = span((float*)input[0]->data, ggml_nelements(input[0])); + auto dst = span((byte*)input[0]->data, ggml_nbytes(input[0])); swin::compute_attention_mask(dst, 18, 18, 6); return {input[0]}; } @@ -540,7 +545,7 @@ char const* param_dict::get(char const* name, char const* default_value) const { struct raw_tensor { char const* name; - float* data; + byte* data; int32_t type_; int32_t ne[4]; @@ -610,7 +615,7 @@ void workbench_run( model_allocate(weights, w.current_backend); for (raw_tensor const& raw : tensors) { - transfer_to_backend(m.weights(raw.name), span(raw.data, raw.size())); + transfer_to_backend(m.weights(raw.name), span(raw.data, raw.size_bytes())); } param_dict test_params = build_dict(params); @@ -644,7 +649,7 @@ void workbench_run( ggml_backend_tensor_get(outputs[i], data_ptr, 0, ggml_nbytes(outputs[i])); output_raw[i].name = ggml_get_name(outputs[i]); - output_raw[i].data = reinterpret_cast(data_ptr); + output_raw[i].data = reinterpret_cast(data_ptr); output_raw[i].type_ = int32_t(outputs[i]->type); output_raw[i].ne[0] = outputs[i]->ne[0]; output_raw[i].ne[1] = outputs[i]->ne[1]; diff --git a/tests/workbench.py b/tests/workbench.py index a0b51ae..0095fd0 100644 --- a/tests/workbench.py +++ b/tests/workbench.py @@ -32,6 +32,7 @@ class RawParam(ctypes.Structure): def torch_to_raw_tensor(name: str, tensor: torch.Tensor): tensor_types = { torch.float32: 0, # GGML_TYPE_F32 + torch.float16: 1, # GGML_TYPE_F16 torch.int32: 26, # GGML_TYPE_I32 } t = tensor.contiguous() From 195f2f08d6f2bbbe5ac341b419ff20f08fa18746 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 16 Oct 2025 11:41:09 +0200 Subject: [PATCH 21/24] depth-anything: documentation, readme, fix license for base/large models --- README.md | 25 +++++++++++++++++++------ include/visp/vision.h | 21 +++++++++++++++------ models/CMakeLists.txt | 7 +++++++ scripts/convert.py | 5 ++++- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 2d02a2c..99eaa94 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,17 @@ Based on [ggml](https://github.com/ggml-org/ggml) similar to the [llama.cpp](htt ### Features -| Model | Task | Backends | -| :-------------------------- | :--------------- | :---------- | -| [**MobileSAM**](#mobilesam) | Segmentation | CPU, Vulkan | -| [**BiRefNet**](#birefnet) | Segmentation | CPU, Vulkan | -| [**MI-GAN**](#mi-gan) | Inpainting | CPU, Vulkan | -| [**ESRGAN**](#real-esrgan) | Super-resolution | CPU, Vulkan | +| Model | Task | Backends | +| :--------------------------------------- | :----------------------- | :---------- | +| [**MobileSAM**](#mobilesam) | Promptable segmentation | CPU, Vulkan | +| [**BiRefNet**](#birefnet) | Dichotomous segmentation | CPU, Vulkan | +| [**Depth-Anything**](#depth-anything-v2) | Depth estimation | CPU, Vulkan | +| [**MI-GAN**](#mi-gan) | Inpainting | CPU, Vulkan | +| [**ESRGAN**](#real-esrgan) | Super-resolution | CPU, Vulkan | | [_Implement a model [**Guide**]_](docs/model-implementation-guide.md) | | | +**Backbones:** SWIN (v1), DINO (v2), TinyViT + ## Get Started Get the library and executables: @@ -92,6 +95,16 @@ vision-cli sam -m MobileSAM-F16.gguf -i input.png -p 300 200 -o mask.png --compo vision-cli birefnet -m BiRefNet-lite-F16.gguf -i input.png -o mask.png --composite comp.png ``` +#### Depth-Anything V2 + +example-depth-anything + +[Model download](https://huggingface.co/Acly/Depth-Anything-GGUF/tree/main) | [Paper (arXiv)](https://arxiv.org/abs/2406.09414) | [Repository (GitHub)](https://github.com/DepthAnything/Depth-Anything-V2) | License: Apache-2 / CC-BY-NC-4 + +```sh +vision-cli depth-anything -m Depth-Anything-V2-Small-F16.gguf -i input.png -o depth.png +``` + #### MI-GAN example-migan diff --git a/include/visp/vision.h b/include/visp/vision.h index 94d257a..1e22096 100644 --- a/include/visp/vision.h +++ b/include/visp/vision.h @@ -57,8 +57,9 @@ // 7. Run the compute graph. // 8. Transfer the output to the host and post-process it. // -// Custom pipelines are simply functions which call the individual steps and extend them -// where needed. The implementation of the high-level API functions is a good starting point. +// Custom pipelines can be created simply by writing a function that calls the +// individual steps. As a starting point, check out or copy the implementation +// of the high-level API functions. Then adapt them as needed. // This allows to: // * load model weights from a different source // * control exactly when allocation happens @@ -76,10 +77,11 @@ #include #include +#include namespace visp { -// SWIN - vision transformer for feature extraction +// SWIN v1 - vision transformer for feature extraction constexpr int swin_n_layers = 4; @@ -102,7 +104,7 @@ VISP_API swin_params swin_detect_params(model_file const&); VISP_API swin_buffers swin_precompute(model_ref, i32x2 image_extent, swin_params const&); VISP_API swin_result swin_encode(model_ref, tensor image, swin_params const&); -// DINO - vision transformer for feature extraction +// DINO v2 - vision transformer for feature extraction struct dino_params { int patch_size = 16; @@ -169,7 +171,9 @@ VISP_API image_data sam_process_mask( struct birefnet_model; // Loads a BiRefNet model from GGUF file onto the backend device. -// * supports BiRefNet, BiRefNet_lite, BiRefNet_Matting variants at 1024px resolution +// * supports BiRefNet, BiRefNet-lite, BiRefNet-Matting variants at 1024px resolution +// * supports BiRefNet-HR variant at 2048px resolution +// * supports BiRefNet-dynamic variant at arbitrary resolution VISP_API birefnet_model birefnet_load_model(char const* filepath, backend_device const&); // Takes RGB input and computes an alpha mask with foreground as 1.0 and background as 0.0. @@ -203,7 +207,12 @@ VISP_API tensor birefnet_predict(model_ref, tensor image, birefnet_params const& struct depthany_model; +// Loads a Depth Anything V2 model from GGUF file onto the backend device. +// * supports Small/Base/Large variants with flexible input resolution VISP_API depthany_model depthany_load_model(char const* filepath, backend_device const&); + +// Takes RGB input and computes estimated depth (distance from camera). +// Output is a single-channel float32 image in range [0, 1.0]. VISP_API image_data depthany_compute(depthany_model&, image_view image); // --- Depth Anything pipeline @@ -222,7 +231,7 @@ VISP_API i32x2 depthany_image_extent(i32x2 input_extent, depthany_params const&) VISP_API image_data depthany_process_input(image_view image, depthany_params const&); image_data depthany_process_output( - span output_data, i32x2 target_extent, depthany_params const&); + std::span output_data, i32x2 target_extent, depthany_params const&); VISP_API tensor depthany_predict(model_ref, tensor image, depthany_params const&); diff --git a/models/CMakeLists.txt b/models/CMakeLists.txt index d1afb96..a5ad052 100644 --- a/models/CMakeLists.txt +++ b/models/CMakeLists.txt @@ -14,6 +14,13 @@ file(DOWNLOAD EXPECTED_HASH "SHA256=7b5397a2c98d66677f8f74317774bbeac49dbb321b8a3dc744af913db71d4fa5" SHOW_PROGRESS ) +message(STATUS "Checking for models/Depth-Anything-V2-Small-F16.gguf") +file(DOWNLOAD + "https://huggingface.co/Acly/Depth-Anything-V2-GGUF/resolve/main/Depth-Anything-V2-Small-F16.gguf" + ${CMAKE_CURRENT_LIST_DIR}/Depth-Anything-V2-Small-F16.gguf + EXPECTED_HASH "SHA256=0f83332d6a8b4375cd7fdcc168f3e3636f474f8e84b0959e903f513aace782f5" + SHOW_PROGRESS +) message(STATUS "Checking for models/MIGAN-512-places2-F16.gguf") file(DOWNLOAD "https://huggingface.co/Acly/MIGAN-GGUF/resolve/main/MIGAN-512-places2-F16.gguf" diff --git a/scripts/convert.py b/scripts/convert.py index d99476c..cc91d63 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -354,7 +354,10 @@ def convert_birefnet(input_filepath: Path, writer: Writer): def convert_depth_anything(input_filepath: Path, writer: Writer): - writer.add_license("apache-2.0") + if "small" in input_filepath.name.lower(): + writer.add_license("apache-2.0") + else: + writer.add_license("cc-by-nc-4.0") writer.set_tensor_layout_default(TensorLayout.nchw) model: dict[str, Tensor] = load_model(input_filepath) From 254f7d5b8d8027b8d96cee84fae565e6a80f197b Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 16 Oct 2025 11:47:34 +0200 Subject: [PATCH 22/24] depth-anything: update image and model download url --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 99eaa94..d598afa 100644 --- a/README.md +++ b/README.md @@ -97,9 +97,9 @@ vision-cli birefnet -m BiRefNet-lite-F16.gguf -i input.png -o mask.png --composi #### Depth-Anything V2 -example-depth-anything +example-depth-anything -[Model download](https://huggingface.co/Acly/Depth-Anything-GGUF/tree/main) | [Paper (arXiv)](https://arxiv.org/abs/2406.09414) | [Repository (GitHub)](https://github.com/DepthAnything/Depth-Anything-V2) | License: Apache-2 / CC-BY-NC-4 +[Model download](https://huggingface.co/Acly/Depth-Anything-V2-GGUF/tree/main) | [Paper (arXiv)](https://arxiv.org/abs/2406.09414) | [Repository (GitHub)](https://github.com/DepthAnything/Depth-Anything-V2) | License: Apache-2 / CC-BY-NC-4 ```sh vision-cli depth-anything -m Depth-Anything-V2-Small-F16.gguf -i input.png -o depth.png From 57b1d0b8b83b64b8decf63f091dc63be47bb3f42 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 16 Oct 2025 18:26:16 +0200 Subject: [PATCH 23/24] ml: only use flash attention on gpu by default * env VISP_FLASH_ATTENTION=0 always disables it * env VISP_FLASH_ATTENTION=1 always enabled it * all other values use default --- src/visp/ml.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/visp/ml.cpp b/src/visp/ml.cpp index e926c69..ad5ae9e 100644 --- a/src/visp/ml.cpp +++ b/src/visp/ml.cpp @@ -138,12 +138,14 @@ void backend_set_n_threads(backend_device& b, int n_threads) { // // model_build_flags -model_build_flags flash_attn_flag() { - static model_build_flags const flag = []() { - char const* env = getenv("VISP_NO_FLASH_ATTENTION"); - return !env || env[0] == '0' ? model_build_flag::flash_attention : model_build_flags{}; - }(); - return flag; +model_build_flags flash_attn_flag(bool default_enabled) { + static char const* const env = getenv("VISP_FLASH_ATTENTION"); + if (env && env[0] == '1') { + return model_build_flag::flash_attention; + } else if (env && env[0] == '0') { + return model_build_flags{}; + } + return default_enabled ? model_build_flag::flash_attention : model_build_flags{}; } model_build_flags backend_default_flags(backend_type type) { @@ -151,8 +153,8 @@ model_build_flags backend_default_flags(backend_type type) { switch (type) { case backend_type::cpu: return conv_2d_direct_cwhn | concat_n | f16_conv_transpose | window_partition | - flash_attn_flag(); - case backend_type::gpu: return flash_attn_flag(); + flash_attn_flag(false); + case backend_type::gpu: return flash_attn_flag(true); } return {}; } From d381eaf731c1146f736e7e8da673a234deaa4d41 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 16 Oct 2025 18:26:31 +0200 Subject: [PATCH 24/24] depth-anything: fix benchmark and add numbers, bump version --- CMakeLists.txt | 4 ++-- README.md | 9 ++++++++- tests/benchmark.cpp | 6 +++--- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ef72053..e913dcf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.28) -project(vision.cpp VERSION 0.1.0 LANGUAGES CXX) +project(vision.cpp VERSION 0.2.0 LANGUAGES CXX) option(VISP_VULKAN "Enable Vulkan support" OFF) option(VISP_DEV "Enable development mode" OFF) @@ -30,7 +30,7 @@ elseif(CMAKE_BUILD_TYPE) endif() endif() -# Configure address sanitizer (Clang only) +# Configure address sanitizer if(VISP_ASAN) if(MSVC) diff --git a/README.md b/README.md index d598afa..12dc260 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,13 @@ as other frameworks for inference speed, but with: | Lite | cpu | f32 | 4505 ms | 10900 ms | 6978 ms | | Lite | gpu | f16 | 85 ms | 84 ms | | +#### Depth-Anything, 518x714 + +| Model | | | _vision.cpp_ | PyTorch | +| :---- | :--- | :--- | -----------: | ------: | +| Small | gpu | f16 | 11 ms | 10 ms | +| Base | gpu | f16 | 24 ms | 22 ms | + #### MI-GAN, 512x512 | Model | | | _vision.cpp_ | PyTorch | @@ -218,7 +225,7 @@ as other frameworks for inference speed, but with: #### Setup -* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench -m sam -b cpu` +* vision.cpp: using vision-bench, GPU via Vulkan, eg. `vision-bench -m sam` * PyTorch: v2.7.1+cu128, eager eval, GPU via CUDA, average n iterations after warm-up ## Dependencies (integrated) diff --git a/tests/benchmark.cpp b/tests/benchmark.cpp index 7d12ed2..d10bcfb 100644 --- a/tests/benchmark.cpp +++ b/tests/benchmark.cpp @@ -94,13 +94,13 @@ bench_timings benchmark_birefnet(path model_path, backend_device& backend) { } bench_timings benchmark_depth_anything(path model_path, backend_device& backend) { - path input_path = test_dir().input / "cat-and-hat.jpg"; + path input_path = test_dir().input / "wardrobe.jpg"; depthany_model model = depthany_load_model(model_path.string().c_str(), backend); image_data input = image_load(input_path.string().c_str()); - image_data input_data = depthany_process_input(input, model.params); - depthany_compute(model, input); + + image_data input_data = depthany_process_input(input, model.params); return run_benchmark(model.graph, backend, 12, {{model.input, input_data}}); }