Skip to content

Commit e280695

Browse files
committed
fix image edit api
1 parent 9021534 commit e280695

File tree

2 files changed

+99
-58
lines changed

2 files changed

+99
-58
lines changed

examples/common/common.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,10 @@ struct SDGenerationParams {
13411341
load_if_exists("skip_layers", skip_layers);
13421342
load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
13431343

1344+
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
1345+
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
1346+
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
1347+
13441348
return true;
13451349
}
13461350

@@ -1627,6 +1631,7 @@ static std::string version_string() {
16271631

16281632
uint8_t* load_image_common(bool from_memory,
16291633
const char* image_path_or_bytes,
1634+
int len,
16301635
int& width,
16311636
int& height,
16321637
int expected_width = 0,
@@ -1637,7 +1642,7 @@ uint8_t* load_image_common(bool from_memory,
16371642
uint8_t* image_buffer = nullptr;
16381643
if (from_memory) {
16391644
image_path = "memory";
1640-
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
1645+
image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel);
16411646
} else {
16421647
image_path = image_path_or_bytes;
16431648
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
@@ -1733,14 +1738,15 @@ uint8_t* load_image_from_file(const char* image_path,
17331738
int expected_width = 0,
17341739
int expected_height = 0,
17351740
int expected_channel = 3) {
1736-
return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel);
1741+
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
17371742
}
17381743

17391744
uint8_t* load_image_from_memory(const char* image_bytes,
1745+
int len,
17401746
int& width,
17411747
int& height,
17421748
int expected_width = 0,
17431749
int expected_height = 0,
17441750
int expected_channel = 3) {
1745-
return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel);
1751+
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
17461752
}

examples/server/main.cpp

Lines changed: 90 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -492,37 +492,51 @@ int main(int argc, const char** argv) {
492492

493493
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
494494
try {
495-
if (req.body.empty()) {
495+
if (!req.is_multipart_form_data()) {
496496
res.status = 400;
497-
res.set_content(R"({"error":"empty body"})", "application/json");
497+
res.set_content(R"({"error":"Content-Type must be multipart/form-data"})", "application/json");
498498
return;
499499
}
500500

501-
json j = json::parse(req.body);
502-
503-
std::string prompt = j.value("prompt", "");
504-
int n = std::max(1, j.value("n", 1));
505-
std::string size = j.value("size", "");
506-
std::string output_format = j.value("output_format", "png");
507-
int output_compression = j.value("output_compression", 100);
508-
509-
std::string ref_image_b64 = j.value("image", "");
510-
std::string mask_image_b64 = j.value("mask", "");
511-
501+
std::string prompt = req.form.get_field("prompt");
512502
if (prompt.empty()) {
513503
res.status = 400;
514504
res.set_content(R"({"error":"prompt required"})", "application/json");
515505
return;
516506
}
517507

518-
if (ref_image_b64.empty()) {
508+
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
509+
510+
size_t image_count = req.form.get_file_count("image[]");
511+
if (image_count == 0) {
519512
res.status = 400;
520-
res.set_content(R"({"error":"image required"})", "application/json");
513+
res.set_content(R"({"error":"at least one image[] required"})", "application/json");
521514
return;
522515
}
523516

524-
int width = 512;
525-
int height = 512;
517+
std::vector<std::vector<uint8_t>> images_bytes;
518+
for (size_t i = 0; i < image_count; i++) {
519+
auto file = req.form.get_file("image[]", i);
520+
images_bytes.emplace_back(file.content.begin(), file.content.end());
521+
}
522+
523+
std::vector<uint8_t> mask_bytes;
524+
if (req.form.has_field("mask")) {
525+
auto file = req.form.get_file("mask");
526+
mask_bytes.assign(file.content.begin(), file.content.end());
527+
}
528+
529+
int n = 1;
530+
if (req.form.has_field("n")) {
531+
try {
532+
n = std::stoi(req.form.get_field("n"));
533+
} catch (...) {
534+
}
535+
}
536+
n = std::clamp(n, 1, 8);
537+
538+
std::string size = req.form.get_field("size");
539+
int width = 512, height = 512;
526540
if (!size.empty()) {
527541
auto pos = size.find('x');
528542
if (pos != std::string::npos) {
@@ -534,66 +548,84 @@ int main(int argc, const char** argv) {
534548
}
535549
}
536550

551+
std::string output_format = "png";
552+
if (req.form.has_field("output_format"))
553+
output_format = req.form.get_field("output_format");
537554
if (output_format != "png" && output_format != "jpeg") {
538555
res.status = 400;
539556
res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json");
540557
return;
541558
}
542559

543-
if (n <= 0)
544-
n = 1;
545-
if (n > 8)
546-
n = 8;
547-
if (output_compression > 100)
560+
std::string output_compression_str = req.form.get_field("output_compression");
561+
int output_compression = 100;
562+
try {
563+
output_compression = std::stoi(output_compression_str);
564+
} catch (...) {
565+
}
566+
if (output_compression > 100) {
548567
output_compression = 100;
549-
if (output_compression < 0)
568+
}
569+
if (output_compression < 0) {
550570
output_compression = 0;
571+
}
551572

552-
// base64 -> raw image
553-
std::vector<uint8_t> ref_image_bytes = base64_decode(ref_image_b64);
554-
int img_w = width;
555-
int img_h = height;
556-
uint8_t* raw_pixels = load_image_from_memory(
557-
reinterpret_cast<const char*>(ref_image_bytes.data()),
558-
img_w, img_h,
559-
width, height, 3);
560-
561-
sd_image_t ref_image;
562-
ref_image.width = img_w;
563-
ref_image.height = img_h;
564-
ref_image.channel = 3;
565-
ref_image.data = raw_pixels;
573+
SDGenerationParams gen_params;
574+
gen_params.prompt = prompt;
575+
gen_params.width = width;
576+
gen_params.height = height;
577+
gen_params.batch_count = n;
578+
579+
if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) {
580+
res.status = 400;
581+
res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json");
582+
return;
583+
}
584+
585+
if (svr_params.verbose) {
586+
printf("%s\n", gen_params.to_string().c_str());
587+
}
588+
589+
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
590+
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
591+
std::vector<sd_image_t> pmid_images;
592+
593+
std::vector<sd_image_t> ref_images;
594+
ref_images.reserve(images_bytes.size());
595+
for (auto& bytes : images_bytes) {
596+
int img_w = width;
597+
int img_h = height;
598+
uint8_t* raw_pixels = load_image_from_memory(
599+
reinterpret_cast<const char*>(bytes.data()),
600+
bytes.size(),
601+
img_w, img_h,
602+
width, height, 3);
603+
604+
if (!raw_pixels) {
605+
continue;
606+
}
607+
608+
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
609+
ref_images.push_back(img);
610+
}
566611

567612
sd_image_t mask_image = {0};
568-
if (!mask_image_b64.empty()) {
569-
std::vector<uint8_t> mask_bytes = base64_decode(mask_image_b64);
570-
int mask_w = width, mask_h = height;
613+
if (!mask_bytes.empty()) {
614+
int mask_w = width;
615+
int mask_h = height;
571616
uint8_t* mask_raw = load_image_from_memory(
572617
reinterpret_cast<const char*>(mask_bytes.data()),
618+
mask_bytes.size(),
573619
mask_w, mask_h,
574620
width, height, 1);
575-
mask_image.width = mask_w;
576-
mask_image.height = mask_h;
577-
mask_image.channel = 1;
578-
mask_image.data = mask_raw;
621+
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
579622
} else {
580623
mask_image.width = width;
581624
mask_image.height = height;
582625
mask_image.channel = 1;
583626
mask_image.data = nullptr;
584627
}
585628

586-
SDGenerationParams gen_params;
587-
gen_params.prompt = prompt;
588-
gen_params.width = width;
589-
gen_params.height = height;
590-
gen_params.batch_count = n;
591-
592-
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
593-
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
594-
std::vector<sd_image_t> ref_images = {ref_image};
595-
std::vector<sd_image_t> pmid_images;
596-
597629
sd_img_gen_params_t img_gen_params = {
598630
gen_params.lora_vec.data(),
599631
static_cast<uint32_t>(gen_params.lora_vec.size()),
@@ -662,6 +694,9 @@ int main(int argc, const char** argv) {
662694
if (mask_image.data) {
663695
stbi_image_free(mask_image.data);
664696
}
697+
for (auto ref_image : ref_images) {
698+
stbi_image_free(ref_image.data);
699+
}
665700
} catch (const std::exception& e) {
666701
res.status = 500;
667702
json err;

0 commit comments

Comments
 (0)