From d6311096c755f8befa103b524d7342881f7d0a2a Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 8 Dec 2025 12:51:05 +0100 Subject: [PATCH 1/2] sam: enable flash attention --- src/visp/arch/mobile-sam.cpp | 38 ++++++++++++++++++--------- tests/test_mobile_sam.py | 51 ++++++++++++++++++------------------ tests/workbench.cpp | 5 ++++ 3 files changed, 55 insertions(+), 39 deletions(-) diff --git a/src/visp/arch/mobile-sam.cpp b/src/visp/arch/mobile-sam.cpp index 7beaef4..b8e9b24 100644 --- a/src/visp/arch/mobile-sam.cpp +++ b/src/visp/arch/mobile-sam.cpp @@ -1,8 +1,8 @@ #include "visp/arch/mobile-sam.h" -#include "visp/nn.h" -#include "visp/vision.h" #include "util/math.h" #include "util/string.h" +#include "visp/nn.h" +#include "visp/vision.h" #include @@ -13,7 +13,7 @@ namespace visp { namespace sam { tensor conv_2d_batch_norm(model_ref m, tensor x, int stride = 1, int pad = 0) { - // batch_norm is fused into conv_2d when converting the model + // batch_norm is fused into conv_2d when converting the model return conv_2d(m["c"], x, stride, pad); } @@ -68,7 +68,6 @@ tensor window_reverse(model_ref m, tensor x, int w, int h, int window) { // Image encoder // - tensor patch_embed(model_ref m, tensor x) { x = conv_2d_batch_norm(m["seq.0"], x, 2, 1); x = ggml_gelu_inplace(m, x); @@ -142,17 +141,30 @@ tensor attention_rel_bias(model_ref m, tensor x, int dim, int num_heads) { tensor q = split(m, qkv, 0); tensor k = split(m, qkv, 1); tensor v = split(m, qkv, 2); - q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3)); - k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3)); - v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later + tensor mask = m.weights("attention_biases_indexed"); + float scale = 1.0f / std::sqrt(float(key_dim)); + + if (m.flags & model_build_flag::flash_attention) { + q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3)); + k = ggml_cast(m, ggml_permute(m, k, 0, 2, 1, 3), GGML_TYPE_F16); + v = ggml_cast(m, ggml_permute(m, v, 0, 2, 1, 3), GGML_TYPE_F16); + if (mask->type != GGML_TYPE_F16) { + mask = ggml_cast(m, mask, GGML_TYPE_F16); + } - tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat) - attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(key_dim))); - attn = ggml_add_inplace(m, attn, m.weights("attention_biases_indexed")); - attn = ggml_soft_max(m, attn); + x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32); + } else { + q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3)); + k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3)); + v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later - x = ggml_mul_mat(m, v, attn); // attn @ v - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2) + tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat) + attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f); + + x = ggml_mul_mat(m, v, attn); // attn @ v + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2) + } x = ggml_reshape_3d(m, x, key_dim * num_heads, n, b); x = linear(m["proj"], x); diff --git a/tests/test_mobile_sam.py b/tests/test_mobile_sam.py index e46b7ec..77c0fbf 100644 --- a/tests/test_mobile_sam.py +++ b/tests/test_mobile_sam.py @@ -6,7 +6,7 @@ from torch import Tensor from . import workbench -from .workbench import to_nhwc, to_nchw, convert_to_nhwc, fuse_conv_2d_batch_norm +from .workbench import to_nhwc, to_nchw, convert_to_nhwc, fuse_conv_2d_batch_norm, tensors_match torch.set_printoptions(precision=2, linewidth=100, sci_mode=False) @@ -53,7 +53,7 @@ def test_conv_2d_batch_norm(bias: bool): result = workbench.invoke_test("sam_conv_2d_batch_norm", x, state, nhwc_layout) result = to_nchw(result) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) class PatchEmbed(torch.nn.Module): @@ -98,7 +98,7 @@ def test_patch_embed(): result = workbench.invoke_test("sam_patch_embed", x, state, nhwc_layout) result = to_nchw(result) - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class LayerNorm2d(torch.nn.Module): @@ -130,7 +130,7 @@ def test_layer_norm_2d(): result = workbench.invoke_test("layer_norm", x, state, nhwc_layout) result = to_nchw(result) - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class MBConv(torch.nn.Module): @@ -193,7 +193,7 @@ def test_mb_conv(): result = to_nchw(result) # precision: ggml_gelu uses fp16 look-up table & tanh approximation - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class PatchMerging(torch.nn.Module): @@ -244,7 +244,7 @@ def test_patch_merging(): result = result.transpose(1, 2).reshape_as(expected) # precision: ggml_gelu uses fp16 look-up table & tanh approximation - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class Mlp(torch.nn.Module): @@ -288,7 +288,7 @@ def test_mlp(): result = workbench.invoke_test("sam_mlp", x, state) # precision: ggml_gelu uses fp16 look-up table & tanh approximation - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class AttentionRelBias(torch.nn.Module): @@ -370,8 +370,8 @@ def forward(self, x): # x (B,N,C) x = self.proj(x) return x - -def test_attention_rel_bias(): +@pytest.mark.parametrize("attn", ["default", "flash_attn"]) +def test_attention_rel_bias(attn:str): attention = AttentionRelBias(4, 2, num_heads=2, attn_ratio=1, resolution=(3, 3)) state = workbench.randomize(attention.state_dict()) attention.load_state_dict(state) @@ -381,9 +381,9 @@ def test_attention_rel_bias(): expected = attention(x) state["attention_biases_indexed"] = state["attention_biases"][:, attention.attention_bias_idxs] - result = workbench.invoke_test("sam_attention_rel_bias", x, state) + result = workbench.invoke_test("sam_attention_rel_bias", x, state, {"attn": attn}) - assert torch.allclose(result, expected, atol=0.001) + assert tensors_match(result, expected, atol=0.001) class TinyViTBlock(torch.nn.Module): @@ -495,7 +495,7 @@ def test_tiny_vit_block(): state = convert_to_nhwc(state) result = workbench.invoke_test("sam_tiny_vit_block", x, state, nhwc_layout) - assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + assert tensors_match(result, expected, rtol=0.001, atol=0.02) class ConvLayer(torch.nn.Module): @@ -787,7 +787,7 @@ def test_tiny_vit(): # result = torch.zeros_like(expected).contiguous() # result = workbench.invoke_test("sam_tiny_vit", x, state) - # assert torch.allclose(result, expected, rtol=0.001, atol=0.02) + # assert tensors_match(result, expected, rtol=0.001, atol=0.02) # @@ -835,7 +835,7 @@ def test_position_embedding_random(): result = workbench.invoke_test("sam_position_embedding_random", x, state) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) class PromptEncoder(torch.nn.Module): @@ -951,7 +951,7 @@ def test_prompt_encoder_points(): points = torch.cat([points, -torch.ones(1, 1, 2)], dim=1) result = workbench.invoke_test("sam_embed_points", points, state) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) def test_prompt_encoder_box(): @@ -970,7 +970,7 @@ def test_prompt_encoder_box(): result = workbench.invoke_test("sam_embed_box", boxes, state) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) # @@ -1046,7 +1046,7 @@ def test_attention(): state["input_v"] = v result = workbench.invoke_test("sam_attention", q, state) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) class MLPBlock(torch.nn.Module): @@ -1155,8 +1155,8 @@ def test_two_way_attention_block(mode): "sam_two_way_attention_block", queries, state, {"mode": mode} ) - assert torch.allclose(result_queries, expected_queries) - assert torch.allclose(result_keys, expected_keys) + assert tensors_match(result_queries, expected_queries) + assert tensors_match(result_keys, expected_keys) class TwoWayTransformer(torch.nn.Module): @@ -1257,8 +1257,8 @@ def test_two_way_transformer(): "sam_two_way_transformer", image_embedding, state, nhwc_layout ) - assert torch.allclose(result_queries, expected_queries, atol=1e-6, rtol=1e-4) - assert torch.allclose(result_keys, expected_keys, atol=1e-6, rtol=1e-4) + assert tensors_match(result_queries, expected_queries, atol=1e-6, rtol=1e-4) + assert tensors_match(result_keys, expected_keys, atol=1e-6, rtol=1e-4) class HypernetworkMLP(torch.nn.Module): @@ -1297,7 +1297,7 @@ def test_hypernetwork_mlp(): result = workbench.invoke_test("sam_hypernetwork_mlp", x, state) - assert torch.allclose(result, expected) + assert tensors_match(result, expected) def output_upscaling(transformer_dim: int, activation=torch.nn.GELU): @@ -1325,8 +1325,7 @@ def test_output_upscaling(): result = workbench.invoke_test("sam_output_upscaling", x, state, nhwc_layout, backend="vulkan") result = to_nchw(result) - workbench.print_results(result, expected) - assert torch.allclose(result, expected, rtol=0.1) # fp16 weights + assert tensors_match(result, expected, rtol=0.1) # fp16 weights class MaskDecoder(torch.nn.Module): @@ -1465,5 +1464,5 @@ def test_predict_masks(): "sam_predict_masks", image_embeddings, state, nhwc_layout, backend="vulkan" ) - assert torch.allclose(result_masks, expected_masks, rtol=1e-2, atol=1e-2) - assert torch.allclose(result_iou_pred, iou_pred, rtol=1e-2) + assert tensors_match(result_masks, expected_masks, rtol=1e-2, atol=1e-2) + assert tensors_match(result_iou_pred, iou_pred, rtol=1e-2) diff --git a/tests/workbench.cpp b/tests/workbench.cpp index d8ff24e..378785d 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -154,6 +154,11 @@ DEF(sam_mlp)(model_ref m, span input, param_dict const& p) { } DEF(sam_attention_rel_bias)(model_ref m, span input, param_dict const& p) { + if (p.get("attn", "default") == "flash_attn"sv) { + m.flags = m.flags | model_build_flag::flash_attention; + } else { + m.flags = m.flags & ~model_build_flag::flash_attention; + } return {sam::attention_rel_bias(m, input[0], 4, 2)}; } From b8bb09daceb9eecaf8df1b525bf9be82255e5b44 Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 8 Dec 2025 18:18:56 +0100 Subject: [PATCH 2/2] nn: share attention code across sam, swin and dino --- src/visp/arch/dino.cpp | 39 ++++++----------- src/visp/arch/dino.h | 2 +- src/visp/arch/mobile-sam.cpp | 83 ++++++++---------------------------- src/visp/arch/mobile-sam.h | 2 +- src/visp/arch/swin.cpp | 33 ++------------ src/visp/nn.cpp | 68 ++++++++++++++++++++++++++++- src/visp/nn.h | 11 +++++ tests/workbench.cpp | 4 +- 8 files changed, 114 insertions(+), 128 deletions(-) diff --git a/src/visp/arch/dino.cpp b/src/visp/arch/dino.cpp index a1717c4..ed453da 100644 --- a/src/visp/arch/dino.cpp +++ b/src/visp/arch/dino.cpp @@ -55,42 +55,27 @@ tensor mlp(model_ref m, tensor x) { return x; } -tensor attention(model_ref m, tensor x, int n_heads) { +tensor self_attention(model_ref m, tensor x, int n_heads) { auto [c, n, b, _] = nelements(x); - float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); - bool flash_attn = bool(m.flags & model_build_flag::flash_attention); - ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; - - auto split = [=](model_ref m, tensor x, ggml_type type, bool transpose = false) mutable { - x = linear(m, x); - x = ggml_reshape_4d(m, x, c / n_heads, n_heads, n, b); - x = transpose ? ggml_permute(m, x, 1, 2, 0, 3) : ggml_permute(m, x, 0, 2, 1, 3); - return ggml_cast(m, x, type); + auto project = [&](model_ref m, tensor t) { + t = linear(m, t); + t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); + return t; }; - tensor q = split(m["attention.query"], x, GGML_TYPE_F32); - tensor k = split(m["attention.key"], x, kv_type); - tensor v = split(m["attention.value"], x, kv_type, !flash_attn); - - if (flash_attn) { - x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f); - } else { - tensor attn = ggml_mul_mat(m, k, q); - attn = ggml_soft_max_ext(m, attn, nullptr, scale, 0.0f); - - x = ggml_mul_mat(m, v, attn); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - } + tensor q = project(m["attention.query"], x); + tensor k = project(m["attention.key"], x); + tensor v = project(m["attention.value"], x); - x = ggml_reshape_3d(m, x, c, n, b); - x = linear(m["output.dense"], x); - return named(m, x); + float scale = 1.0f / std::sqrt(float(c) / float(n_heads)); + x = attention(m, q, k, v, nullptr, scale, m["output.dense"]); + return x; } tensor layer(model_ref m, tensor x, dino_params const& p) { tensor attn = x; attn = layer_norm(m["norm1"], attn, 1e-6f); - attn = attention(m["attention"], attn, p.n_heads); + attn = self_attention(m["attention"], attn, p.n_heads); attn = layer_scale(m["layer_scale1"], attn); x = ggml_add(m, x, attn); diff --git a/src/visp/arch/dino.h b/src/visp/arch/dino.h index 43d915b..5e46061 100644 --- a/src/visp/arch/dino.h +++ b/src/visp/arch/dino.h @@ -12,7 +12,7 @@ tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int tensor prepare_tokens(model_ref m, tensor x, int patch_size); tensor layer_scale(model_ref m, tensor x); tensor mlp(model_ref m, tensor x); -tensor attention(model_ref m, tensor x, int n_heads); +tensor self_attention(model_ref m, tensor x, int n_heads); tensor layer(model_ref m, tensor x, dino_params const& p); std::vector get_intermediate_layers( diff --git a/src/visp/arch/mobile-sam.cpp b/src/visp/arch/mobile-sam.cpp index b8e9b24..9e544e9 100644 --- a/src/visp/arch/mobile-sam.cpp +++ b/src/visp/arch/mobile-sam.cpp @@ -121,54 +121,14 @@ tensor mlp(model_ref m, tensor x) { return named(m, x); } -tensor attention_rel_bias(model_ref m, tensor x, int dim, int num_heads) { - GGML_ASSERT(dim % num_heads == 0); - int key_dim = dim / num_heads; - auto [c, n, b, _] = nelements(x); - - x = layer_norm(m["norm"], x); - - tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, key_dim, 3, num_heads * n, b); - qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2)); // ne = [key_dim, num_heads * n, b, 3] - - auto split = [=](model_ref m, tensor tensor, int64_t index) { - tensor = slice(m, tensor, {}, {}, {}, index); - tensor = ggml_reshape_4d(m, tensor, key_dim, num_heads, n, b); - return tensor; - }; - - tensor q = split(m, qkv, 0); - tensor k = split(m, qkv, 1); - tensor v = split(m, qkv, 2); +tensor attention_rel_bias(model_ref m, tensor x, int dim, int n_heads) { + float scale = 1.0f / std::sqrt(float(dim / n_heads)); tensor mask = m.weights("attention_biases_indexed"); - float scale = 1.0f / std::sqrt(float(key_dim)); - - if (m.flags & model_build_flag::flash_attention) { - q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3)); - k = ggml_cast(m, ggml_permute(m, k, 0, 2, 1, 3), GGML_TYPE_F16); - v = ggml_cast(m, ggml_permute(m, v, 0, 2, 1, 3), GGML_TYPE_F16); - if (mask->type != GGML_TYPE_F16) { - mask = ggml_cast(m, mask, GGML_TYPE_F16); - } - - x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f); - ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32); - } else { - q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3)); - k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3)); - v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later - tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat) - attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f); - - x = ggml_mul_mat(m, v, attn); // attn @ v - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2) - } - x = ggml_reshape_3d(m, x, key_dim * num_heads, n, b); - x = linear(m["proj"], x); - - return named(m, x); + x = layer_norm(m["norm"], x); + auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 1); + x = attention(m, q, k, v, mask, scale, m["proj"]); + return x; } tensor tiny_vit_block( @@ -344,25 +304,18 @@ tensor separate_attention_heads(model_ref m, tensor x, int num_heads) { return x; } -tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads) { +tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int n_heads) { q = linear(m["q_proj"], q); k = linear(m["k_proj"], k); v = linear(m["v_proj"], v); - q = separate_attention_heads(m, q, num_heads); - k = separate_attention_heads(m, k, num_heads); - v = ggml_reshape_4d(m, v, v->ne[0] / num_heads, num_heads, v->ne[1], v->ne[2]); - v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // already transposed for mul_mat + q = ggml_reshape_4d(m, q, q->ne[0] / n_heads, n_heads, q->ne[1], q->ne[2]); + k = ggml_reshape_4d(m, k, k->ne[0] / n_heads, n_heads, k->ne[1], k->ne[2]); + v = ggml_reshape_4d(m, v, v->ne[0] / n_heads, n_heads, v->ne[1], v->ne[2]); - tensor attn = ggml_mul_mat(m, k, q); - attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(q->ne[0]))); - attn = ggml_soft_max(m, attn); - - tensor out = ggml_mul_mat(m, v, attn); - out = ggml_cont(m, ggml_permute(m, out, 0, 2, 1, 3)); - out = ggml_reshape_3d(m, out, out->ne[0] * out->ne[1], out->ne[2], out->ne[3]); - out = linear(m["out_proj"], out); - return out; + float scale = 1.0f / std::sqrt(float(q->ne[0])); + tensor x = attention(m, q, k, v, nullptr, scale, m["out_proj"]); + return x; } auto two_way_attention_block( @@ -375,10 +328,10 @@ auto two_way_attention_block( bool skip_first_layer_pe) -> std::tuple { // Self attention block if (skip_first_layer_pe) { - queries = attention(m["self_attn"], queries, queries, queries, num_heads); + queries = decoder_attention(m["self_attn"], queries, queries, queries, num_heads); } else { tensor q = ggml_add(m, queries, query_pe); - tensor attn_out = attention(m["self_attn"], q, q, queries, num_heads); + tensor attn_out = decoder_attention(m["self_attn"], q, q, queries, num_heads); queries = ggml_add(m, queries, attn_out); } queries = layer_norm(m["norm1"], queries); @@ -386,7 +339,7 @@ auto two_way_attention_block( // Cross attention block, tokens attending to image embedding tensor q = ggml_add(m, queries, query_pe); tensor k = ggml_add(m, keys, key_pe); - tensor attn_out = attention(m["cross_attn_t2i"], q, k, keys, num_heads); + tensor attn_out = decoder_attention(m["cross_attn_t2i"], q, k, keys, num_heads); queries = ggml_add_inplace(m, queries, attn_out); queries = layer_norm(m["norm2"], queries); @@ -401,7 +354,7 @@ auto two_way_attention_block( // Cross attention block, image embedding attending to tokens q = ggml_add(m, queries, query_pe); // k = ggml_add(m, keys, key_pe); // redundant, same as above - attn_out = attention(m["cross_attn_i2t"], k, q, queries, num_heads); + attn_out = decoder_attention(m["cross_attn_i2t"], k, q, queries, num_heads); keys = ggml_add_inplace(m, keys, attn_out); keys = layer_norm(m["norm4"], keys); @@ -434,7 +387,7 @@ auto two_way_transformer( // Apply the final attention layer from the points to the image tensor q = ggml_add(m, queries, point_embedding); tensor k = ggml_add(m, keys, image_pe); - tensor attn_out = attention(m["final_attn_t2i"], q, k, keys, num_heads); + tensor attn_out = decoder_attention(m["final_attn_t2i"], q, k, keys, num_heads); queries = ggml_add_inplace(m, queries, attn_out); queries = layer_norm(m["norm_final_attn"], queries); diff --git a/src/visp/arch/mobile-sam.h b/src/visp/arch/mobile-sam.h index 6e38868..e4d7881 100644 --- a/src/visp/arch/mobile-sam.h +++ b/src/visp/arch/mobile-sam.h @@ -66,7 +66,7 @@ tensor position_embedding_random(model_ref m, tensor coords); tensor mlp_block(model_ref m, tensor x); tensor separate_attention_heads(model_ref m, tensor x, int num_heads); -tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads); +tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int num_heads); std::tuple two_way_attention_block( model_ref m, tensor queries, diff --git a/src/visp/arch/swin.cpp b/src/visp/arch/swin.cpp index b46483d..1a83bcc 100644 --- a/src/visp/arch/swin.cpp +++ b/src/visp/arch/swin.cpp @@ -65,24 +65,6 @@ tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) { tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int window) { auto [c, n, b, _] = nelements(x); - float scale = 1.0f / std::sqrt(float(c / n_heads)); - bool flash_attn = bool(m.flags & model_build_flag::flash_attention); - ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; - - tensor qkv = linear(m["qkv"], x); - qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); - qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); - - auto split = [=](tensor t, size_t index, ggml_type type, bool transpose = false) mutable { - t = slice(m, t, {}, {}, {}, index); - t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); - t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3); - t = ggml_cast(m, t, type); // TODO: future flash attention supports f32 and permutations - return t; - }; - tensor q = split(qkv, 0, GGML_TYPE_F32); - tensor k = split(qkv, 1, kv_type); - tensor v = split(qkv, 2, kv_type, !flash_attn); tensor_name rel_pos_name = format("window_attention_{}.rel_pos_index", window); tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str()); @@ -104,19 +86,10 @@ tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int win attn_mask = ggml_add(m, mask, attn_mask); // [n, n, n_heads, b] + [n, n, n_heads, 1] } - if (flash_attn) { - x = ggml_flash_attn_ext(m, q, k, v, attn_mask, scale, 0.0f, 0.0f); - ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32); - } else { - tensor attn = ggml_mul_mat(m, k, q); - attn = ggml_soft_max_ext(m, attn, attn_mask, scale, 0.0f); - - x = ggml_mul_mat(m, v, attn); - x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); - } + auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 2); + float scale = 1.0f / std::sqrt(float(c / n_heads)); + x = attention(m, q, k, v, attn_mask, scale, m["proj"]); - x = ggml_reshape_3d(m, x, c, n, b); - x = linear(m["proj"], x); return named(m, x); } diff --git a/src/visp/nn.cpp b/src/visp/nn.cpp index 6d3268c..ed54b66 100644 --- a/src/visp/nn.cpp +++ b/src/visp/nn.cpp @@ -80,7 +80,7 @@ tensor conv_2d(model_ref m, tensor x, int stride, int pad) { x = ggml_mul_mat(m, weight, x); x = ggml_reshape_4d(m, x, weight->ne[1], w, h, b); - } else if (m.flags & model_build_flag::conv_2d_direct_cwhn) { + } else if (m.flags & model_build_flag::conv_2d_direct_cwhn) { weight = permute_cwhn_to_whcn(m, weight); x = permute_cwhn_to_whcn(m, x); x = ggml_conv_2d_direct(m, weight, x, stride, stride, pad, pad, 1, 1); @@ -144,7 +144,7 @@ tensor conv_2d_deform( } } x = ggml_conv_2d_deform(m, weight, x, offset, mask, stride, stride, pad, pad); - + if (m.flags & model_build_flag::cwhn) { x = permute_whcn_to_cwhn(m, x); } @@ -183,4 +183,68 @@ tensor patch_embed(model_ref m, tensor x, int patch_size) { return named(m, x); } +attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim) { + auto [c, n, b, _] = nelements(x); + + tensor qkv = linear(m, x); + switch (split_dim) { + case 1: + qkv = ggml_reshape_4d(m, qkv, c / n_heads, 3, n_heads * n, b); + qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2)); + break; + case 2: + qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b); + qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2)); + break; + default: ASSERT(false, "Unsupported split_dim"); + } + + auto split = [&](tensor t, size_t index) mutable { + t = slice(m, t, {}, {}, {}, index); + t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b); + return t; + }; + + tensor q = split(qkv, 0); + tensor k = split(qkv, 1); + tensor v = split(qkv, 2); + return {q, k, v}; +} + +tensor attention( + model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out) { + + q = ggml_permute(m, q, 0, 2, 1, 3); + k = ggml_permute(m, k, 0, 2, 1, 3); + + tensor x = nullptr; + if (m.flags & model_build_flag::flash_attention) { + v = ggml_permute(m, v, 0, 2, 1, 3); + + k = ggml_cast(m, k, GGML_TYPE_F16); + v = ggml_cast(m, v, GGML_TYPE_F16); + if (mask && mask->type != GGML_TYPE_F16) { + mask = ggml_cast(m, mask, GGML_TYPE_F16); + } + + x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f); + ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32); + + } else { + v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); + + tensor attn = ggml_mul_mat(m, k, q); + attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f); + x = ggml_mul_mat(m, v, attn); + + x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); + } + + // [head_dim, n_heads, n_patches, batch] -> [embed_dim, n_patches, batch] + x = ggml_reshape_3d(m, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]); + x = linear(m_out, x); + + return named(m, x); +} + } // namespace visp \ No newline at end of file diff --git a/src/visp/nn.h b/src/visp/nn.h index 9b7e762..bb682ab 100644 --- a/src/visp/nn.h +++ b/src/visp/nn.h @@ -41,4 +41,15 @@ tensor batch_norm_2d(model_ref, tensor x); // 2D image to patch embedding using convolution and optional norm. CWHN input and output. tensor patch_embed(model_ref, tensor x, int patch_size); +struct attention_qkv { + tensor q, k, v; +}; +// Input: x [head_dim*n_heads, n_patches, batch] +// Output: q, k, v each of shape [head_dim, n_heads, n_patches, batch] +attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim); + +// Attention with optional mask and output linear layer. +tensor attention( + model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out); + } // namespace visp diff --git a/tests/workbench.cpp b/tests/workbench.cpp index 378785d..5ed0b83 100644 --- a/tests/workbench.cpp +++ b/tests/workbench.cpp @@ -205,7 +205,7 @@ DEF(sam_attention)(model_ref m, span input, param_dict const& p) { tensor q = input[0]; tensor k = m.weights("input_k"); tensor v = m.weights("input_v"); - return {sam::attention(m, q, k, v, 2)}; + return {sam::decoder_attention(m, q, k, v, 2)}; } DEF(sam_two_way_attention_block)(model_ref m, span input, param_dict const& p) { @@ -443,7 +443,7 @@ DEF(dino_attention)(model_ref m, span input, param_dict const& p) { if (p.get("flash_attn", 0) != 0) { m.flags |= model_build_flag::flash_attention; } - return {dino::attention(m, input[0], p.get("n_heads", 8))}; + return {dino::self_attention(m, input[0], p.get("n_heads", 8))}; } DEF(dino_block)(model_ref m, span input, param_dict const& p) {