@@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
245245 parameter_string += " Guidance: " + std::to_string (gen_params.sample_params .guidance .distilled_guidance ) + " , " ;
246246 parameter_string += " Eta: " + std::to_string (gen_params.sample_params .eta ) + " , " ;
247247 parameter_string += " Seed: " + std::to_string (seed) + " , " ;
248- parameter_string += " Size: " + std::to_string (gen_params.width ) + " x" + std::to_string (gen_params.height ) + " , " ;
248+ parameter_string += " Size: " + std::to_string (gen_params.get_resolved_width ()) + " x" + std::to_string (gen_params.get_resolved_height () ) + " , " ;
249249 parameter_string += " Model: " + sd_basename (ctx_params.model_path ) + " , " ;
250250 parameter_string += " RNG: " + std::string (sd_rng_type_name (ctx_params.rng_type )) + " , " ;
251251 if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
526526 }
527527
528528 bool vae_decode_only = true ;
529- sd_image_t init_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
530- sd_image_t end_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
531- sd_image_t control_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 3 , nullptr };
532- sd_image_t mask_image = {( uint32_t )gen_params. width , ( uint32_t )gen_params. height , 1 , nullptr };
529+ sd_image_t init_image = {0 , 0 , 3 , nullptr };
530+ sd_image_t end_image = {0 , 0 , 3 , nullptr };
531+ sd_image_t control_image = {0 , 0 , 3 , nullptr };
532+ sd_image_t mask_image = {0 , 0 , 1 , nullptr };
533533 std::vector<sd_image_t > ref_images;
534534 std::vector<sd_image_t > pmid_images;
535535 std::vector<sd_image_t > control_frames;
@@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
556556 control_frames.clear ();
557557 };
558558
559- if (gen_params.init_image_path .size () > 0 ) {
560- vae_decode_only = false ;
559+ auto load_image_and_update_size = [&](const std::string& path,
560+ sd_image_t & image,
561+ bool resize_image = true ,
562+ int expected_channel = 3 ) -> bool {
563+ int expected_width = 0 ;
564+ int expected_height = 0 ;
565+ if (resize_image && gen_params.width_and_height_are_set ()) {
566+ expected_width = gen_params.width ;
567+ expected_height = gen_params.height ;
568+ }
561569
562- int width = 0 ;
563- int height = 0 ;
564- init_image.data = load_image_from_file (gen_params.init_image_path .c_str (), width, height, gen_params.width , gen_params.height );
565- if (init_image.data == nullptr ) {
566- LOG_ERROR (" load image from '%s' failed" , gen_params.init_image_path .c_str ());
570+ if (!load_sd_image_from_file (&image, path.c_str (), expected_width, expected_height, expected_channel)) {
571+ LOG_ERROR (" load image from '%s' failed" , path.c_str ());
567572 release_all_resources ();
573+ return false ;
574+ }
575+
576+ gen_params.set_width_and_height_if_unset (image.width , image.height );
577+ return true ;
578+ };
579+
580+ if (gen_params.init_image_path .size () > 0 ) {
581+ vae_decode_only = false ;
582+ if (!load_image_and_update_size (gen_params.init_image_path , init_image)) {
568583 return 1 ;
569584 }
570585 }
571586
572587 if (gen_params.end_image_path .size () > 0 ) {
573588 vae_decode_only = false ;
574-
575- int width = 0 ;
576- int height = 0 ;
577- end_image.data = load_image_from_file (gen_params.end_image_path .c_str (), width, height, gen_params.width , gen_params.height );
578- if (end_image.data == nullptr ) {
579- LOG_ERROR (" load image from '%s' failed" , gen_params.end_image_path .c_str ());
580- release_all_resources ();
589+ if (!load_image_and_update_size (gen_params.init_image_path , end_image)) {
581590 return 1 ;
582591 }
583592 }
584593
594+ if (gen_params.ref_image_paths .size () > 0 ) {
595+ vae_decode_only = false ;
596+ for (auto & path : gen_params.ref_image_paths ) {
597+ sd_image_t ref_image = {0 , 0 , 3 , nullptr };
598+ if (!load_image_and_update_size (path, ref_image, false )) {
599+ return 1 ;
600+ }
601+ ref_images.push_back (ref_image);
602+ }
603+ }
604+
585605 if (gen_params.mask_image_path .size () > 0 ) {
586- int c = 0 ;
587- int width = 0 ;
588- int height = 0 ;
589- mask_image. data = load_image_from_file (gen_params. mask_image_path . c_str (), width, height, gen_params. width , gen_params.height , 1 );
590- if (mask_image. data == nullptr ) {
606+ if ( load_sd_image_from_file (&mask_image,
607+ gen_params. mask_image_path . c_str (),
608+ gen_params. get_resolved_width (),
609+ gen_params.get_resolved_height (),
610+ 1 ) ) {
591611 LOG_ERROR (" load image from '%s' failed" , gen_params.mask_image_path .c_str ());
592612 release_all_resources ();
593613 return 1 ;
594614 }
595615 } else {
596- mask_image.data = (uint8_t *)malloc (gen_params.width * gen_params.height );
616+ mask_image.data = (uint8_t *)malloc (gen_params.get_resolved_width () * gen_params.get_resolved_height () );
597617 if (mask_image.data == nullptr ) {
598618 LOG_ERROR (" malloc mask image failed" );
599619 release_all_resources ();
600620 return 1 ;
601621 }
602- memset (mask_image.data , 255 , gen_params.width * gen_params.height );
622+ mask_image.width = gen_params.get_resolved_width ();
623+ mask_image.height = gen_params.get_resolved_height ();
624+ memset (mask_image.data , 255 , gen_params.get_resolved_width () * gen_params.get_resolved_height ());
603625 }
604626
605627 if (gen_params.control_image_path .size () > 0 ) {
606- int width = 0 ;
607- int height = 0 ;
608- control_image. data = load_image_from_file (gen_params. control_image_path . c_str (), width, height, gen_params. width , gen_params.height );
609- if (control_image. data == nullptr ) {
628+ if ( load_sd_image_from_file (&control_image,
629+ gen_params. control_image_path . c_str (),
630+ gen_params.get_resolved_width (),
631+ gen_params. get_resolved_height ()) ) {
610632 LOG_ERROR (" load image from '%s' failed" , gen_params.control_image_path .c_str ());
611633 release_all_resources ();
612634 return 1 ;
@@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
621643 }
622644 }
623645
624- if (gen_params.ref_image_paths .size () > 0 ) {
625- vae_decode_only = false ;
626- for (auto & path : gen_params.ref_image_paths ) {
627- int width = 0 ;
628- int height = 0 ;
629- uint8_t * image_buffer = load_image_from_file (path.c_str (), width, height);
630- if (image_buffer == nullptr ) {
631- LOG_ERROR (" load image from '%s' failed" , path.c_str ());
632- release_all_resources ();
633- return 1 ;
634- }
635- ref_images.push_back ({(uint32_t )width,
636- (uint32_t )height,
637- 3 ,
638- image_buffer});
639- }
640- }
641-
642646 if (!gen_params.control_video_path .empty ()) {
643647 if (!load_images_from_dir (gen_params.control_video_path ,
644648 control_frames,
645- gen_params.width ,
646- gen_params.height ,
649+ gen_params.get_resolved_width () ,
650+ gen_params.get_resolved_height () ,
647651 gen_params.video_frames ,
648652 cli_params.verbose )) {
649653 release_all_resources ();
@@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
717721 gen_params.auto_resize_ref_image ,
718722 gen_params.increase_ref_index ,
719723 mask_image,
720- gen_params.width ,
721- gen_params.height ,
724+ gen_params.get_resolved_width () ,
725+ gen_params.get_resolved_height () ,
722726 gen_params.sample_params ,
723727 gen_params.strength ,
724728 gen_params.seed ,
@@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
748752 end_image,
749753 control_frames.data (),
750754 (int )control_frames.size (),
751- gen_params.width ,
752- gen_params.height ,
755+ gen_params.get_resolved_width () ,
756+ gen_params.get_resolved_height () ,
753757 gen_params.sample_params ,
754758 gen_params.high_noise_sample_params ,
755759 gen_params.moe_boundary ,
0 commit comments