@@ -394,9 +394,9 @@ struct GenerationRequest {
394394 scheduler_t scheduler = DEFAULT;
395395 int batch_count = 1 ;
396396 int64_t seed = -1 ;
397- bool normalize_input = false ;
398397 float eta = 0 .0f ;
399398 bool has_eta = false ;
399+ int shifted_timestep = 0 ;
400400 sd_tiling_params_t vae_tiling_params = {false , 0 , 0 , 0 .5f , 0 .0f , 0 .0f };
401401 bool has_vae_tiling_override = false ;
402402};
@@ -700,6 +700,25 @@ bool parse_generation_request(const json& body, GenerationRequest& request, std:
700700 request.has_eta = true ;
701701 }
702702
703+ auto shift_it = body.find (" timestep_shift" );
704+ const char * shift_field_name = " timestep_shift" ;
705+ if (shift_it == body.end ()) {
706+ shift_it = body.find (" shifted_timestep" );
707+ shift_field_name = " shifted_timestep" ;
708+ }
709+ if (shift_it != body.end ()) {
710+ if (!shift_it->is_number_integer ()) {
711+ error = std::string (" field '" ) + shift_field_name + " ' must be an integer" ;
712+ return false ;
713+ }
714+ int value = static_cast <int >(shift_it->get <int64_t >());
715+ if (value < 0 || value > 1000 ) {
716+ error = std::string (" field '" ) + shift_field_name + " ' must be between 0 and 1000" ;
717+ return false ;
718+ }
719+ request.shifted_timestep = value;
720+ }
721+
703722 auto batch_it = body.find (" batch_count" );
704723 if (batch_it != body.end ()) {
705724 if (!batch_it->is_number_integer ()) {
@@ -722,15 +741,6 @@ bool parse_generation_request(const json& body, GenerationRequest& request, std:
722741 request.seed = seed_it->get <int64_t >();
723742 }
724743
725- auto norm_it = body.find (" normalize_input" );
726- if (norm_it != body.end ()) {
727- if (!norm_it->is_boolean ()) {
728- error = " field 'normalize_input' must be a boolean" ;
729- return false ;
730- }
731- request.normalize_input = norm_it->get <bool >();
732- }
733-
734744 auto method_it = body.find (" sample_method" );
735745 if (method_it != body.end ()) {
736746 if (!method_it->is_string ()) {
@@ -1056,7 +1066,6 @@ int main(int argc, char** argv) {
10561066 img_params.height = request_params.height ;
10571067 img_params.batch_count = request_params.batch_count ;
10581068 img_params.seed = effective_seed;
1059- img_params.normalize_input = request_params.normalize_input ;
10601069 if (request_params.has_vae_tiling_override ) {
10611070 img_params.vae_tiling_params = request_params.vae_tiling_params ;
10621071 }
@@ -1082,6 +1091,7 @@ int main(int argc, char** argv) {
10821091 if (request_params.has_eta ) {
10831092 sample_params.eta = request_params.eta ;
10841093 }
1094+ sample_params.shifted_timestep = request_params.shifted_timestep ;
10851095
10861096 auto start_time = std::chrono::steady_clock::now ();
10871097 sd_image_t * results = generate_image (state.ctx , &img_params);
0 commit comments