We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 11ab095 commit b1e354fCopy full SHA for b1e354f
common.hpp
@@ -193,11 +193,12 @@ class GEGLU : public UnaryBlock {
193
// return: [ne3, ne2, ne1, dim_out]
194
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
195
196
- x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
197
- auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
198
- x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
199
- auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
+ x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
200
+ auto gate = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], dim_out * x->nb[0]);
+ x = ggml_view_4d(ctx->ggml_ctx, x, dim_out, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0);
+
201
+ gate = ggml_cont(ctx->ggml_ctx, gate);
202
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
203
204
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
0 commit comments