@@ -258,24 +258,29 @@ class TinyVideoEncoder : public UnaryBlock {
258258 blocks[std::to_string (index++)] = std::shared_ptr<GGMLBlock>(new MemBlock (hidden, hidden));
259259 }
260260 }
261- blocks[std::to_string (index++ )] = std::shared_ptr<GGMLBlock>(new Conv2d (hidden, z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
261+ blocks[std::to_string (index)] = std::shared_ptr<GGMLBlock>(new Conv2d (hidden, z_channels, {3 , 3 }, {1 , 1 }, {1 , 1 }));
262262 }
263263
264264 struct ggml_tensor * forward (GGMLRunnerContext* ctx, struct ggml_tensor * z) override {
265- // return z;
266265 auto first_conv = std::dynamic_pointer_cast<Conv2d>(blocks[" 0" ]);
267- auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (num_layers * (num_blocks + 2 ) + 1 )]);
268266 auto h = first_conv->forward (ctx, z);
269-
270267 h = ggml_relu_inplace (ctx->ggml_ctx , h);
271-
272- for (int i = 2 ; i < num_layers * (num_blocks + 2 ) + 2 ; i++) {
273- if (blocks.find (std::to_string (i)) == blocks.end ()) {
274- continue ;
268+
269+ int index = 2 ;
270+ for (int i = 0 ; i < num_layers; i++) {
271+ auto pool = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string (index++)]);
272+ auto conv = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string (index++)]);
273+
274+ h = pool->forward (ctx, h);
275+ h = conv->forward (ctx, h);
276+ for (int j = 0 ; j < num_blocks; j++) {
277+ auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string (index++)]);
278+ auto mem = ggml_pad_ext (ctx->ggml_ctx , h, 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 );
279+ mem = ggml_view_4d (ctx->ggml_ctx , mem, h->ne [0 ], h->ne [1 ], h->ne [2 ], h->ne [3 ], h->nb [1 ], h->nb [2 ], h->nb [3 ], 0 );
280+ h = block->forward (ctx, h, mem);
275281 }
276- auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string (i)]);
277- h = block->forward (ctx, h);
278282 }
283+ auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (index)]);
279284 h = last_conv->forward (ctx, h);
280285 return h;
281286 }
@@ -322,7 +327,7 @@ class TinyVideoDecoder : public UnaryBlock {
322327 for (int i = 0 ; i < num_layers; i++) {
323328 for (int j = 0 ; j < num_blocks; j++) {
324329 auto block = std::dynamic_pointer_cast<MemBlock>(blocks[std::to_string (index++)]);
325- auto mem = ggml_pad_ext (ctx->ggml_ctx , h, 0 , 0 , 0 , 0 , 0 , 0 , 1 ,0 );
330+ auto mem = ggml_pad_ext (ctx->ggml_ctx , h, 0 , 0 , 0 , 0 , 0 , 0 , 1 , 0 );
326331 mem = ggml_view_4d (ctx->ggml_ctx , mem, h->ne [0 ], h->ne [1 ], h->ne [2 ], h->ne [3 ], h->nb [1 ], h->nb [2 ], h->nb [3 ], 0 );
327332 h = block->forward (ctx, h, mem);
328333 }
@@ -339,7 +344,7 @@ class TinyVideoDecoder : public UnaryBlock {
339344 auto last_conv = std::dynamic_pointer_cast<Conv2d>(blocks[std::to_string (++index)]);
340345 h = last_conv->forward (ctx, h);
341346
342- // shape(W, H, 3, T+3 ) => shape(W, H, 3, T)
347+ // shape(W, H, 3, 3 + T ) => shape(W, H, 3, T)
343348 h = ggml_view_4d (ctx->ggml_ctx , h, h->ne [0 ], h->ne [1 ], h->ne [2 ], h->ne [3 ] - 3 , h->nb [1 ], h->nb [2 ], h->nb [3 ], 3 * h->nb [3 ]);
344349 return h;
345350 }
@@ -376,9 +381,20 @@ class TAEHV : public GGMLBlock {
376381 }
377382
378383 struct ggml_tensor * encode (GGMLRunnerContext* ctx, struct ggml_tensor * x) {
379- return nullptr ;
380- // auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks["encoder"]);
381- // return encoder->forward(ctx, x);
384+ auto encoder = std::dynamic_pointer_cast<TinyVideoEncoder>(blocks[" encoder" ]);
385+ // (W, H, T, C) -> (W, H, C, T)
386+ x = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , x, 0 , 1 , 3 , 2 ));
387+ int64_t num_frames = x->ne [3 ];
388+ if (num_frames % 4 ) {
389+ // pad to multiple of 4 at the end
390+ auto last_frame = ggml_view_4d (ctx->ggml_ctx , x, x->ne [0 ], x->ne [1 ], x->ne [2 ], 1 , x->nb [1 ], x->nb [2 ], x->nb [3 ], (num_frames - 1 ) * x->nb [3 ]);
391+ for (int i = 0 ; i < 4 - num_frames % 4 ; i++) {
392+ x = ggml_concat (ctx->ggml_ctx , x, last_frame, 3 );
393+ }
394+ }
395+ x = encoder->forward (ctx, x);
396+ x = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , x, 0 , 1 , 3 , 2 ));
397+ return x;
382398 }
383399};
384400
0 commit comments