@@ -711,7 +711,8 @@ namespace Flux {
711711 struct ggml_tensor * timesteps,
712712 struct ggml_tensor * y,
713713 struct ggml_tensor * guidance,
714- struct ggml_tensor * pe) {
714+ struct ggml_tensor * pe,
715+ std::vector<int > skip_layers = std::vector<int >()) {
715716 auto img_in = std::dynamic_pointer_cast<Linear>(blocks[" img_in" ]);
716717 auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" time_in" ]);
717718 auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" vector_in" ]);
@@ -733,6 +734,10 @@ namespace Flux {
733734 txt = txt_in->forward (ctx, txt);
734735
735736 for (int i = 0 ; i < params.depth ; i++) {
737+ if (skip_layers.size () > 0 && std::find (skip_layers.begin (), skip_layers.end (), i) != skip_layers.end ()) {
738+ continue ;
739+ }
740+
736741 auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks[" double_blocks." + std::to_string (i)]);
737742
738743 auto img_txt = block->forward (ctx, img, txt, vec, pe);
@@ -742,6 +747,9 @@ namespace Flux {
742747
743748 auto txt_img = ggml_concat (ctx, txt, img, 1 ); // [N, n_txt_token + n_img_token, hidden_size]
744749 for (int i = 0 ; i < params.depth_single_blocks ; i++) {
750+ if (skip_layers.size () > 0 && std::find (skip_layers.begin (), skip_layers.end (), i + params.depth ) != skip_layers.end ()) {
751+ continue ;
752+ }
745753 auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks[" single_blocks." + std::to_string (i)]);
746754
747755 txt_img = block->forward (ctx, txt_img, vec, pe);
@@ -769,7 +777,8 @@ namespace Flux {
769777 struct ggml_tensor * context,
770778 struct ggml_tensor * y,
771779 struct ggml_tensor * guidance,
772- struct ggml_tensor * pe) {
780+ struct ggml_tensor * pe,
781+ std::vector<int > skip_layers = std::vector<int >()) {
773782 // Forward pass of DiT.
774783 // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
775784 // timestep: (N,) tensor of diffusion timesteps
@@ -791,7 +800,7 @@ namespace Flux {
791800 // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
792801 auto img = patchify (ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
793802
794- auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
803+ auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, skip_layers ); // [N, h*w, C * patch_size * patch_size]
795804
796805 // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
797806 out = unpatchify (ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -829,7 +838,8 @@ namespace Flux {
829838 struct ggml_tensor * timesteps,
830839 struct ggml_tensor * context,
831840 struct ggml_tensor * y,
832- struct ggml_tensor * guidance) {
841+ struct ggml_tensor * guidance,
842+ std::vector<int > skip_layers = std::vector<int >()) {
833843 GGML_ASSERT (x->ne [3 ] == 1 );
834844 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
835845
@@ -856,7 +866,8 @@ namespace Flux {
856866 context,
857867 y,
858868 guidance,
859- pe);
869+ pe,
870+ skip_layers);
860871
861872 ggml_build_forward_expand (gf, out);
862873
@@ -870,14 +881,15 @@ namespace Flux {
870881 struct ggml_tensor * y,
871882 struct ggml_tensor * guidance,
872883 struct ggml_tensor ** output = NULL ,
873- struct ggml_context * output_ctx = NULL ) {
884+ struct ggml_context * output_ctx = NULL ,
885+ std::vector<int > skip_layers = std::vector<int >()) {
874886 // x: [N, in_channels, h, w]
875887 // timesteps: [N, ]
876888 // context: [N, max_position, hidden_size]
877889 // y: [N, adm_in_channels] or [1, adm_in_channels]
878890 // guidance: [N, ]
879891 auto get_graph = [&]() -> struct ggml_cgraph * {
880- return build_graph(x, timesteps, context, y, guidance);
892+ return build_graph(x, timesteps, context, y, guidance, skip_layers );
881893 };
882894
883895 GGMLRunner::compute (get_graph, n_threads, false , output, output_ctx);
0 commit comments