Skip to content

Commit b8bb09d

Browse files
committed
nn: share attention code across sam, swin and dino
1 parent d631109 commit b8bb09d

File tree

8 files changed

+114
-128
lines changed

8 files changed

+114
-128
lines changed

src/visp/arch/dino.cpp

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -55,42 +55,27 @@ tensor mlp(model_ref m, tensor x) {
5555
return x;
5656
}
5757

58-
tensor attention(model_ref m, tensor x, int n_heads) {
58+
tensor self_attention(model_ref m, tensor x, int n_heads) {
5959
auto [c, n, b, _] = nelements(x);
60-
float scale = 1.0f / std::sqrt(float(c) / float(n_heads));
61-
bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
62-
ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
63-
64-
auto split = [=](model_ref m, tensor x, ggml_type type, bool transpose = false) mutable {
65-
x = linear(m, x);
66-
x = ggml_reshape_4d(m, x, c / n_heads, n_heads, n, b);
67-
x = transpose ? ggml_permute(m, x, 1, 2, 0, 3) : ggml_permute(m, x, 0, 2, 1, 3);
68-
return ggml_cast(m, x, type);
60+
auto project = [&](model_ref m, tensor t) {
61+
t = linear(m, t);
62+
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
63+
return t;
6964
};
7065

71-
tensor q = split(m["attention.query"], x, GGML_TYPE_F32);
72-
tensor k = split(m["attention.key"], x, kv_type);
73-
tensor v = split(m["attention.value"], x, kv_type, !flash_attn);
74-
75-
if (flash_attn) {
76-
x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f);
77-
} else {
78-
tensor attn = ggml_mul_mat(m, k, q);
79-
attn = ggml_soft_max_ext(m, attn, nullptr, scale, 0.0f);
80-
81-
x = ggml_mul_mat(m, v, attn);
82-
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
83-
}
66+
tensor q = project(m["attention.query"], x);
67+
tensor k = project(m["attention.key"], x);
68+
tensor v = project(m["attention.value"], x);
8469

85-
x = ggml_reshape_3d(m, x, c, n, b);
86-
x = linear(m["output.dense"], x);
87-
return named(m, x);
70+
float scale = 1.0f / std::sqrt(float(c) / float(n_heads));
71+
x = attention(m, q, k, v, nullptr, scale, m["output.dense"]);
72+
return x;
8873
}
8974

9075
tensor layer(model_ref m, tensor x, dino_params const& p) {
9176
tensor attn = x;
9277
attn = layer_norm(m["norm1"], attn, 1e-6f);
93-
attn = attention(m["attention"], attn, p.n_heads);
78+
attn = self_attention(m["attention"], attn, p.n_heads);
9479
attn = layer_scale(m["layer_scale1"], attn);
9580
x = ggml_add(m, x, attn);
9681

src/visp/arch/dino.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int
1212
tensor prepare_tokens(model_ref m, tensor x, int patch_size);
1313
tensor layer_scale(model_ref m, tensor x);
1414
tensor mlp(model_ref m, tensor x);
15-
tensor attention(model_ref m, tensor x, int n_heads);
15+
tensor self_attention(model_ref m, tensor x, int n_heads);
1616
tensor layer(model_ref m, tensor x, dino_params const& p);
1717

1818
std::vector<tensor> get_intermediate_layers(

src/visp/arch/mobile-sam.cpp

Lines changed: 18 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -121,54 +121,14 @@ tensor mlp(model_ref m, tensor x) {
121121
return named(m, x);
122122
}
123123

124-
tensor attention_rel_bias(model_ref m, tensor x, int dim, int num_heads) {
125-
GGML_ASSERT(dim % num_heads == 0);
126-
int key_dim = dim / num_heads;
127-
auto [c, n, b, _] = nelements(x);
128-
129-
x = layer_norm(m["norm"], x);
130-
131-
tensor qkv = linear(m["qkv"], x);
132-
qkv = ggml_reshape_4d(m, qkv, key_dim, 3, num_heads * n, b);
133-
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2)); // ne = [key_dim, num_heads * n, b, 3]
134-
135-
auto split = [=](model_ref m, tensor tensor, int64_t index) {
136-
tensor = slice(m, tensor, {}, {}, {}, index);
137-
tensor = ggml_reshape_4d(m, tensor, key_dim, num_heads, n, b);
138-
return tensor;
139-
};
140-
141-
tensor q = split(m, qkv, 0);
142-
tensor k = split(m, qkv, 1);
143-
tensor v = split(m, qkv, 2);
124+
tensor attention_rel_bias(model_ref m, tensor x, int dim, int n_heads) {
125+
float scale = 1.0f / std::sqrt(float(dim / n_heads));
144126
tensor mask = m.weights("attention_biases_indexed");
145-
float scale = 1.0f / std::sqrt(float(key_dim));
146-
147-
if (m.flags & model_build_flag::flash_attention) {
148-
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
149-
k = ggml_cast(m, ggml_permute(m, k, 0, 2, 1, 3), GGML_TYPE_F16);
150-
v = ggml_cast(m, ggml_permute(m, v, 0, 2, 1, 3), GGML_TYPE_F16);
151-
if (mask->type != GGML_TYPE_F16) {
152-
mask = ggml_cast(m, mask, GGML_TYPE_F16);
153-
}
154-
155-
x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f);
156-
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
157-
} else {
158-
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
159-
k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3));
160-
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later
161127

162-
tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat)
163-
attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f);
164-
165-
x = ggml_mul_mat(m, v, attn); // attn @ v
166-
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2)
167-
}
168-
x = ggml_reshape_3d(m, x, key_dim * num_heads, n, b);
169-
x = linear(m["proj"], x);
170-
171-
return named(m, x);
128+
x = layer_norm(m["norm"], x);
129+
auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 1);
130+
x = attention(m, q, k, v, mask, scale, m["proj"]);
131+
return x;
172132
}
173133

174134
tensor tiny_vit_block(
@@ -344,25 +304,18 @@ tensor separate_attention_heads(model_ref m, tensor x, int num_heads) {
344304
return x;
345305
}
346306

347-
tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads) {
307+
tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int n_heads) {
348308
q = linear(m["q_proj"], q);
349309
k = linear(m["k_proj"], k);
350310
v = linear(m["v_proj"], v);
351311

352-
q = separate_attention_heads(m, q, num_heads);
353-
k = separate_attention_heads(m, k, num_heads);
354-
v = ggml_reshape_4d(m, v, v->ne[0] / num_heads, num_heads, v->ne[1], v->ne[2]);
355-
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // already transposed for mul_mat
312+
q = ggml_reshape_4d(m, q, q->ne[0] / n_heads, n_heads, q->ne[1], q->ne[2]);
313+
k = ggml_reshape_4d(m, k, k->ne[0] / n_heads, n_heads, k->ne[1], k->ne[2]);
314+
v = ggml_reshape_4d(m, v, v->ne[0] / n_heads, n_heads, v->ne[1], v->ne[2]);
356315

357-
tensor attn = ggml_mul_mat(m, k, q);
358-
attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(q->ne[0])));
359-
attn = ggml_soft_max(m, attn);
360-
361-
tensor out = ggml_mul_mat(m, v, attn);
362-
out = ggml_cont(m, ggml_permute(m, out, 0, 2, 1, 3));
363-
out = ggml_reshape_3d(m, out, out->ne[0] * out->ne[1], out->ne[2], out->ne[3]);
364-
out = linear(m["out_proj"], out);
365-
return out;
316+
float scale = 1.0f / std::sqrt(float(q->ne[0]));
317+
tensor x = attention(m, q, k, v, nullptr, scale, m["out_proj"]);
318+
return x;
366319
}
367320

368321
auto two_way_attention_block(
@@ -375,18 +328,18 @@ auto two_way_attention_block(
375328
bool skip_first_layer_pe) -> std::tuple<tensor, tensor> {
376329
// Self attention block
377330
if (skip_first_layer_pe) {
378-
queries = attention(m["self_attn"], queries, queries, queries, num_heads);
331+
queries = decoder_attention(m["self_attn"], queries, queries, queries, num_heads);
379332
} else {
380333
tensor q = ggml_add(m, queries, query_pe);
381-
tensor attn_out = attention(m["self_attn"], q, q, queries, num_heads);
334+
tensor attn_out = decoder_attention(m["self_attn"], q, q, queries, num_heads);
382335
queries = ggml_add(m, queries, attn_out);
383336
}
384337
queries = layer_norm(m["norm1"], queries);
385338

386339
// Cross attention block, tokens attending to image embedding
387340
tensor q = ggml_add(m, queries, query_pe);
388341
tensor k = ggml_add(m, keys, key_pe);
389-
tensor attn_out = attention(m["cross_attn_t2i"], q, k, keys, num_heads);
342+
tensor attn_out = decoder_attention(m["cross_attn_t2i"], q, k, keys, num_heads);
390343
queries = ggml_add_inplace(m, queries, attn_out);
391344
queries = layer_norm(m["norm2"], queries);
392345

@@ -401,7 +354,7 @@ auto two_way_attention_block(
401354
// Cross attention block, image embedding attending to tokens
402355
q = ggml_add(m, queries, query_pe);
403356
// k = ggml_add(m, keys, key_pe); // redundant, same as above
404-
attn_out = attention(m["cross_attn_i2t"], k, q, queries, num_heads);
357+
attn_out = decoder_attention(m["cross_attn_i2t"], k, q, queries, num_heads);
405358
keys = ggml_add_inplace(m, keys, attn_out);
406359
keys = layer_norm(m["norm4"], keys);
407360

@@ -434,7 +387,7 @@ auto two_way_transformer(
434387
// Apply the final attention layer from the points to the image
435388
tensor q = ggml_add(m, queries, point_embedding);
436389
tensor k = ggml_add(m, keys, image_pe);
437-
tensor attn_out = attention(m["final_attn_t2i"], q, k, keys, num_heads);
390+
tensor attn_out = decoder_attention(m["final_attn_t2i"], q, k, keys, num_heads);
438391
queries = ggml_add_inplace(m, queries, attn_out);
439392
queries = layer_norm(m["norm_final_attn"], queries);
440393

src/visp/arch/mobile-sam.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ tensor position_embedding_random(model_ref m, tensor coords);
6666

6767
tensor mlp_block(model_ref m, tensor x);
6868
tensor separate_attention_heads(model_ref m, tensor x, int num_heads);
69-
tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads);
69+
tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int num_heads);
7070
std::tuple<tensor, tensor> two_way_attention_block(
7171
model_ref m,
7272
tensor queries,

src/visp/arch/swin.cpp

Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,6 @@ tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) {
6565

6666
tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int window) {
6767
auto [c, n, b, _] = nelements(x);
68-
float scale = 1.0f / std::sqrt(float(c / n_heads));
69-
bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
70-
ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;
71-
72-
tensor qkv = linear(m["qkv"], x);
73-
qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b);
74-
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));
75-
76-
auto split = [=](tensor t, size_t index, ggml_type type, bool transpose = false) mutable {
77-
t = slice(m, t, {}, {}, {}, index);
78-
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
79-
t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3);
80-
t = ggml_cast(m, t, type); // TODO: future flash attention supports f32 and permutations
81-
return t;
82-
};
83-
tensor q = split(qkv, 0, GGML_TYPE_F32);
84-
tensor k = split(qkv, 1, kv_type);
85-
tensor v = split(qkv, 2, kv_type, !flash_attn);
8668

8769
tensor_name rel_pos_name = format<tensor_name>("window_attention_{}.rel_pos_index", window);
8870
tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str());
@@ -104,19 +86,10 @@ tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int win
10486
attn_mask = ggml_add(m, mask, attn_mask); // [n, n, n_heads, b] + [n, n, n_heads, 1]
10587
}
10688

107-
if (flash_attn) {
108-
x = ggml_flash_attn_ext(m, q, k, v, attn_mask, scale, 0.0f, 0.0f);
109-
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
110-
} else {
111-
tensor attn = ggml_mul_mat(m, k, q);
112-
attn = ggml_soft_max_ext(m, attn, attn_mask, scale, 0.0f);
113-
114-
x = ggml_mul_mat(m, v, attn);
115-
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
116-
}
89+
auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 2);
90+
float scale = 1.0f / std::sqrt(float(c / n_heads));
91+
x = attention(m, q, k, v, attn_mask, scale, m["proj"]);
11792

118-
x = ggml_reshape_3d(m, x, c, n, b);
119-
x = linear(m["proj"], x);
12093
return named(m, x);
12194
}
12295

src/visp/nn.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ tensor conv_2d(model_ref m, tensor x, int stride, int pad) {
8080
x = ggml_mul_mat(m, weight, x);
8181
x = ggml_reshape_4d(m, x, weight->ne[1], w, h, b);
8282

83-
} else if (m.flags & model_build_flag::conv_2d_direct_cwhn) {
83+
} else if (m.flags & model_build_flag::conv_2d_direct_cwhn) {
8484
weight = permute_cwhn_to_whcn(m, weight);
8585
x = permute_cwhn_to_whcn(m, x);
8686
x = ggml_conv_2d_direct(m, weight, x, stride, stride, pad, pad, 1, 1);
@@ -144,7 +144,7 @@ tensor conv_2d_deform(
144144
}
145145
}
146146
x = ggml_conv_2d_deform(m, weight, x, offset, mask, stride, stride, pad, pad);
147-
147+
148148
if (m.flags & model_build_flag::cwhn) {
149149
x = permute_whcn_to_cwhn(m, x);
150150
}
@@ -183,4 +183,68 @@ tensor patch_embed(model_ref m, tensor x, int patch_size) {
183183
return named(m, x);
184184
}
185185

186+
attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim) {
187+
auto [c, n, b, _] = nelements(x);
188+
189+
tensor qkv = linear(m, x);
190+
switch (split_dim) {
191+
case 1:
192+
qkv = ggml_reshape_4d(m, qkv, c / n_heads, 3, n_heads * n, b);
193+
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2));
194+
break;
195+
case 2:
196+
qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b);
197+
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));
198+
break;
199+
default: ASSERT(false, "Unsupported split_dim");
200+
}
201+
202+
auto split = [&](tensor t, size_t index) mutable {
203+
t = slice(m, t, {}, {}, {}, index);
204+
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
205+
return t;
206+
};
207+
208+
tensor q = split(qkv, 0);
209+
tensor k = split(qkv, 1);
210+
tensor v = split(qkv, 2);
211+
return {q, k, v};
212+
}
213+
214+
tensor attention(
215+
model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out) {
216+
217+
q = ggml_permute(m, q, 0, 2, 1, 3);
218+
k = ggml_permute(m, k, 0, 2, 1, 3);
219+
220+
tensor x = nullptr;
221+
if (m.flags & model_build_flag::flash_attention) {
222+
v = ggml_permute(m, v, 0, 2, 1, 3);
223+
224+
k = ggml_cast(m, k, GGML_TYPE_F16);
225+
v = ggml_cast(m, v, GGML_TYPE_F16);
226+
if (mask && mask->type != GGML_TYPE_F16) {
227+
mask = ggml_cast(m, mask, GGML_TYPE_F16);
228+
}
229+
230+
x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f);
231+
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
232+
233+
} else {
234+
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3));
235+
236+
tensor attn = ggml_mul_mat(m, k, q);
237+
attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f);
238+
x = ggml_mul_mat(m, v, attn);
239+
240+
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
241+
}
242+
243+
// [head_dim, n_heads, n_patches, batch] -> [embed_dim, n_patches, batch]
244+
x = ggml_reshape_3d(m, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
245+
x = linear(m_out, x);
246+
247+
return named(m, x);
248+
}
249+
186250
} // namespace visp

src/visp/nn.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,15 @@ tensor batch_norm_2d(model_ref, tensor x);
4141
// 2D image to patch embedding using convolution and optional norm. CWHN input and output.
4242
tensor patch_embed(model_ref, tensor x, int patch_size);
4343

44+
struct attention_qkv {
45+
tensor q, k, v;
46+
};
47+
// Input: x [head_dim*n_heads, n_patches, batch]
48+
// Output: q, k, v each of shape [head_dim, n_heads, n_patches, batch]
49+
attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim);
50+
51+
// Attention with optional mask and output linear layer.
52+
tensor attention(
53+
model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out);
54+
4455
} // namespace visp

tests/workbench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ DEF(sam_attention)(model_ref m, span<tensor> input, param_dict const& p) {
205205
tensor q = input[0];
206206
tensor k = m.weights("input_k");
207207
tensor v = m.weights("input_v");
208-
return {sam::attention(m, q, k, v, 2)};
208+
return {sam::decoder_attention(m, q, k, v, 2)};
209209
}
210210

211211
DEF(sam_two_way_attention_block)(model_ref m, span<tensor> input, param_dict const& p) {
@@ -443,7 +443,7 @@ DEF(dino_attention)(model_ref m, span<tensor> input, param_dict const& p) {
443443
if (p.get("flash_attn", 0) != 0) {
444444
m.flags |= model_build_flag::flash_attention;
445445
}
446-
return {dino::attention(m, input[0], p.get("n_heads", 8))};
446+
return {dino::self_attention(m, input[0], p.get("n_heads", 8))};
447447
}
448448

449449
DEF(dino_block)(model_ref m, span<tensor> input, param_dict const& p) {

0 commit comments

Comments
 (0)