@@ -121,54 +121,15 @@ tensor mlp(model_ref m, tensor x) {
121121 return named (m, x);
122122}
123123
124- tensor attention_rel_bias (model_ref m, tensor x, int dim, int num_heads) {
125- GGML_ASSERT (dim % num_heads == 0 );
126- int key_dim = dim / num_heads;
124+ tensor attention_rel_bias (model_ref m, tensor x, int dim, int n_heads) {
127125 auto [c, n, b, _] = nelements (x);
128-
129- x = layer_norm (m[" norm" ], x);
130-
131- tensor qkv = linear (m[" qkv" ], x);
132- qkv = ggml_reshape_4d (m, qkv, key_dim, 3 , num_heads * n, b);
133- qkv = ggml_cont (m, ggml_permute (m, qkv, 0 , 3 , 1 , 2 )); // ne = [key_dim, num_heads * n, b, 3]
134-
135- auto split = [=](model_ref m, tensor tensor, int64_t index) {
136- tensor = slice (m, tensor, {}, {}, {}, index);
137- tensor = ggml_reshape_4d (m, tensor, key_dim, num_heads, n, b);
138- return tensor;
139- };
140-
141- tensor q = split (m, qkv, 0 );
142- tensor k = split (m, qkv, 1 );
143- tensor v = split (m, qkv, 2 );
126+ float scale = 1 .0f / std::sqrt (float (dim / n_heads));
144127 tensor mask = m.weights (" attention_biases_indexed" );
145- float scale = 1 .0f / std::sqrt (float (key_dim));
146-
147- if (m.flags & model_build_flag::flash_attention) {
148- q = ggml_cont (m, ggml_permute (m, q, 0 , 2 , 1 , 3 ));
149- k = ggml_cast (m, ggml_permute (m, k, 0 , 2 , 1 , 3 ), GGML_TYPE_F16);
150- v = ggml_cast (m, ggml_permute (m, v, 0 , 2 , 1 , 3 ), GGML_TYPE_F16);
151- if (mask->type != GGML_TYPE_F16) {
152- mask = ggml_cast (m, mask, GGML_TYPE_F16);
153- }
154-
155- x = ggml_flash_attn_ext (m, q, k, v, mask, scale, 0 .0f , 0 .0f );
156- ggml_flash_attn_ext_set_prec (x, GGML_PREC_F32);
157- } else {
158- q = ggml_cont (m, ggml_permute (m, q, 0 , 2 , 1 , 3 ));
159- k = ggml_cont (m, ggml_permute (m, k, 0 , 2 , 1 , 3 ));
160- v = ggml_cont (m, ggml_permute (m, v, 1 , 2 , 0 , 3 )); // transpose for mul_mat later
161128
162- tensor attn = ggml_mul_mat (m, k, q); // q @ k (k is transposed in mul_mat)
163- attn = ggml_soft_max_ext (m, attn, mask, scale, 0 .0f );
164-
165- x = ggml_mul_mat (m, v, attn); // attn @ v
166- x = ggml_cont (m, ggml_permute (m, x, 0 , 2 , 1 , 3 )); // transpose(1, 2)
167- }
168- x = ggml_reshape_3d (m, x, key_dim * num_heads, n, b);
169- x = linear (m[" proj" ], x);
170-
171- return named (m, x);
129+ x = layer_norm (m[" norm" ], x);
130+ auto [q, k, v] = split_qkv (m[" qkv" ], x, n_heads, 1 );
131+ x = attention (m, q, k, v, mask, scale, m[" proj" ]);
132+ return x;
172133}
173134
174135tensor tiny_vit_block (
@@ -344,25 +305,18 @@ tensor separate_attention_heads(model_ref m, tensor x, int num_heads) {
344305 return x;
345306}
346307
347- tensor attention (model_ref m, tensor q, tensor k, tensor v, int num_heads ) {
308+ tensor decoder_attention (model_ref m, tensor q, tensor k, tensor v, int n_heads ) {
348309 q = linear (m[" q_proj" ], q);
349310 k = linear (m[" k_proj" ], k);
350311 v = linear (m[" v_proj" ], v);
351312
352- q = separate_attention_heads (m, q, num_heads);
353- k = separate_attention_heads (m, k, num_heads);
354- v = ggml_reshape_4d (m, v, v->ne [0 ] / num_heads, num_heads, v->ne [1 ], v->ne [2 ]);
355- v = ggml_cont (m, ggml_permute (m, v, 1 , 2 , 0 , 3 )); // already transposed for mul_mat
313+ q = ggml_reshape_4d (m, q, q->ne [0 ] / n_heads, n_heads, q->ne [1 ], q->ne [2 ]);
314+ k = ggml_reshape_4d (m, k, k->ne [0 ] / n_heads, n_heads, k->ne [1 ], k->ne [2 ]);
315+ v = ggml_reshape_4d (m, v, v->ne [0 ] / n_heads, n_heads, v->ne [1 ], v->ne [2 ]);
356316
357- tensor attn = ggml_mul_mat (m, k, q);
358- attn = ggml_scale_inplace (m, attn, 1 .0f / std::sqrt (float (q->ne [0 ])));
359- attn = ggml_soft_max (m, attn);
360-
361- tensor out = ggml_mul_mat (m, v, attn);
362- out = ggml_cont (m, ggml_permute (m, out, 0 , 2 , 1 , 3 ));
363- out = ggml_reshape_3d (m, out, out->ne [0 ] * out->ne [1 ], out->ne [2 ], out->ne [3 ]);
364- out = linear (m[" out_proj" ], out);
365- return out;
317+ float scale = 1 .0f / std::sqrt (float (q->ne [0 ]));
318+ tensor x = attention (m, q, k, v, nullptr , scale, m[" out_proj" ]);
319+ return x;
366320}
367321
368322auto two_way_attention_block (
@@ -375,18 +329,18 @@ auto two_way_attention_block(
375329 bool skip_first_layer_pe) -> std::tuple<tensor, tensor> {
376330 // Self attention block
377331 if (skip_first_layer_pe) {
378- queries = attention (m[" self_attn" ], queries, queries, queries, num_heads);
332+ queries = decoder_attention (m[" self_attn" ], queries, queries, queries, num_heads);
379333 } else {
380334 tensor q = ggml_add (m, queries, query_pe);
381- tensor attn_out = attention (m[" self_attn" ], q, q, queries, num_heads);
335+ tensor attn_out = decoder_attention (m[" self_attn" ], q, q, queries, num_heads);
382336 queries = ggml_add (m, queries, attn_out);
383337 }
384338 queries = layer_norm (m[" norm1" ], queries);
385339
386340 // Cross attention block, tokens attending to image embedding
387341 tensor q = ggml_add (m, queries, query_pe);
388342 tensor k = ggml_add (m, keys, key_pe);
389- tensor attn_out = attention (m[" cross_attn_t2i" ], q, k, keys, num_heads);
343+ tensor attn_out = decoder_attention (m[" cross_attn_t2i" ], q, k, keys, num_heads);
390344 queries = ggml_add_inplace (m, queries, attn_out);
391345 queries = layer_norm (m[" norm2" ], queries);
392346
@@ -401,7 +355,7 @@ auto two_way_attention_block(
401355 // Cross attention block, image embedding attending to tokens
402356 q = ggml_add (m, queries, query_pe);
403357 // k = ggml_add(m, keys, key_pe); // redundant, same as above
404- attn_out = attention (m[" cross_attn_i2t" ], k, q, queries, num_heads);
358+ attn_out = decoder_attention (m[" cross_attn_i2t" ], k, q, queries, num_heads);
405359 keys = ggml_add_inplace (m, keys, attn_out);
406360 keys = layer_norm (m[" norm4" ], keys);
407361
@@ -434,7 +388,7 @@ auto two_way_transformer(
434388 // Apply the final attention layer from the points to the image
435389 tensor q = ggml_add (m, queries, point_embedding);
436390 tensor k = ggml_add (m, keys, image_pe);
437- tensor attn_out = attention (m[" final_attn_t2i" ], q, k, keys, num_heads);
391+ tensor attn_out = decoder_attention (m[" final_attn_t2i" ], q, k, keys, num_heads);
438392 queries = ggml_add_inplace (m, queries, attn_out);
439393 queries = layer_norm (m[" norm_final_attn" ], queries);
440394
0 commit comments