Skip to content

Commit bd9b0f1

Browse files
committed
server. basic slg support
1 parent 3760488 commit bd9b0f1

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

examples/server/main.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ struct SDParams {
106106
// Photomaker params
107107
std::string input_id_images_path;
108108

109+
std::vector<int> skip_layers = {7, 8, 9};
110+
float slg_scale = 2.5;
111+
float skip_layer_start = 0.01;
112+
float skip_layer_end = 0.2;
113+
109114
// server things
110115
int port = 8080;
111116
std::string host = "127.0.0.1";
@@ -169,6 +174,11 @@ void print_usage(int argc, const char* argv[]) {
169174
printf(" -p, --prompt [PROMPT] the prompt to render\n");
170175
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
171176
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
177+
printf(" --slg enable skip layer guidance (CFG variant)\n");
178+
printf(" --skip_layers LAYERS Layers to skip for skip layer CFG (requires --slg): (default: [7,8,9])\n");
179+
printf(" --slg-scale SCALE skip layer guidance scale (requires --slg): (default: 2.5)\n");
180+
printf(" --skip_layer_start START skip layer enabling point (* steps) (requires --slg): (default: 0.01)\n");
181+
printf(" --skip_layer_end END skip layer enabling point (* steps) (requires --slg): (default: 0.2)\n");
172182
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
173183
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
174184
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
@@ -195,6 +205,7 @@ void print_usage(int argc, const char* argv[]) {
195205

196206
void parse_args(int argc, const char** argv, SDParams& params) {
197207
bool invalid_arg = false;
208+
bool cfg_skip = false;
198209
std::string arg;
199210
for (int i = 1; i < argc; i++) {
200211
arg = argv[i];
@@ -420,6 +431,63 @@ void parse_args(int argc, const char** argv, SDParams& params) {
420431
params.verbose = true;
421432
} else if (arg == "--color") {
422433
params.color = true;
434+
} else if (arg == "--slg") {
435+
cfg_skip = true;
436+
} else if (arg == "--skip-layers") {
437+
if (++i >= argc) {
438+
invalid_arg = true;
439+
break;
440+
}
441+
if (argv[i][0] != '[') {
442+
invalid_arg = true;
443+
break;
444+
}
445+
std::string layers_str = argv[i];
446+
while (layers_str.back() != ']') {
447+
if (++i >= argc) {
448+
invalid_arg = true;
449+
break;
450+
}
451+
layers_str += " " + std::string(argv[i]);
452+
}
453+
layers_str = layers_str.substr(1, layers_str.size() - 2);
454+
455+
std::regex regex("[, ]+");
456+
std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
457+
std::sregex_token_iterator end;
458+
std::vector<std::string> tokens(iter, end);
459+
std::vector<int> layers;
460+
for (const auto& token : tokens) {
461+
try {
462+
layers.push_back(std::stoi(token));
463+
} catch (const std::invalid_argument& e) {
464+
invalid_arg = true;
465+
break;
466+
}
467+
}
468+
params.skip_layers = layers;
469+
470+
if (invalid_arg) {
471+
break;
472+
}
473+
} else if (arg == "--slg-scale") {
474+
if (++i >= argc) {
475+
invalid_arg = true;
476+
break;
477+
}
478+
params.slg_scale = std::stof(argv[i]);
479+
} else if (arg == "--skip-layer-start") {
480+
if (++i >= argc) {
481+
invalid_arg = true;
482+
break;
483+
}
484+
params.skip_layer_start = std::stof(argv[i]);
485+
} else if (arg == "--skip-layer-end") {
486+
if (++i >= argc) {
487+
invalid_arg = true;
488+
break;
489+
}
490+
params.skip_layer_end = std::stof(argv[i]);
423491
} else if (arg == "--port") {
424492
if (++i >= argc) {
425493
invalid_arg = true;
@@ -447,6 +515,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
447515
params.n_threads = get_num_physical_cores();
448516
}
449517

518+
if (!cfg_skip) {
519+
// set skip_layers to empty
520+
params.skip_layers.clear();
521+
}
522+
450523
if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) {
451524
fprintf(stderr, "error: the following arguments are required: prompt\n");
452525
print_usage(argc, argv);
@@ -917,6 +990,10 @@ int main(int argc, const char* argv[]) {
917990
params.style_ratio,
918991
params.normalize_input,
919992
params.input_id_images_path.c_str(),
993+
params.skip_layers,
994+
params.slg_scale,
995+
params.skip_layer_start,
996+
params.skip_layer_end,
920997
(step_callback_t)step_callback);
921998

922999
if (results == NULL) {

0 commit comments

Comments
 (0)