Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 12 additions & 27 deletions src/visp/arch/dino.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,42 +55,27 @@ tensor mlp(model_ref m, tensor x) {
return x;
}

tensor attention(model_ref m, tensor x, int n_heads) {
tensor self_attention(model_ref m, tensor x, int n_heads) {
auto [c, n, b, _] = nelements(x);
float scale = 1.0f / std::sqrt(float(c) / float(n_heads));
bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;

auto split = [=](model_ref m, tensor x, ggml_type type, bool transpose = false) mutable {
x = linear(m, x);
x = ggml_reshape_4d(m, x, c / n_heads, n_heads, n, b);
x = transpose ? ggml_permute(m, x, 1, 2, 0, 3) : ggml_permute(m, x, 0, 2, 1, 3);
return ggml_cast(m, x, type);
auto project = [&](model_ref m, tensor t) {
t = linear(m, t);
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
return t;
};

tensor q = split(m["attention.query"], x, GGML_TYPE_F32);
tensor k = split(m["attention.key"], x, kv_type);
tensor v = split(m["attention.value"], x, kv_type, !flash_attn);

if (flash_attn) {
x = ggml_flash_attn_ext(m, q, k, v, nullptr, scale, 0.0f, 0.0f);
} else {
tensor attn = ggml_mul_mat(m, k, q);
attn = ggml_soft_max_ext(m, attn, nullptr, scale, 0.0f);

x = ggml_mul_mat(m, v, attn);
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
}
tensor q = project(m["attention.query"], x);
tensor k = project(m["attention.key"], x);
tensor v = project(m["attention.value"], x);

x = ggml_reshape_3d(m, x, c, n, b);
x = linear(m["output.dense"], x);
return named(m, x);
float scale = 1.0f / std::sqrt(float(c) / float(n_heads));
x = attention(m, q, k, v, nullptr, scale, m["output.dense"]);
return x;
}

tensor layer(model_ref m, tensor x, dino_params const& p) {
tensor attn = x;
attn = layer_norm(m["norm1"], attn, 1e-6f);
attn = attention(m["attention"], attn, p.n_heads);
attn = self_attention(m["attention"], attn, p.n_heads);
attn = layer_scale(m["layer_scale1"], attn);
x = ggml_add(m, x, attn);

Expand Down
2 changes: 1 addition & 1 deletion src/visp/arch/dino.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ tensor interpolate_pos_encoding(model_ref m, tensor x, int64_t w, int64_t h, int
tensor prepare_tokens(model_ref m, tensor x, int patch_size);
tensor layer_scale(model_ref m, tensor x);
tensor mlp(model_ref m, tensor x);
tensor attention(model_ref m, tensor x, int n_heads);
tensor self_attention(model_ref m, tensor x, int n_heads);
tensor layer(model_ref m, tensor x, dino_params const& p);

std::vector<tensor> get_intermediate_layers(
Expand Down
77 changes: 21 additions & 56 deletions src/visp/arch/mobile-sam.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include "visp/arch/mobile-sam.h"
#include "visp/nn.h"
#include "visp/vision.h"
#include "util/math.h"
#include "util/string.h"
#include "visp/nn.h"
#include "visp/vision.h"

#include <ggml.h>

Expand All @@ -13,7 +13,7 @@ namespace visp {
namespace sam {

tensor conv_2d_batch_norm(model_ref m, tensor x, int stride = 1, int pad = 0) {
// batch_norm is fused into conv_2d when converting the model
// batch_norm is fused into conv_2d when converting the model
return conv_2d(m["c"], x, stride, pad);
}

Expand Down Expand Up @@ -68,7 +68,6 @@ tensor window_reverse(model_ref m, tensor x, int w, int h, int window) {
// Image encoder
//


tensor patch_embed(model_ref m, tensor x) {
x = conv_2d_batch_norm(m["seq.0"], x, 2, 1);
x = ggml_gelu_inplace(m, x);
Expand Down Expand Up @@ -122,41 +121,14 @@ tensor mlp(model_ref m, tensor x) {
return named(m, x);
}

tensor attention_rel_bias(model_ref m, tensor x, int dim, int num_heads) {
GGML_ASSERT(dim % num_heads == 0);
int key_dim = dim / num_heads;
auto [c, n, b, _] = nelements(x);
tensor attention_rel_bias(model_ref m, tensor x, int dim, int n_heads) {
float scale = 1.0f / std::sqrt(float(dim / n_heads));
tensor mask = m.weights("attention_biases_indexed");

x = layer_norm(m["norm"], x);

tensor qkv = linear(m["qkv"], x);
qkv = ggml_reshape_4d(m, qkv, key_dim, 3, num_heads * n, b);
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2)); // ne = [key_dim, num_heads * n, b, 3]

auto split = [=](model_ref m, tensor tensor, int64_t index) {
tensor = slice(m, tensor, {}, {}, {}, index);
tensor = ggml_reshape_4d(m, tensor, key_dim, num_heads, n, b);
return tensor;
};

tensor q = split(m, qkv, 0);
tensor k = split(m, qkv, 1);
tensor v = split(m, qkv, 2);
q = ggml_cont(m, ggml_permute(m, q, 0, 2, 1, 3));
k = ggml_cont(m, ggml_permute(m, k, 0, 2, 1, 3));
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // transpose for mul_mat later

tensor attn = ggml_mul_mat(m, k, q); // q @ k (k is transposed in mul_mat)
attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(key_dim)));
attn = ggml_add_inplace(m, attn, m.weights("attention_biases_indexed"));
attn = ggml_soft_max(m, attn);

x = ggml_mul_mat(m, v, attn); // attn @ v
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3)); // transpose(1, 2)
x = ggml_reshape_3d(m, x, key_dim * num_heads, n, b);
x = linear(m["proj"], x);

return named(m, x);
auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 1);
x = attention(m, q, k, v, mask, scale, m["proj"]);
return x;
}

tensor tiny_vit_block(
Expand Down Expand Up @@ -332,25 +304,18 @@ tensor separate_attention_heads(model_ref m, tensor x, int num_heads) {
return x;
}

tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads) {
tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int n_heads) {
q = linear(m["q_proj"], q);
k = linear(m["k_proj"], k);
v = linear(m["v_proj"], v);

q = separate_attention_heads(m, q, num_heads);
k = separate_attention_heads(m, k, num_heads);
v = ggml_reshape_4d(m, v, v->ne[0] / num_heads, num_heads, v->ne[1], v->ne[2]);
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3)); // already transposed for mul_mat
q = ggml_reshape_4d(m, q, q->ne[0] / n_heads, n_heads, q->ne[1], q->ne[2]);
k = ggml_reshape_4d(m, k, k->ne[0] / n_heads, n_heads, k->ne[1], k->ne[2]);
v = ggml_reshape_4d(m, v, v->ne[0] / n_heads, n_heads, v->ne[1], v->ne[2]);

tensor attn = ggml_mul_mat(m, k, q);
attn = ggml_scale_inplace(m, attn, 1.0f / std::sqrt(float(q->ne[0])));
attn = ggml_soft_max(m, attn);

tensor out = ggml_mul_mat(m, v, attn);
out = ggml_cont(m, ggml_permute(m, out, 0, 2, 1, 3));
out = ggml_reshape_3d(m, out, out->ne[0] * out->ne[1], out->ne[2], out->ne[3]);
out = linear(m["out_proj"], out);
return out;
float scale = 1.0f / std::sqrt(float(q->ne[0]));
tensor x = attention(m, q, k, v, nullptr, scale, m["out_proj"]);
return x;
}

auto two_way_attention_block(
Expand All @@ -363,18 +328,18 @@ auto two_way_attention_block(
bool skip_first_layer_pe) -> std::tuple<tensor, tensor> {
// Self attention block
if (skip_first_layer_pe) {
queries = attention(m["self_attn"], queries, queries, queries, num_heads);
queries = decoder_attention(m["self_attn"], queries, queries, queries, num_heads);
} else {
tensor q = ggml_add(m, queries, query_pe);
tensor attn_out = attention(m["self_attn"], q, q, queries, num_heads);
tensor attn_out = decoder_attention(m["self_attn"], q, q, queries, num_heads);
queries = ggml_add(m, queries, attn_out);
}
queries = layer_norm(m["norm1"], queries);

// Cross attention block, tokens attending to image embedding
tensor q = ggml_add(m, queries, query_pe);
tensor k = ggml_add(m, keys, key_pe);
tensor attn_out = attention(m["cross_attn_t2i"], q, k, keys, num_heads);
tensor attn_out = decoder_attention(m["cross_attn_t2i"], q, k, keys, num_heads);
queries = ggml_add_inplace(m, queries, attn_out);
queries = layer_norm(m["norm2"], queries);

Expand All @@ -389,7 +354,7 @@ auto two_way_attention_block(
// Cross attention block, image embedding attending to tokens
q = ggml_add(m, queries, query_pe);
// k = ggml_add(m, keys, key_pe); // redundant, same as above
attn_out = attention(m["cross_attn_i2t"], k, q, queries, num_heads);
attn_out = decoder_attention(m["cross_attn_i2t"], k, q, queries, num_heads);
keys = ggml_add_inplace(m, keys, attn_out);
keys = layer_norm(m["norm4"], keys);

Expand Down Expand Up @@ -422,7 +387,7 @@ auto two_way_transformer(
// Apply the final attention layer from the points to the image
tensor q = ggml_add(m, queries, point_embedding);
tensor k = ggml_add(m, keys, image_pe);
tensor attn_out = attention(m["final_attn_t2i"], q, k, keys, num_heads);
tensor attn_out = decoder_attention(m["final_attn_t2i"], q, k, keys, num_heads);
queries = ggml_add_inplace(m, queries, attn_out);
queries = layer_norm(m["norm_final_attn"], queries);

Expand Down
2 changes: 1 addition & 1 deletion src/visp/arch/mobile-sam.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ tensor position_embedding_random(model_ref m, tensor coords);

tensor mlp_block(model_ref m, tensor x);
tensor separate_attention_heads(model_ref m, tensor x, int num_heads);
tensor attention(model_ref m, tensor q, tensor k, tensor v, int num_heads);
tensor decoder_attention(model_ref m, tensor q, tensor k, tensor v, int num_heads);
std::tuple<tensor, tensor> two_way_attention_block(
model_ref m,
tensor queries,
Expand Down
33 changes: 3 additions & 30 deletions src/visp/arch/swin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,24 +65,6 @@ tensor window_reverse(model_ref m, tensor x, int64_t w, int64_t h, int window) {

tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int window) {
auto [c, n, b, _] = nelements(x);
float scale = 1.0f / std::sqrt(float(c / n_heads));
bool flash_attn = bool(m.flags & model_build_flag::flash_attention);
ggml_type kv_type = flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32;

tensor qkv = linear(m["qkv"], x);
qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b);
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));

auto split = [=](tensor t, size_t index, ggml_type type, bool transpose = false) mutable {
t = slice(m, t, {}, {}, {}, index);
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
t = transpose ? ggml_permute(m, t, 1, 2, 0, 3) : ggml_permute(m, t, 0, 2, 1, 3);
t = ggml_cast(m, t, type); // TODO: future flash attention supports f32 and permutations
return t;
};
tensor q = split(qkv, 0, GGML_TYPE_F32);
tensor k = split(qkv, 1, kv_type);
tensor v = split(qkv, 2, kv_type, !flash_attn);

tensor_name rel_pos_name = format<tensor_name>("window_attention_{}.rel_pos_index", window);
tensor rel_pos_index = ggml_get_tensor(m, rel_pos_name.c_str());
Expand All @@ -104,19 +86,10 @@ tensor window_attention(model_ref m, tensor x, tensor mask, int n_heads, int win
attn_mask = ggml_add(m, mask, attn_mask); // [n, n, n_heads, b] + [n, n, n_heads, 1]
}

if (flash_attn) {
x = ggml_flash_attn_ext(m, q, k, v, attn_mask, scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);
} else {
tensor attn = ggml_mul_mat(m, k, q);
attn = ggml_soft_max_ext(m, attn, attn_mask, scale, 0.0f);

x = ggml_mul_mat(m, v, attn);
x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
}
auto [q, k, v] = split_qkv(m["qkv"], x, n_heads, 2);
float scale = 1.0f / std::sqrt(float(c / n_heads));
x = attention(m, q, k, v, attn_mask, scale, m["proj"]);

x = ggml_reshape_3d(m, x, c, n, b);
x = linear(m["proj"], x);
return named(m, x);
}

Expand Down
68 changes: 66 additions & 2 deletions src/visp/nn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ tensor conv_2d(model_ref m, tensor x, int stride, int pad) {
x = ggml_mul_mat(m, weight, x);
x = ggml_reshape_4d(m, x, weight->ne[1], w, h, b);

} else if (m.flags & model_build_flag::conv_2d_direct_cwhn) {
} else if (m.flags & model_build_flag::conv_2d_direct_cwhn) {
weight = permute_cwhn_to_whcn(m, weight);
x = permute_cwhn_to_whcn(m, x);
x = ggml_conv_2d_direct(m, weight, x, stride, stride, pad, pad, 1, 1);
Expand Down Expand Up @@ -144,7 +144,7 @@ tensor conv_2d_deform(
}
}
x = ggml_conv_2d_deform(m, weight, x, offset, mask, stride, stride, pad, pad);

if (m.flags & model_build_flag::cwhn) {
x = permute_whcn_to_cwhn(m, x);
}
Expand Down Expand Up @@ -183,4 +183,68 @@ tensor patch_embed(model_ref m, tensor x, int patch_size) {
return named(m, x);
}

attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim) {
auto [c, n, b, _] = nelements(x);

tensor qkv = linear(m, x);
switch (split_dim) {
case 1:
qkv = ggml_reshape_4d(m, qkv, c / n_heads, 3, n_heads * n, b);
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 3, 1, 2));
break;
case 2:
qkv = ggml_reshape_4d(m, qkv, c / n_heads, n_heads, 3, n * b);
qkv = ggml_cont(m, ggml_permute(m, qkv, 0, 1, 3, 2));
break;
default: ASSERT(false, "Unsupported split_dim");
}

auto split = [&](tensor t, size_t index) mutable {
t = slice(m, t, {}, {}, {}, index);
t = ggml_reshape_4d(m, t, c / n_heads, n_heads, n, b);
return t;
};

tensor q = split(qkv, 0);
tensor k = split(qkv, 1);
tensor v = split(qkv, 2);
return {q, k, v};
}

tensor attention(
model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out) {

q = ggml_permute(m, q, 0, 2, 1, 3);
k = ggml_permute(m, k, 0, 2, 1, 3);

tensor x = nullptr;
if (m.flags & model_build_flag::flash_attention) {
v = ggml_permute(m, v, 0, 2, 1, 3);

k = ggml_cast(m, k, GGML_TYPE_F16);
v = ggml_cast(m, v, GGML_TYPE_F16);
if (mask && mask->type != GGML_TYPE_F16) {
mask = ggml_cast(m, mask, GGML_TYPE_F16);
}

x = ggml_flash_attn_ext(m, q, k, v, mask, scale, 0.0f, 0.0f);
ggml_flash_attn_ext_set_prec(x, GGML_PREC_F32);

} else {
v = ggml_cont(m, ggml_permute(m, v, 1, 2, 0, 3));

tensor attn = ggml_mul_mat(m, k, q);
attn = ggml_soft_max_ext(m, attn, mask, scale, 0.0f);
x = ggml_mul_mat(m, v, attn);

x = ggml_cont(m, ggml_permute(m, x, 0, 2, 1, 3));
}

// [head_dim, n_heads, n_patches, batch] -> [embed_dim, n_patches, batch]
x = ggml_reshape_3d(m, x, x->ne[0] * x->ne[1], x->ne[2], x->ne[3]);
x = linear(m_out, x);

return named(m, x);
}

} // namespace visp
11 changes: 11 additions & 0 deletions src/visp/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,15 @@ tensor batch_norm_2d(model_ref, tensor x);
// 2D image to patch embedding using convolution and optional norm. CWHN input and output.
tensor patch_embed(model_ref, tensor x, int patch_size);

struct attention_qkv {
tensor q, k, v;
};
// Input: x [head_dim*n_heads, n_patches, batch]
// Output: q, k, v each of shape [head_dim, n_heads, n_patches, batch]
attention_qkv split_qkv(model_ref m, tensor x, int n_heads, int split_dim);

// Attention with optional mask and output linear layer.
tensor attention(
model_ref m, tensor q, tensor k, tensor v, tensor mask, float scale, model_ref m_out);

} // namespace visp
Loading