Skip to content

Commit 9dcea94

Browse files
committed
add timestep-shift and remove normalize-input
1 parent c020301 commit 9dcea94

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

examples/sd-server/main.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,9 @@ struct GenerationRequest {
394394
scheduler_t scheduler = DEFAULT;
395395
int batch_count = 1;
396396
int64_t seed = -1;
397-
bool normalize_input = false;
398397
float eta = 0.0f;
399398
bool has_eta = false;
399+
int shifted_timestep = 0;
400400
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
401401
bool has_vae_tiling_override = false;
402402
};
@@ -700,6 +700,25 @@ bool parse_generation_request(const json& body, GenerationRequest& request, std:
700700
request.has_eta = true;
701701
}
702702

703+
auto shift_it = body.find("timestep_shift");
704+
const char* shift_field_name = "timestep_shift";
705+
if (shift_it == body.end()) {
706+
shift_it = body.find("shifted_timestep");
707+
shift_field_name = "shifted_timestep";
708+
}
709+
if (shift_it != body.end()) {
710+
if (!shift_it->is_number_integer()) {
711+
error = std::string("field '") + shift_field_name + "' must be an integer";
712+
return false;
713+
}
714+
int value = static_cast<int>(shift_it->get<int64_t>());
715+
if (value < 0 || value > 1000) {
716+
error = std::string("field '") + shift_field_name + "' must be between 0 and 1000";
717+
return false;
718+
}
719+
request.shifted_timestep = value;
720+
}
721+
703722
auto batch_it = body.find("batch_count");
704723
if (batch_it != body.end()) {
705724
if (!batch_it->is_number_integer()) {
@@ -722,15 +741,6 @@ bool parse_generation_request(const json& body, GenerationRequest& request, std:
722741
request.seed = seed_it->get<int64_t>();
723742
}
724743

725-
auto norm_it = body.find("normalize_input");
726-
if (norm_it != body.end()) {
727-
if (!norm_it->is_boolean()) {
728-
error = "field 'normalize_input' must be a boolean";
729-
return false;
730-
}
731-
request.normalize_input = norm_it->get<bool>();
732-
}
733-
734744
auto method_it = body.find("sample_method");
735745
if (method_it != body.end()) {
736746
if (!method_it->is_string()) {
@@ -1056,7 +1066,6 @@ int main(int argc, char** argv) {
10561066
img_params.height = request_params.height;
10571067
img_params.batch_count = request_params.batch_count;
10581068
img_params.seed = effective_seed;
1059-
img_params.normalize_input = request_params.normalize_input;
10601069
if (request_params.has_vae_tiling_override) {
10611070
img_params.vae_tiling_params = request_params.vae_tiling_params;
10621071
}
@@ -1082,6 +1091,7 @@ int main(int argc, char** argv) {
10821091
if (request_params.has_eta) {
10831092
sample_params.eta = request_params.eta;
10841093
}
1094+
sample_params.shifted_timestep = request_params.shifted_timestep;
10851095

10861096
auto start_time = std::chrono::steady_clock::now();
10871097
sd_image_t* results = generate_image(state.ctx, &img_params);

0 commit comments

Comments
 (0)