Skip to content

Commit 5e4579c

Browse files
authored
feat: use image width and height when not explicitly set (#1206)
1 parent 3295711 commit 5e4579c

File tree

2 files changed

+92
-65
lines changed

2 files changed

+92
-65
lines changed

examples/cli/main.cpp

Lines changed: 58 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
245245
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
246246
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
247247
parameter_string += "Seed: " + std::to_string(seed) + ", ";
248-
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
248+
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
249249
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
250250
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
251251
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
526526
}
527527

528528
bool vae_decode_only = true;
529-
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
530-
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
531-
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
532-
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
529+
sd_image_t init_image = {0, 0, 3, nullptr};
530+
sd_image_t end_image = {0, 0, 3, nullptr};
531+
sd_image_t control_image = {0, 0, 3, nullptr};
532+
sd_image_t mask_image = {0, 0, 1, nullptr};
533533
std::vector<sd_image_t> ref_images;
534534
std::vector<sd_image_t> pmid_images;
535535
std::vector<sd_image_t> control_frames;
@@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
556556
control_frames.clear();
557557
};
558558

559-
if (gen_params.init_image_path.size() > 0) {
560-
vae_decode_only = false;
559+
auto load_image_and_update_size = [&](const std::string& path,
560+
sd_image_t& image,
561+
bool resize_image = true,
562+
int expected_channel = 3) -> bool {
563+
int expected_width = 0;
564+
int expected_height = 0;
565+
if (resize_image && gen_params.width_and_height_are_set()) {
566+
expected_width = gen_params.width;
567+
expected_height = gen_params.height;
568+
}
561569

562-
int width = 0;
563-
int height = 0;
564-
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
565-
if (init_image.data == nullptr) {
566-
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
570+
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
571+
LOG_ERROR("load image from '%s' failed", path.c_str());
567572
release_all_resources();
573+
return false;
574+
}
575+
576+
gen_params.set_width_and_height_if_unset(image.width, image.height);
577+
return true;
578+
};
579+
580+
if (gen_params.init_image_path.size() > 0) {
581+
vae_decode_only = false;
582+
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
568583
return 1;
569584
}
570585
}
571586

572587
if (gen_params.end_image_path.size() > 0) {
573588
vae_decode_only = false;
574-
575-
int width = 0;
576-
int height = 0;
577-
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
578-
if (end_image.data == nullptr) {
579-
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
580-
release_all_resources();
589+
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
581590
return 1;
582591
}
583592
}
584593

594+
if (gen_params.ref_image_paths.size() > 0) {
595+
vae_decode_only = false;
596+
for (auto& path : gen_params.ref_image_paths) {
597+
sd_image_t ref_image = {0, 0, 3, nullptr};
598+
if (!load_image_and_update_size(path, ref_image, false)) {
599+
return 1;
600+
}
601+
ref_images.push_back(ref_image);
602+
}
603+
}
604+
585605
if (gen_params.mask_image_path.size() > 0) {
586-
int c = 0;
587-
int width = 0;
588-
int height = 0;
589-
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
590-
if (mask_image.data == nullptr) {
606+
if (load_sd_image_from_file(&mask_image,
607+
gen_params.mask_image_path.c_str(),
608+
gen_params.get_resolved_width(),
609+
gen_params.get_resolved_height(),
610+
1)) {
591611
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
592612
release_all_resources();
593613
return 1;
594614
}
595615
} else {
596-
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
616+
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
597617
if (mask_image.data == nullptr) {
598618
LOG_ERROR("malloc mask image failed");
599619
release_all_resources();
600620
return 1;
601621
}
602-
memset(mask_image.data, 255, gen_params.width * gen_params.height);
622+
mask_image.width = gen_params.get_resolved_width();
623+
mask_image.height = gen_params.get_resolved_height();
624+
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
603625
}
604626

605627
if (gen_params.control_image_path.size() > 0) {
606-
int width = 0;
607-
int height = 0;
608-
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
609-
if (control_image.data == nullptr) {
628+
if (load_sd_image_from_file(&control_image,
629+
gen_params.control_image_path.c_str(),
630+
gen_params.get_resolved_width(),
631+
gen_params.get_resolved_height())) {
610632
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
611633
release_all_resources();
612634
return 1;
@@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
621643
}
622644
}
623645

624-
if (gen_params.ref_image_paths.size() > 0) {
625-
vae_decode_only = false;
626-
for (auto& path : gen_params.ref_image_paths) {
627-
int width = 0;
628-
int height = 0;
629-
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
630-
if (image_buffer == nullptr) {
631-
LOG_ERROR("load image from '%s' failed", path.c_str());
632-
release_all_resources();
633-
return 1;
634-
}
635-
ref_images.push_back({(uint32_t)width,
636-
(uint32_t)height,
637-
3,
638-
image_buffer});
639-
}
640-
}
641-
642646
if (!gen_params.control_video_path.empty()) {
643647
if (!load_images_from_dir(gen_params.control_video_path,
644648
control_frames,
645-
gen_params.width,
646-
gen_params.height,
649+
gen_params.get_resolved_width(),
650+
gen_params.get_resolved_height(),
647651
gen_params.video_frames,
648652
cli_params.verbose)) {
649653
release_all_resources();
@@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
717721
gen_params.auto_resize_ref_image,
718722
gen_params.increase_ref_index,
719723
mask_image,
720-
gen_params.width,
721-
gen_params.height,
724+
gen_params.get_resolved_width(),
725+
gen_params.get_resolved_height(),
722726
gen_params.sample_params,
723727
gen_params.strength,
724728
gen_params.seed,
@@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
748752
end_image,
749753
control_frames.data(),
750754
(int)control_frames.size(),
751-
gen_params.width,
752-
gen_params.height,
755+
gen_params.get_resolved_width(),
756+
gen_params.get_resolved_height(),
753757
gen_params.sample_params,
754758
gen_params.high_noise_sample_params,
755759
gen_params.moe_boundary,

examples/common/common.hpp

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,8 +1024,8 @@ struct SDGenerationParams {
10241024
std::string prompt_with_lora; // for metadata record only
10251025
std::string negative_prompt;
10261026
int clip_skip = -1; // <= 0 represents unspecified
1027-
int width = 512;
1028-
int height = 512;
1027+
int width = -1;
1028+
int height = -1;
10291029
int batch_count = 1;
10301030
std::string init_image_path;
10311031
std::string end_image_path;
@@ -1705,17 +1705,24 @@ struct SDGenerationParams {
17051705
}
17061706
}
17071707

1708-
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
1709-
prompt_with_lora = prompt;
1710-
if (width <= 0) {
1711-
LOG_ERROR("error: the width must be greater than 0\n");
1712-
return false;
1713-
}
1708+
bool width_and_height_are_set() const {
1709+
return width > 0 && height > 0;
1710+
}
17141711

1715-
if (height <= 0) {
1716-
LOG_ERROR("error: the height must be greater than 0\n");
1717-
return false;
1712+
void set_width_and_height_if_unset(int w, int h) {
1713+
if (!width_and_height_are_set()) {
1714+
LOG_INFO("set width x height to %d x %d", w, h);
1715+
width = w;
1716+
height = h;
17181717
}
1718+
}
1719+
1720+
int get_resolved_width() const { return (width > 0) ? width : 512; }
1721+
1722+
int get_resolved_height() const { return (height > 0) ? height : 512; }
1723+
1724+
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
1725+
prompt_with_lora = prompt;
17191726

17201727
if (sample_params.sample_steps <= 0) {
17211728
LOG_ERROR("error: the sample_steps must be greater than 0\n");
@@ -2083,6 +2090,22 @@ uint8_t* load_image_from_file(const char* image_path,
20832090
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
20842091
}
20852092

2093+
bool load_sd_image_from_file(sd_image_t* image,
2094+
const char* image_path,
2095+
int expected_width = 0,
2096+
int expected_height = 0,
2097+
int expected_channel = 3) {
2098+
int width;
2099+
int height;
2100+
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
2101+
if (image->data == nullptr) {
2102+
return false;
2103+
}
2104+
image->width = width;
2105+
image->height = height;
2106+
return true;
2107+
}
2108+
20862109
uint8_t* load_image_from_memory(const char* image_bytes,
20872110
int len,
20882111
int& width,

0 commit comments

Comments
 (0)