diff --git a/common.hpp b/common.hpp index 33d499fb1..c146d46d8 100644 --- a/common.hpp +++ b/common.hpp @@ -193,11 +193,12 @@ class GEGLU : public UnaryBlock { // return: [ne3, ne2, ne1, dim_out] auto proj = std::dynamic_pointer_cast(blocks["proj"]); - x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] - auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); - x = x_vec[0]; // [ne3, ne2, ne1, dim_out] - auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] + x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] + auto gate = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], dim_out * x->nb[0]); + x = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + + gate = ggml_cont(ctx->ggml_ctx, gate); gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]