Skip to content

Commit b1e354f

Browse files
committed
Fix GEGLU slowdown for unet models
1 parent 11ab095 commit b1e354f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

common.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,12 @@ class GEGLU : public UnaryBlock {
193193
// return: [ne3, ne2, ne1, dim_out]
194194
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
195195

196-
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
197-
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
198-
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
199-
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
196+
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
200197

198+
auto gate = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], dim_out * x->nb[0]);
199+
x = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
200+
201+
gate = ggml_cont(ctx->ggml_ctx, gate);
201202
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
202203

203204
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]

0 commit comments

Comments
 (0)