Skip to content

Commit f6e9df9

Browse files
committed
slg support for flux (expermiental)
1 parent 58a288f commit f6e9df9

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

diffusion_model.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ struct UNetModel : public DiffusionModel {
7373
struct ggml_tensor** output = NULL,
7474
struct ggml_context* output_ctx = NULL,
7575
std::vector<int> skip_layers = std::vector<int>()) {
76+
(void)skip_layers; // SLG doesn't work with UNet models
7677
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7778
}
7879
};
@@ -173,7 +174,7 @@ struct FluxModel : public DiffusionModel {
173174
struct ggml_tensor** output = NULL,
174175
struct ggml_context* output_ctx = NULL,
175176
std::vector<int> skip_layers = std::vector<int>()) {
176-
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
177+
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx, skip_layers);
177178
}
178179
};
179180

flux.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,8 @@ namespace Flux {
711711
struct ggml_tensor* timesteps,
712712
struct ggml_tensor* y,
713713
struct ggml_tensor* guidance,
714-
struct ggml_tensor* pe) {
714+
struct ggml_tensor* pe,
715+
std::vector<int> skip_layers = std::vector<int>()) {
715716
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
716717
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
717718
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
@@ -733,6 +734,10 @@ namespace Flux {
733734
txt = txt_in->forward(ctx, txt);
734735

735736
for (int i = 0; i < params.depth; i++) {
737+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
738+
continue;
739+
}
740+
736741
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
737742

738743
auto img_txt = block->forward(ctx, img, txt, vec, pe);
@@ -742,6 +747,9 @@ namespace Flux {
742747

743748
auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size]
744749
for (int i = 0; i < params.depth_single_blocks; i++) {
750+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) {
751+
continue;
752+
}
745753
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
746754

747755
txt_img = block->forward(ctx, txt_img, vec, pe);
@@ -769,7 +777,8 @@ namespace Flux {
769777
struct ggml_tensor* context,
770778
struct ggml_tensor* y,
771779
struct ggml_tensor* guidance,
772-
struct ggml_tensor* pe) {
780+
struct ggml_tensor* pe,
781+
std::vector<int> skip_layers = std::vector<int>()) {
773782
// Forward pass of DiT.
774783
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
775784
// timestep: (N,) tensor of diffusion timesteps
@@ -791,7 +800,7 @@ namespace Flux {
791800
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
792801
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
793802

794-
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size]
803+
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, h*w, C * patch_size * patch_size]
795804

796805
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
797806
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
@@ -829,7 +838,8 @@ namespace Flux {
829838
struct ggml_tensor* timesteps,
830839
struct ggml_tensor* context,
831840
struct ggml_tensor* y,
832-
struct ggml_tensor* guidance) {
841+
struct ggml_tensor* guidance,
842+
std::vector<int> skip_layers = std::vector<int>()) {
833843
GGML_ASSERT(x->ne[3] == 1);
834844
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
835845

@@ -856,7 +866,8 @@ namespace Flux {
856866
context,
857867
y,
858868
guidance,
859-
pe);
869+
pe,
870+
skip_layers);
860871

861872
ggml_build_forward_expand(gf, out);
862873

@@ -870,14 +881,15 @@ namespace Flux {
870881
struct ggml_tensor* y,
871882
struct ggml_tensor* guidance,
872883
struct ggml_tensor** output = NULL,
873-
struct ggml_context* output_ctx = NULL) {
884+
struct ggml_context* output_ctx = NULL,
885+
std::vector<int> skip_layers = std::vector<int>()) {
874886
// x: [N, in_channels, h, w]
875887
// timesteps: [N, ]
876888
// context: [N, max_position, hidden_size]
877889
// y: [N, adm_in_channels] or [1, adm_in_channels]
878890
// guidance: [N, ]
879891
auto get_graph = [&]() -> struct ggml_cgraph* {
880-
return build_graph(x, timesteps, context, y, guidance);
892+
return build_graph(x, timesteps, context, y, guidance, skip_layers);
881893
};
882894

883895
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);

0 commit comments

Comments
 (0)