Skip to content

Commit d631109

Browse files
committed
sam: enable flash attention
1 parent b766b0a commit d631109

File tree

3 files changed

+55
-39
lines changed

3 files changed

+55
-39
lines changed

src/visp/arch/mobile-sam.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "visp/arch/mobile-sam.h"
2-
#include "visp/nn.h"
3-
#include "visp/vision.h"
42
#include "util/math.h"
53
#include "util/string.h"
4+
#include "visp/nn.h"
5+
#include "visp/vision.h"
66

77
#include <ggml.h>
88

@@ -13,7 +13,7 @@ namespace visp {
1313
namespace sam {
1414

1515
tensor conv_2d_batch_norm(model_ref m, tensor x, int stride = 1, int pad = 0) {
16-
// batch_norm is fused into conv_2d when converting the model
16+
// batch_norm is fused into conv_2d when converting the model
1717
return conv_2d(m["c"], x, stride, pad);
1818
}
1919

@@ -68,7 +68,6 @@ tensor window_reverse(model_ref m, tensor x, int w, int h, int window) {
6868
// Image encoder
6969
//
7070

71-
7271
tensor patch_embed(model_ref m, tensor x) {
7372
x = conv_2d_batch_norm(m["seq.0"], x, 2, 1);
7473
x = ggml_gelu_inplace(m, x);
@@ -142,17 +141,30 @@ tensor attention_rel_bias(model_ref m, tensor x, int dim, int num_heads) {
142141
tensor q = split(m, qkv, 0);
143142
tensor k = split(m, qkv, 1);
144143
tensor v = split(m, qkv, 2);
145-
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
146-
k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3));
147-
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later
144+
tensor mask = m.weights("attention_biases_indexed");
145+
float scale = 1.0f / std::sqrt(float(key_dim));
146+
147+
if (m.flags & model_build_flag::flash_attention) {
148+
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
149+
k = ggml_cast(m, ggml_permute(m, k, 0, 2, 1, 3), GGML_TYPE_F16);
150+
v = ggml_cast(m, ggml_permute(m, v, 0, 2, 1, 3), GGML_TYPE_F16);
151+
if (mask->type != GGML_TYPE_F16) {
152+
mask = ggml_cast(m, mask, GGML_TYPE_F16);
153+
}
148154

149-
tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat)
150-
attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(key_dim)));
151-
attn = ggml_add_inplace(m, attn, m.weights("attention_biases_indexed"));
152-
attn = ggml_soft_max(m, attn);
155+
x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f);
156+
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
157+
} else {
158+
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
159+
k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3));
160+
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later
153161

154-
x = ggml_mul_mat(m, v, attn); // attn @ v
155-
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2)
162+
tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat)
163+
attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f);
164+
165+
x = ggml_mul_mat(m, v, attn); // attn @ v
166+
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2)
167+
}
156168
x = ggml_reshape_3d(m, x, key_dim * num_heads, n, b);
157169
x = linear(m["proj"], x);
158170

tests/test_mobile_sam.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch import Tensor
77

88
from . import workbench
9-
from .workbench import to_nhwc, to_nchw, convert_to_nhwc, fuse_conv_2d_batch_norm
9+
from .workbench import to_nhwc, to_nchw, convert_to_nhwc, fuse_conv_2d_batch_norm, tensors_match
1010

1111
torch.set_printoptions(precision=2, linewidth=100, sci_mode=False)
1212

@@ -53,7 +53,7 @@ def test_conv_2d_batch_norm(bias: bool):
5353
result = workbench.invoke_test("sam_conv_2d_batch_norm", x, state, nhwc_layout)
5454
result = to_nchw(result)
5555

56-
assert torch.allclose(result, expected)
56+
assert tensors_match(result, expected)
5757

5858

5959
class PatchEmbed(torch.nn.Module):
@@ -98,7 +98,7 @@ def test_patch_embed():
9898
result = workbench.invoke_test("sam_patch_embed", x, state, nhwc_layout)
9999
result = to_nchw(result)
100100

101-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
101+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
102102

103103

104104
class LayerNorm2d(torch.nn.Module):
@@ -130,7 +130,7 @@ def test_layer_norm_2d():
130130
result = workbench.invoke_test("layer_norm", x, state, nhwc_layout)
131131
result = to_nchw(result)
132132

133-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
133+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
134134

135135

136136
class MBConv(torch.nn.Module):
@@ -193,7 +193,7 @@ def test_mb_conv():
193193
result = to_nchw(result)
194194

195195
# precision: ggml_gelu uses fp16 look-up table & tanh approximation
196-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
196+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
197197

198198

199199
class PatchMerging(torch.nn.Module):
@@ -244,7 +244,7 @@ def test_patch_merging():
244244
result = result.transpose(1, 2).reshape_as(expected)
245245

246246
# precision: ggml_gelu uses fp16 look-up table & tanh approximation
247-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
247+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
248248

249249

250250
class Mlp(torch.nn.Module):
@@ -288,7 +288,7 @@ def test_mlp():
288288
result = workbench.invoke_test("sam_mlp", x, state)
289289

290290
# precision: ggml_gelu uses fp16 look-up table & tanh approximation
291-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
291+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
292292

293293

294294
class AttentionRelBias(torch.nn.Module):
@@ -370,8 +370,8 @@ def forward(self, x): # x (B,N,C)
370370
x = self.proj(x)
371371
return x
372372

373-
374-
def test_attention_rel_bias():
373+
@pytest.mark.parametrize("attn", ["default", "flash_attn"])
374+
def test_attention_rel_bias(attn:str):
375375
attention = AttentionRelBias(4, 2, num_heads=2, attn_ratio=1, resolution=(3, 3))
376376
state = workbench.randomize(attention.state_dict())
377377
attention.load_state_dict(state)
@@ -381,9 +381,9 @@ def test_attention_rel_bias():
381381
expected = attention(x)
382382

383383
state["attention_biases_indexed"] = state["attention_biases"][:, attention.attention_bias_idxs]
384-
result = workbench.invoke_test("sam_attention_rel_bias", x, state)
384+
result = workbench.invoke_test("sam_attention_rel_bias", x, state, {"attn": attn})
385385

386-
assert torch.allclose(result, expected, atol=0.001)
386+
assert tensors_match(result, expected, atol=0.001)
387387

388388

389389
class TinyViTBlock(torch.nn.Module):
@@ -495,7 +495,7 @@ def test_tiny_vit_block():
495495
state = convert_to_nhwc(state)
496496
result = workbench.invoke_test("sam_tiny_vit_block", x, state, nhwc_layout)
497497

498-
assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
498+
assert tensors_match(result, expected, rtol=0.001, atol=0.02)
499499

500500

501501
class ConvLayer(torch.nn.Module):
@@ -787,7 +787,7 @@ def test_tiny_vit():
787787
# result = torch.zeros_like(expected).contiguous()
788788
# result = workbench.invoke_test("sam_tiny_vit", x, state)
789789

790-
# assert torch.allclose(result, expected, rtol=0.001, atol=0.02)
790+
# assert tensors_match(result, expected, rtol=0.001, atol=0.02)
791791

792792

793793
#
@@ -835,7 +835,7 @@ def test_position_embedding_random():
835835

836836
result = workbench.invoke_test("sam_position_embedding_random", x, state)
837837

838-
assert torch.allclose(result, expected)
838+
assert tensors_match(result, expected)
839839

840840

841841
class PromptEncoder(torch.nn.Module):
@@ -951,7 +951,7 @@ def test_prompt_encoder_points():
951951
points = torch.cat([points, -torch.ones(1, 1, 2)], dim=1)
952952
result = workbench.invoke_test("sam_embed_points", points, state)
953953

954-
assert torch.allclose(result, expected)
954+
assert tensors_match(result, expected)
955955

956956

957957
def test_prompt_encoder_box():
@@ -970,7 +970,7 @@ def test_prompt_encoder_box():
970970

971971
result = workbench.invoke_test("sam_embed_box", boxes, state)
972972

973-
assert torch.allclose(result, expected)
973+
assert tensors_match(result, expected)
974974

975975

976976
#
@@ -1046,7 +1046,7 @@ def test_attention():
10461046
state["input_v"] = v
10471047
result = workbench.invoke_test("sam_attention", q, state)
10481048

1049-
assert torch.allclose(result, expected)
1049+
assert tensors_match(result, expected)
10501050

10511051

10521052
class MLPBlock(torch.nn.Module):
@@ -1155,8 +1155,8 @@ def test_two_way_attention_block(mode):
11551155
"sam_two_way_attention_block", queries, state, {"mode": mode}
11561156
)
11571157

1158-
assert torch.allclose(result_queries, expected_queries)
1159-
assert torch.allclose(result_keys, expected_keys)
1158+
assert tensors_match(result_queries, expected_queries)
1159+
assert tensors_match(result_keys, expected_keys)
11601160

11611161

11621162
class TwoWayTransformer(torch.nn.Module):
@@ -1257,8 +1257,8 @@ def test_two_way_transformer():
12571257
"sam_two_way_transformer", image_embedding, state, nhwc_layout
12581258
)
12591259

1260-
assert torch.allclose(result_queries, expected_queries, atol=1e-6, rtol=1e-4)
1261-
assert torch.allclose(result_keys, expected_keys, atol=1e-6, rtol=1e-4)
1260+
assert tensors_match(result_queries, expected_queries, atol=1e-6, rtol=1e-4)
1261+
assert tensors_match(result_keys, expected_keys, atol=1e-6, rtol=1e-4)
12621262

12631263

12641264
class HypernetworkMLP(torch.nn.Module):
@@ -1297,7 +1297,7 @@ def test_hypernetwork_mlp():
12971297

12981298
result = workbench.invoke_test("sam_hypernetwork_mlp", x, state)
12991299

1300-
assert torch.allclose(result, expected)
1300+
assert tensors_match(result, expected)
13011301

13021302

13031303
def output_upscaling(transformer_dim: int, activation=torch.nn.GELU):
@@ -1325,8 +1325,7 @@ def test_output_upscaling():
13251325
result = workbench.invoke_test("sam_output_upscaling", x, state, nhwc_layout, backend="vulkan")
13261326
result = to_nchw(result)
13271327

1328-
workbench.print_results(result, expected)
1329-
assert torch.allclose(result, expected, rtol=0.1) # fp16 weights
1328+
assert tensors_match(result, expected, rtol=0.1) # fp16 weights
13301329

13311330

13321331
class MaskDecoder(torch.nn.Module):
@@ -1465,5 +1464,5 @@ def test_predict_masks():
14651464
"sam_predict_masks", image_embeddings, state, nhwc_layout, backend="vulkan"
14661465
)
14671466

1468-
assert torch.allclose(result_masks, expected_masks, rtol=1e-2, atol=1e-2)
1469-
assert torch.allclose(result_iou_pred, iou_pred, rtol=1e-2)
1467+
assert tensors_match(result_masks, expected_masks, rtol=1e-2, atol=1e-2)
1468+
assert tensors_match(result_iou_pred, iou_pred, rtol=1e-2)

tests/workbench.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,11 @@ DEF(sam_mlp)(model_ref m, span<tensor> input, param_dict const& p) {
154154
}
155155

156156
DEF(sam_attention_rel_bias)(model_ref m, span<tensor> input, param_dict const& p) {
157+
if (p.get("attn", "default") == "flash_attn"sv) {
158+
m.flags = m.flags | model_build_flag::flash_attention;
159+
} else {
160+
m.flags = m.flags & ~model_build_flag::flash_attention;
161+
}
157162
return {sam::attention_rel_bias(m, input[0], 4, 2)};
158163
}
159164

0 commit comments

Comments
 (0)