Skip to content

Commit 3760488

Browse files
committed
add skip layer guidance support (mmdit only)
1 parent 94f29c7 commit 3760488

File tree

5 files changed

+172
-19
lines changed

5 files changed

+172
-19
lines changed

diffusion_model.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ struct DiffusionModel {
1717
std::vector<struct ggml_tensor*> controls = {},
1818
float control_strength = 0.f,
1919
struct ggml_tensor** output = NULL,
20-
struct ggml_context* output_ctx = NULL) = 0;
20+
struct ggml_context* output_ctx = NULL,
21+
std::vector<int> skip_layers = std::vector<int>()) = 0;
2122
virtual void alloc_params_buffer() = 0;
2223
virtual void free_params_buffer() = 0;
2324
virtual void free_compute_buffer() = 0;
@@ -71,7 +72,8 @@ struct UNetModel : public DiffusionModel {
7172
std::vector<struct ggml_tensor*> controls = {},
7273
float control_strength = 0.f,
7374
struct ggml_tensor** output = NULL,
74-
struct ggml_context* output_ctx = NULL) {
75+
struct ggml_context* output_ctx = NULL,
76+
std::vector<int> skip_layers = std::vector<int>()) {
7577
return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx);
7678
}
7779
};
@@ -120,8 +122,9 @@ struct MMDiTModel : public DiffusionModel {
120122
std::vector<struct ggml_tensor*> controls = {},
121123
float control_strength = 0.f,
122124
struct ggml_tensor** output = NULL,
123-
struct ggml_context* output_ctx = NULL) {
124-
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx);
125+
struct ggml_context* output_ctx = NULL,
126+
std::vector<int> skip_layers = std::vector<int>()) {
127+
return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
125128
}
126129
};
127130

@@ -169,7 +172,8 @@ struct FluxModel : public DiffusionModel {
169172
std::vector<struct ggml_tensor*> controls = {},
170173
float control_strength = 0.f,
171174
struct ggml_tensor** output = NULL,
172-
struct ggml_context* output_ctx = NULL) {
175+
struct ggml_context* output_ctx = NULL,
176+
std::vector<int> skip_layers = std::vector<int>()) {
173177
return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx);
174178
}
175179
};

examples/cli/main.cpp

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,19 @@ struct SDParams {
120120
bool canny_preprocess = false;
121121
bool color = false;
122122
int upscale_repeats = 1;
123+
124+
std::vector<int> skip_layers = {7, 8, 9};
125+
float slg_scale = 2.5;
126+
float skip_layer_start = 0.01;
127+
float skip_layer_end = 0.2;
123128
};
124129

125130
void print_params(SDParams params) {
126131
printf("Option: \n");
127132
printf(" n_threads: %d\n", params.n_threads);
128133
printf(" mode: %s\n", modes_str[params.mode]);
129134
printf(" model_path: %s\n", params.model_path.c_str());
130-
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
135+
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
131136
printf(" fallback_type: %s\n", params.ftype < SD_TYPE_COUNT ? sd_type_name(params.ftype) : "unspecified");
132137
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
133138
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
@@ -201,6 +206,11 @@ void print_usage(int argc, const char* argv[]) {
201206
printf(" -p, --prompt [PROMPT] the prompt to render\n");
202207
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
203208
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
209+
printf(" --slg enable skip layer guidance (CFG variant)\n");
210+
printf(" --skip_layers LAYERS Layers to skip for skip layer CFG (requires --slg): (default: [7,8,9])\n");
211+
printf(" --slg-scale SCALE skip layer guidance scale (requires --slg): (default: 2.5)\n");
212+
printf(" --skip_layer_start START skip layer enabling point (* steps) (requires --slg): (default: 0.01)\n");
213+
printf(" --skip_layer_end END skip layer enabling point (* steps) (requires --slg): (default: 0.2)\n");
204214
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
205215
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
206216
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
@@ -227,6 +237,7 @@ void print_usage(int argc, const char* argv[]) {
227237

228238
void parse_args(int argc, const char** argv, SDParams& params) {
229239
bool invalid_arg = false;
240+
bool cfg_skip = false;
230241
std::string arg;
231242
for (int i = 1; i < argc; i++) {
232243
arg = argv[i];
@@ -563,6 +574,63 @@ void parse_args(int argc, const char** argv, SDParams& params) {
563574
params.verbose = true;
564575
} else if (arg == "--color") {
565576
params.color = true;
577+
} else if (arg == "--slg") {
578+
cfg_skip = true;
579+
} else if (arg == "--skip-layers") {
580+
if (++i >= argc) {
581+
invalid_arg = true;
582+
break;
583+
}
584+
if (argv[i][0] != '[') {
585+
invalid_arg = true;
586+
break;
587+
}
588+
std::string layers_str = argv[i];
589+
while (layers_str.back() != ']') {
590+
if (++i >= argc) {
591+
invalid_arg = true;
592+
break;
593+
}
594+
layers_str += " " + std::string(argv[i]);
595+
}
596+
layers_str = layers_str.substr(1, layers_str.size() - 2);
597+
598+
std::regex regex("[, ]+");
599+
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
600+
std::sregex_token_iterator end;
601+
std::vector<std::string> tokens(iter, end);
602+
std::vector<int> layers;
603+
for (const auto& token : tokens) {
604+
try {
605+
layers.push_back(std::stoi(token));
606+
} catch (const std::invalid_argument& e) {
607+
invalid_arg = true;
608+
break;
609+
}
610+
}
611+
params.skip_layers = layers;
612+
613+
if (invalid_arg) {
614+
break;
615+
}
616+
} else if (arg == "--slg-scale") {
617+
if (++i >= argc) {
618+
invalid_arg = true;
619+
break;
620+
}
621+
params.slg_scale = std::stof(argv[i]);
622+
} else if (arg == "--skip-layer-start") {
623+
if (++i >= argc) {
624+
invalid_arg = true;
625+
break;
626+
}
627+
params.skip_layer_start = std::stof(argv[i]);
628+
} else if (arg == "--skip-layer-end") {
629+
if (++i >= argc) {
630+
invalid_arg = true;
631+
break;
632+
}
633+
params.skip_layer_end = std::stof(argv[i]);
566634
} else {
567635
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
568636
print_usage(argc, argv);
@@ -578,6 +646,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
578646
params.n_threads = get_num_physical_cores();
579647
}
580648

649+
if (!cfg_skip) {
650+
// set skip_layers to empty
651+
params.skip_layers.clear();
652+
}
653+
581654
if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) {
582655
fprintf(stderr, "error: the following arguments are required: prompt\n");
583656
print_usage(argc, argv);
@@ -771,7 +844,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
771844

772845
const float(*latent_rgb_proj)[channel];
773846

774-
775847
if (dim == 16) {
776848
// 16 channels VAE -> Flux or SD3
777849

@@ -990,6 +1062,10 @@ int main(int argc, const char* argv[]) {
9901062
params.style_ratio,
9911063
params.normalize_input,
9921064
params.input_id_images_path.c_str(),
1065+
params.skip_layers,
1066+
params.slg_scale,
1067+
params.skip_layer_start,
1068+
params.skip_layer_end,
9931069
(step_callback_t)step_callback);
9941070
} else {
9951071
sd_image_t input_image = {(uint32_t)params.width,

mmdit.hpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -803,14 +803,20 @@ struct MMDiT : public GGMLBlock {
803803
struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx,
804804
struct ggml_tensor* x,
805805
struct ggml_tensor* c_mod,
806-
struct ggml_tensor* context) {
806+
struct ggml_tensor* context,
807+
std::vector<int> skip_layers = std::vector<int>()) {
807808
// x: [N, H*W, hidden_size]
808809
// context: [N, n_context, d_context]
809810
// c: [N, hidden_size]
810811
// return: [N, N*W, patch_size * patch_size * out_channels]
811812
auto final_layer = std::dynamic_pointer_cast<FinalLayer>(blocks["final_layer"]);
812813

813814
for (int i = 0; i < depth; i++) {
815+
// skip iteration if i is in skip_layers
816+
if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) {
817+
continue;
818+
}
819+
814820
auto block = std::dynamic_pointer_cast<JointBlock>(blocks["joint_blocks." + std::to_string(i)]);
815821

816822
auto context_x = block->forward(ctx, context, x, c_mod);
@@ -826,8 +832,9 @@ struct MMDiT : public GGMLBlock {
826832
struct ggml_tensor* forward(struct ggml_context* ctx,
827833
struct ggml_tensor* x,
828834
struct ggml_tensor* t,
829-
struct ggml_tensor* y = NULL,
830-
struct ggml_tensor* context = NULL) {
835+
struct ggml_tensor* y = NULL,
836+
struct ggml_tensor* context = NULL,
837+
std::vector<int> skip_layers = std::vector<int>()) {
831838
// Forward pass of DiT.
832839
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
833840
// t: (N,) tensor of diffusion timesteps
@@ -858,7 +865,7 @@ struct MMDiT : public GGMLBlock {
858865
context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536]
859866
}
860867

861-
x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels)
868+
x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels)
862869

863870
x = unpatchify(ctx, x, h, w); // [N, C, H, W]
864871

@@ -889,7 +896,8 @@ struct MMDiTRunner : public GGMLRunner {
889896
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
890897
struct ggml_tensor* timesteps,
891898
struct ggml_tensor* context,
892-
struct ggml_tensor* y) {
899+
struct ggml_tensor* y,
900+
std::vector<int> skip_layers = std::vector<int>()) {
893901
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false);
894902

895903
x = to_backend(x);
@@ -901,7 +909,8 @@ struct MMDiTRunner : public GGMLRunner {
901909
x,
902910
timesteps,
903911
y,
904-
context);
912+
context,
913+
skip_layers);
905914

906915
ggml_build_forward_expand(gf, out);
907916

@@ -914,13 +923,14 @@ struct MMDiTRunner : public GGMLRunner {
914923
struct ggml_tensor* context,
915924
struct ggml_tensor* y,
916925
struct ggml_tensor** output = NULL,
917-
struct ggml_context* output_ctx = NULL) {
926+
struct ggml_context* output_ctx = NULL,
927+
std::vector<int> skip_layers = std::vector<int>()) {
918928
// x: [N, in_channels, h, w]
919929
// timesteps: [N, ]
920930
// context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size]
921931
// y: [N, adm_in_channels] or [1, adm_in_channels]
922932
auto get_graph = [&]() -> struct ggml_cgraph* {
923-
return build_graph(x, timesteps, context, y);
933+
return build_graph(x, timesteps, context, y, skip_layers);
924934
};
925935

926936
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);

0 commit comments

Comments
 (0)