Skip to content

Commit 0835e5c

Browse files
committed
optimize ggml_ext_chunk
1 parent 11ab095 commit 0835e5c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

common.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ class GEGLU : public UnaryBlock {
194194
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
195195

196196
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
197-
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
197+
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false);
198198
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
199199
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
200200

201+
gate = ggml_cont(ctx->ggml_ctx, gate);
202+
201203
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
202204

203205
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]

ggml_extend.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
732732
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
733733
struct ggml_tensor* x,
734734
int num,
735-
int64_t dim) {
735+
int64_t dim,
736+
bool cont = true) {
736737
GGML_ASSERT(dim >= 0 && dim < 4);
737738
GGML_ASSERT(x->ne[dim] % num == 0);
738739

@@ -747,7 +748,9 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
747748

748749
if (dim != 3) {
749750
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
750-
x = ggml_cont(ctx, x);
751+
if (cont) {
752+
x = ggml_cont(ctx, x);
753+
}
751754
}
752755

753756
std::vector<struct ggml_tensor*> chunks;
@@ -760,7 +763,9 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
760763

761764
if (dim != 3) {
762765
chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
763-
chunk = ggml_cont(ctx, chunk);
766+
if (cont) {
767+
chunk = ggml_cont(ctx, chunk);
768+
}
764769
}
765770
chunks.push_back(chunk);
766771
}

0 commit comments

Comments
 (0)