@@ -492,37 +492,51 @@ int main(int argc, const char** argv) {
492492
493493 svr.Post (" /v1/images/edits" , [&](const httplib::Request& req, httplib::Response& res) {
494494 try {
495- if (req.body . empty ()) {
495+ if (! req.is_multipart_form_data ()) {
496496 res.status = 400 ;
497- res.set_content (R"( {"error":"empty body "})" , " application/json" );
497+ res.set_content (R"( {"error":"Content-Type must be multipart/form-data "})" , " application/json" );
498498 return ;
499499 }
500500
501- json j = json::parse (req.body );
502-
503- std::string prompt = j.value (" prompt" , " " );
504- int n = std::max (1 , j.value (" n" , 1 ));
505- std::string size = j.value (" size" , " " );
506- std::string output_format = j.value (" output_format" , " png" );
507- int output_compression = j.value (" output_compression" , 100 );
508-
509- std::string ref_image_b64 = j.value (" image" , " " );
510- std::string mask_image_b64 = j.value (" mask" , " " );
511-
501+ std::string prompt = req.form .get_field (" prompt" );
512502 if (prompt.empty ()) {
513503 res.status = 400 ;
514504 res.set_content (R"( {"error":"prompt required"})" , " application/json" );
515505 return ;
516506 }
517507
518- if (ref_image_b64.empty ()) {
508+ std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args (prompt);
509+
510+ size_t image_count = req.form .get_file_count (" image[]" );
511+ if (image_count == 0 ) {
519512 res.status = 400 ;
520- res.set_content (R"( {"error":"image required"})" , " application/json" );
513+ res.set_content (R"( {"error":"at least one image[] required"})" , " application/json" );
521514 return ;
522515 }
523516
524- int width = 512 ;
525- int height = 512 ;
517+ std::vector<std::vector<uint8_t >> images_bytes;
518+ for (size_t i = 0 ; i < image_count; i++) {
519+ auto file = req.form .get_file (" image[]" , i);
520+ images_bytes.emplace_back (file.content .begin (), file.content .end ());
521+ }
522+
523+ std::vector<uint8_t > mask_bytes;
524+ if (req.form .has_field (" mask" )) {
525+ auto file = req.form .get_file (" mask" );
526+ mask_bytes.assign (file.content .begin (), file.content .end ());
527+ }
528+
529+ int n = 1 ;
530+ if (req.form .has_field (" n" )) {
531+ try {
532+ n = std::stoi (req.form .get_field (" n" ));
533+ } catch (...) {
534+ }
535+ }
536+ n = std::clamp (n, 1 , 8 );
537+
538+ std::string size = req.form .get_field (" size" );
539+ int width = 512 , height = 512 ;
526540 if (!size.empty ()) {
527541 auto pos = size.find (' x' );
528542 if (pos != std::string::npos) {
@@ -534,66 +548,84 @@ int main(int argc, const char** argv) {
534548 }
535549 }
536550
551+ std::string output_format = " png" ;
552+ if (req.form .has_field (" output_format" ))
553+ output_format = req.form .get_field (" output_format" );
537554 if (output_format != " png" && output_format != " jpeg" ) {
538555 res.status = 400 ;
539556 res.set_content (R"( {"error":"invalid output_format, must be one of [png, jpeg]"})" , " application/json" );
540557 return ;
541558 }
542559
543- if (n <= 0 )
544- n = 1 ;
545- if (n > 8 )
546- n = 8 ;
547- if (output_compression > 100 )
560+ std::string output_compression_str = req.form .get_field (" output_compression" );
561+ int output_compression = 100 ;
562+ try {
563+ output_compression = std::stoi (output_compression_str);
564+ } catch (...) {
565+ }
566+ if (output_compression > 100 ) {
548567 output_compression = 100 ;
549- if (output_compression < 0 )
568+ }
569+ if (output_compression < 0 ) {
550570 output_compression = 0 ;
571+ }
551572
552- // base64 -> raw image
553- std::vector<uint8_t > ref_image_bytes = base64_decode (ref_image_b64);
554- int img_w = width;
555- int img_h = height;
556- uint8_t * raw_pixels = load_image_from_memory (
557- reinterpret_cast <const char *>(ref_image_bytes.data ()),
558- img_w, img_h,
559- width, height, 3 );
560-
561- sd_image_t ref_image;
562- ref_image.width = img_w;
563- ref_image.height = img_h;
564- ref_image.channel = 3 ;
565- ref_image.data = raw_pixels;
573+ SDGenerationParams gen_params;
574+ gen_params.prompt = prompt;
575+ gen_params.width = width;
576+ gen_params.height = height;
577+ gen_params.batch_count = n;
578+
579+ if (!sd_cpp_extra_args_str.empty () && !gen_params.from_json_str (sd_cpp_extra_args_str)) {
580+ res.status = 400 ;
581+ res.set_content (R"( {"error":"invalid sd_cpp_extra_args"})" , " application/json" );
582+ return ;
583+ }
584+
585+ if (svr_params.verbose ) {
586+ printf (" %s\n " , gen_params.to_string ().c_str ());
587+ }
588+
589+ sd_image_t init_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
590+ sd_image_t control_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
591+ std::vector<sd_image_t > pmid_images;
592+
593+ std::vector<sd_image_t > ref_images;
594+ ref_images.reserve (images_bytes.size ());
595+ for (auto & bytes : images_bytes) {
596+ int img_w = width;
597+ int img_h = height;
598+ uint8_t * raw_pixels = load_image_from_memory (
599+ reinterpret_cast <const char *>(bytes.data ()),
600+ bytes.size (),
601+ img_w, img_h,
602+ width, height, 3 );
603+
604+ if (!raw_pixels) {
605+ continue ;
606+ }
607+
608+ sd_image_t img{(uint32_t )img_w, (uint32_t )img_h, 3 , raw_pixels};
609+ ref_images.push_back (img);
610+ }
566611
567612 sd_image_t mask_image = {0 };
568- if (!mask_image_b64 .empty ()) {
569- std::vector< uint8_t > mask_bytes = base64_decode (mask_image_b64) ;
570- int mask_w = width, mask_h = height;
613+ if (!mask_bytes .empty ()) {
614+ int mask_w = width ;
615+ int mask_h = height;
571616 uint8_t * mask_raw = load_image_from_memory (
572617 reinterpret_cast <const char *>(mask_bytes.data ()),
618+ mask_bytes.size (),
573619 mask_w, mask_h,
574620 width, height, 1 );
575- mask_image.width = mask_w;
576- mask_image.height = mask_h;
577- mask_image.channel = 1 ;
578- mask_image.data = mask_raw;
621+ mask_image = {(uint32_t )mask_w, (uint32_t )mask_h, 1 , mask_raw};
579622 } else {
580623 mask_image.width = width;
581624 mask_image.height = height;
582625 mask_image.channel = 1 ;
583626 mask_image.data = nullptr ;
584627 }
585628
586- SDGenerationParams gen_params;
587- gen_params.prompt = prompt;
588- gen_params.width = width;
589- gen_params.height = height;
590- gen_params.batch_count = n;
591-
592- sd_image_t init_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
593- sd_image_t control_image = {(uint32_t )gen_params.width , (uint32_t )gen_params.height , 3 , nullptr };
594- std::vector<sd_image_t > ref_images = {ref_image};
595- std::vector<sd_image_t > pmid_images;
596-
597629 sd_img_gen_params_t img_gen_params = {
598630 gen_params.lora_vec .data (),
599631 static_cast <uint32_t >(gen_params.lora_vec .size ()),
@@ -662,6 +694,9 @@ int main(int argc, const char** argv) {
662694 if (mask_image.data ) {
663695 stbi_image_free (mask_image.data );
664696 }
697+ for (auto ref_image : ref_images) {
698+ stbi_image_free (ref_image.data );
699+ }
665700 } catch (const std::exception& e) {
666701 res.status = 500 ;
667702 json err;
0 commit comments