Skip to content

Commit baf122d

Browse files
committed
taew2.1 encode support
1 parent fde734b commit baf122d

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

tae.hpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -258,24 +258,29 @@ class TinyVideoEncoder : public UnaryBlock {
258258
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new MemBlock(hidden, hidden));
259259
}
260260
}
261-
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1}));
261+
blocks[std::to_string(index)] = std::shared_ptr<GGMLBlock>(new Conv2d(hidden, z_channels, {3, 3}, {1, 1}, {1, 1}));
262262
}
263263

264264
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* z) override {
265-
// return z;
266265
auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks["0"]);
267-
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(num_layers * (num_blocks + 2) + 1)]);
268266
auto h = first_conv->forward(ctx, z);
269-
270267
h = ggml_relu_inplace(ctx->ggml_ctx, h);
271-
272-
for (int i = 2; i < num_layers * (num_blocks + 2) + 2; i++) {
273-
if (blocks.find(std::to_string(i)) == blocks.end()) {
274-
continue;
268+
269+
int index = 2;
270+
for (int i = 0; i < num_layers; i++) {
271+
auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
272+
auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(index++)]);
273+
274+
h = pool->forward(ctx, h);
275+
h = conv->forward(ctx, h);
276+
for (int j = 0; j < num_blocks; j++) {
277+
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
278+
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
279+
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
280+
h = block->forward(ctx, h, mem);
275281
}
276-
auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(i)]);
277-
h = block->forward(ctx, h);
278282
}
283+
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(index)]);
279284
h = last_conv->forward(ctx, h);
280285
return h;
281286
}
@@ -322,7 +327,7 @@ class TinyVideoDecoder : public UnaryBlock {
322327
for (int i = 0; i < num_layers; i++) {
323328
for (int j = 0; j < num_blocks; j++) {
324329
auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string(index++)]);
325-
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1,0);
330+
auto mem = ggml_pad_ext(ctx->ggml_ctx, h, 0, 0, 0, 0, 0, 0, 1, 0);
326331
mem = ggml_view_4d(ctx->ggml_ctx, mem, h->ne[0], h->ne[1], h->ne[2], h->ne[3], h->nb[1], h->nb[2], h->nb[3], 0);
327332
h = block->forward(ctx, h, mem);
328333
}
@@ -339,7 +344,7 @@ class TinyVideoDecoder : public UnaryBlock {
339344
auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string(++index)]);
340345
h = last_conv->forward(ctx, h);
341346

342-
// shape(W, H, 3, T+3) => shape(W, H, 3, T)
347+
// shape(W, H, 3, 3 + T) => shape(W, H, 3, T)
343348
h = ggml_view_4d(ctx->ggml_ctx, h, h->ne[0], h->ne[1], h->ne[2], h->ne[3] - 3, h->nb[1], h->nb[2], h->nb[3], 3 * h->nb[3]);
344349
return h;
345350
}
@@ -376,9 +381,20 @@ class TAEHV : public GGMLBlock {
376381
}
377382

378383
struct ggml_tensor* encode(GGMLRunnerContext* ctx, struct ggml_tensor* x) {
379-
return nullptr;
380-
// auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
381-
// return encoder->forward(ctx, x);
384+
auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
385+
// (W, H, T, C) -> (W, H, C, T)
386+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
387+
int64_t num_frames = x->ne[3];
388+
if (num_frames % 4) {
389+
// pad to multiple of 4 at the end
390+
auto last_frame = ggml_view_4d(ctx->ggml_ctx, x, x->ne[0], x->ne[1], x->ne[2], 1, x->nb[1], x->nb[2], x->nb[3], (num_frames - 1) * x->nb[3]);
391+
for (int i = 0; i < 4 - num_frames % 4; i++) {
392+
x = ggml_concat(ctx->ggml_ctx, x, last_frame, 3);
393+
}
394+
}
395+
x = encoder->forward(ctx, x);
396+
x = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 0, 1, 3, 2));
397+
return x;
382398
}
383399
};
384400

0 commit comments

Comments
 (0)