@@ -106,6 +106,11 @@ struct SDParams {
106106 // Photomaker params
107107 std::string input_id_images_path;
108108
109+ std::vector<int > skip_layers = {7 , 8 , 9 };
110+ float slg_scale = 2.5 ;
111+ float skip_layer_start = 0.01 ;
112+ float skip_layer_end = 0.2 ;
113+
109114 // server things
110115 int port = 8080 ;
111116 std::string host = " 127.0.0.1" ;
@@ -169,6 +174,11 @@ void print_usage(int argc, const char* argv[]) {
169174 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
170175 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
171176 printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
177+ printf (" --slg enable skip layer guidance (CFG variant)\n " );
178+ printf (" --skip_layers LAYERS Layers to skip for skip layer CFG (requires --slg): (default: [7,8,9])\n " );
179+ printf (" --slg-scale SCALE skip layer guidance scale (requires --slg): (default: 2.5)\n " );
180+ printf (" --skip_layer_start START skip layer enabling point (* steps) (requires --slg): (default: 0.01)\n " );
181+ printf (" --skip_layer_end END skip layer enabling point (* steps) (requires --slg): (default: 0.2)\n " );
172182 printf (" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n " );
173183 printf (" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n " );
174184 printf (" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n " );
@@ -195,6 +205,7 @@ void print_usage(int argc, const char* argv[]) {
195205
196206void parse_args (int argc, const char ** argv, SDParams& params) {
197207 bool invalid_arg = false ;
208+ bool cfg_skip = false ;
198209 std::string arg;
199210 for (int i = 1 ; i < argc; i++) {
200211 arg = argv[i];
@@ -420,6 +431,63 @@ void parse_args(int argc, const char** argv, SDParams& params) {
420431 params.verbose = true ;
421432 } else if (arg == " --color" ) {
422433 params.color = true ;
434+ } else if (arg == " --slg" ) {
435+ cfg_skip = true ;
436+ } else if (arg == " --skip-layers" ) {
437+ if (++i >= argc) {
438+ invalid_arg = true ;
439+ break ;
440+ }
441+ if (argv[i][0 ] != ' [' ) {
442+ invalid_arg = true ;
443+ break ;
444+ }
445+ std::string layers_str = argv[i];
446+ while (layers_str.back () != ' ]' ) {
447+ if (++i >= argc) {
448+ invalid_arg = true ;
449+ break ;
450+ }
451+ layers_str += " " + std::string (argv[i]);
452+ }
453+ layers_str = layers_str.substr (1 , layers_str.size () - 2 );
454+
455+ std::regex regex (" [, ]+" );
456+ std::sregex_token_iterator iter (layers_str.begin (), layers_str.end (), regex, -1 );
457+ std::sregex_token_iterator end;
458+ std::vector<std::string> tokens (iter, end);
459+ std::vector<int > layers;
460+ for (const auto & token : tokens) {
461+ try {
462+ layers.push_back (std::stoi (token));
463+ } catch (const std::invalid_argument& e) {
464+ invalid_arg = true ;
465+ break ;
466+ }
467+ }
468+ params.skip_layers = layers;
469+
470+ if (invalid_arg) {
471+ break ;
472+ }
473+ } else if (arg == " --slg-scale" ) {
474+ if (++i >= argc) {
475+ invalid_arg = true ;
476+ break ;
477+ }
478+ params.slg_scale = std::stof (argv[i]);
479+ } else if (arg == " --skip-layer-start" ) {
480+ if (++i >= argc) {
481+ invalid_arg = true ;
482+ break ;
483+ }
484+ params.skip_layer_start = std::stof (argv[i]);
485+ } else if (arg == " --skip-layer-end" ) {
486+ if (++i >= argc) {
487+ invalid_arg = true ;
488+ break ;
489+ }
490+ params.skip_layer_end = std::stof (argv[i]);
423491 } else if (arg == " --port" ) {
424492 if (++i >= argc) {
425493 invalid_arg = true ;
@@ -447,6 +515,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
447515 params.n_threads = get_num_physical_cores ();
448516 }
449517
518+ if (!cfg_skip) {
519+ // set skip_layers to empty
520+ params.skip_layers .clear ();
521+ }
522+
450523 if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt .length () == 0 ) {
451524 fprintf (stderr, " error: the following arguments are required: prompt\n " );
452525 print_usage (argc, argv);
@@ -917,6 +990,10 @@ int main(int argc, const char* argv[]) {
917990 params.style_ratio ,
918991 params.normalize_input ,
919992 params.input_id_images_path .c_str (),
993+ params.skip_layers ,
994+ params.slg_scale ,
995+ params.skip_layer_start ,
996+ params.skip_layer_end ,
920997 (step_callback_t )step_callback);
921998
922999 if (results == NULL ) {
0 commit comments