Skip to content

Commit 9021534

Browse files
committed
wip image edit api
1 parent 96aea63 commit 9021534

File tree

3 files changed

+369
-117
lines changed

3 files changed

+369
-117
lines changed

examples/cli/main.cpp

Lines changed: 7 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,10 @@
1515
// #include "preprocessing.hpp"
1616
#include "stable-diffusion.h"
1717

18-
#define STB_IMAGE_IMPLEMENTATION
19-
#define STB_IMAGE_STATIC
20-
#include "stb_image.h"
21-
22-
#define STB_IMAGE_WRITE_IMPLEMENTATION
23-
#define STB_IMAGE_WRITE_STATIC
24-
#include "stb_image_write.h"
25-
26-
#define STB_IMAGE_RESIZE_IMPLEMENTATION
27-
#define STB_IMAGE_RESIZE_STATIC
28-
#include "stb_image_resize.h"
18+
#include "common/common.hpp"
2919

3020
#include "avi_writer.h"
3121

32-
#include "common/common.hpp"
33-
3422
const char* previews_str[] = {
3523
"none",
3624
"proj",
@@ -335,94 +323,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
335323
fflush(out_stream);
336324
}
337325

338-
uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) {
339-
int c = 0;
340-
uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel);
341-
if (image_buffer == nullptr) {
342-
fprintf(stderr, "load image from '%s' failed\n", image_path);
343-
return nullptr;
344-
}
345-
if (c < expected_channel) {
346-
fprintf(stderr,
347-
"the number of channels for the input image must be >= %d,"
348-
"but got %d channels, image_path = %s\n",
349-
expected_channel,
350-
c,
351-
image_path);
352-
free(image_buffer);
353-
return nullptr;
354-
}
355-
if (width <= 0) {
356-
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
357-
free(image_buffer);
358-
return nullptr;
359-
}
360-
if (height <= 0) {
361-
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
362-
free(image_buffer);
363-
return nullptr;
364-
}
365-
366-
// Resize input image ...
367-
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
368-
float dst_aspect = (float)expected_width / (float)expected_height;
369-
float src_aspect = (float)width / (float)height;
370-
371-
int crop_x = 0, crop_y = 0;
372-
int crop_w = width, crop_h = height;
373-
374-
if (src_aspect > dst_aspect) {
375-
crop_w = (int)(height * dst_aspect);
376-
crop_x = (width - crop_w) / 2;
377-
} else if (src_aspect < dst_aspect) {
378-
crop_h = (int)(width / dst_aspect);
379-
crop_y = (height - crop_h) / 2;
380-
}
381-
382-
if (crop_x != 0 || crop_y != 0) {
383-
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
384-
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
385-
if (cropped_image_buffer == nullptr) {
386-
fprintf(stderr, "error: allocate memory for crop\n");
387-
free(image_buffer);
388-
return nullptr;
389-
}
390-
for (int row = 0; row < crop_h; row++) {
391-
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
392-
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
393-
memcpy(dst, src, crop_w * expected_channel);
394-
}
395-
396-
width = crop_w;
397-
height = crop_h;
398-
free(image_buffer);
399-
image_buffer = cropped_image_buffer;
400-
}
401-
402-
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
403-
int resized_height = expected_height;
404-
int resized_width = expected_width;
405-
406-
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
407-
if (resized_image_buffer == nullptr) {
408-
fprintf(stderr, "error: allocate memory for resize input image\n");
409-
free(image_buffer);
410-
return nullptr;
411-
}
412-
stbir_resize(image_buffer, width, height, 0,
413-
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
414-
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
415-
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
416-
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
417-
STBIR_COLORSPACE_SRGB, nullptr);
418-
width = resized_width;
419-
height = resized_height;
420-
free(image_buffer);
421-
image_buffer = resized_image_buffer;
422-
}
423-
return image_buffer;
424-
}
425-
426326
bool load_images_from_dir(const std::string dir,
427327
std::vector<sd_image_t>& images,
428328
int expected_width = 0,
@@ -457,7 +357,7 @@ bool load_images_from_dir(const std::string dir,
457357
}
458358
int width = 0;
459359
int height = 0;
460-
uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height);
360+
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height);
461361
if (image_buffer == nullptr) {
462362
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
463363
return false;
@@ -593,7 +493,7 @@ int main(int argc, const char* argv[]) {
593493

594494
int width = 0;
595495
int height = 0;
596-
init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
496+
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
597497
if (init_image.data == nullptr) {
598498
fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str());
599499
release_all_resources();
@@ -606,7 +506,7 @@ int main(int argc, const char* argv[]) {
606506

607507
int width = 0;
608508
int height = 0;
609-
end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
509+
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
610510
if (end_image.data == nullptr) {
611511
fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str());
612512
release_all_resources();
@@ -618,7 +518,7 @@ int main(int argc, const char* argv[]) {
618518
int c = 0;
619519
int width = 0;
620520
int height = 0;
621-
mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
521+
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
622522
if (mask_image.data == nullptr) {
623523
fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str());
624524
release_all_resources();
@@ -637,7 +537,7 @@ int main(int argc, const char* argv[]) {
637537
if (gen_params.control_image_path.size() > 0) {
638538
int width = 0;
639539
int height = 0;
640-
control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
540+
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
641541
if (control_image.data == nullptr) {
642542
fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str());
643543
release_all_resources();
@@ -658,7 +558,7 @@ int main(int argc, const char* argv[]) {
658558
for (auto& path : gen_params.ref_image_paths) {
659559
int width = 0;
660560
int height = 0;
661-
uint8_t* image_buffer = load_image(path.c_str(), width, height);
561+
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
662562
if (image_buffer == nullptr) {
663563
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
664564
release_all_resources();

examples/common/common.hpp

Lines changed: 134 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11

2+
#include <filesystem>
23
#include <iostream>
34
#include <map>
45
#include <random>
56
#include <regex>
67
#include <sstream>
78
#include <string>
89
#include <vector>
9-
#include <filesystem>
1010

1111
#include <json.hpp>
12-
using json = nlohmann::json;
12+
using json = nlohmann::json;
1313
namespace fs = std::filesystem;
1414

1515
#if defined(_WIN32)
@@ -19,6 +19,18 @@ namespace fs = std::filesystem;
1919

2020
#include "stable-diffusion.h"
2121

22+
#define STB_IMAGE_IMPLEMENTATION
23+
#define STB_IMAGE_STATIC
24+
#include "stb_image.h"
25+
26+
#define STB_IMAGE_WRITE_IMPLEMENTATION
27+
#define STB_IMAGE_WRITE_STATIC
28+
#include "stb_image_write.h"
29+
30+
#define STB_IMAGE_RESIZE_IMPLEMENTATION
31+
#define STB_IMAGE_RESIZE_STATIC
32+
#include "stb_image_resize.h"
33+
2234
#define SAFE_STR(s) ((s) ? (s) : "")
2335
#define BOOL_STR(b) ((b) ? "true" : "false")
2436

@@ -1612,3 +1624,123 @@ struct SDGenerationParams {
16121624
static std::string version_string() {
16131625
return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit();
16141626
}
1627+
1628+
uint8_t* load_image_common(bool from_memory,
1629+
const char* image_path_or_bytes,
1630+
int& width,
1631+
int& height,
1632+
int expected_width = 0,
1633+
int expected_height = 0,
1634+
int expected_channel = 3) {
1635+
int c = 0;
1636+
const char* image_path;
1637+
uint8_t* image_buffer = nullptr;
1638+
if (from_memory) {
1639+
image_path = "memory";
1640+
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
1641+
} else {
1642+
image_path = image_path_or_bytes;
1643+
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
1644+
}
1645+
if (image_buffer == nullptr) {
1646+
fprintf(stderr, "load image from '%s' failed\n", image_path);
1647+
return nullptr;
1648+
}
1649+
if (c < expected_channel) {
1650+
fprintf(stderr,
1651+
"the number of channels for the input image must be >= %d,"
1652+
"but got %d channels, image_path = %s\n",
1653+
expected_channel,
1654+
c,
1655+
image_path);
1656+
free(image_buffer);
1657+
return nullptr;
1658+
}
1659+
if (width <= 0) {
1660+
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
1661+
free(image_buffer);
1662+
return nullptr;
1663+
}
1664+
if (height <= 0) {
1665+
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
1666+
free(image_buffer);
1667+
return nullptr;
1668+
}
1669+
1670+
// Resize input image ...
1671+
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
1672+
float dst_aspect = (float)expected_width / (float)expected_height;
1673+
float src_aspect = (float)width / (float)height;
1674+
1675+
int crop_x = 0, crop_y = 0;
1676+
int crop_w = width, crop_h = height;
1677+
1678+
if (src_aspect > dst_aspect) {
1679+
crop_w = (int)(height * dst_aspect);
1680+
crop_x = (width - crop_w) / 2;
1681+
} else if (src_aspect < dst_aspect) {
1682+
crop_h = (int)(width / dst_aspect);
1683+
crop_y = (height - crop_h) / 2;
1684+
}
1685+
1686+
if (crop_x != 0 || crop_y != 0) {
1687+
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
1688+
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
1689+
if (cropped_image_buffer == nullptr) {
1690+
fprintf(stderr, "error: allocate memory for crop\n");
1691+
free(image_buffer);
1692+
return nullptr;
1693+
}
1694+
for (int row = 0; row < crop_h; row++) {
1695+
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
1696+
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
1697+
memcpy(dst, src, crop_w * expected_channel);
1698+
}
1699+
1700+
width = crop_w;
1701+
height = crop_h;
1702+
free(image_buffer);
1703+
image_buffer = cropped_image_buffer;
1704+
}
1705+
1706+
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
1707+
int resized_height = expected_height;
1708+
int resized_width = expected_width;
1709+
1710+
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
1711+
if (resized_image_buffer == nullptr) {
1712+
fprintf(stderr, "error: allocate memory for resize input image\n");
1713+
free(image_buffer);
1714+
return nullptr;
1715+
}
1716+
stbir_resize(image_buffer, width, height, 0,
1717+
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
1718+
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
1719+
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
1720+
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
1721+
STBIR_COLORSPACE_SRGB, nullptr);
1722+
width = resized_width;
1723+
height = resized_height;
1724+
free(image_buffer);
1725+
image_buffer = resized_image_buffer;
1726+
}
1727+
return image_buffer;
1728+
}
1729+
1730+
uint8_t* load_image_from_file(const char* image_path,
1731+
int& width,
1732+
int& height,
1733+
int expected_width = 0,
1734+
int expected_height = 0,
1735+
int expected_channel = 3) {
1736+
return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel);
1737+
}
1738+
1739+
uint8_t* load_image_from_memory(const char* image_bytes,
1740+
int& width,
1741+
int& height,
1742+
int expected_width = 0,
1743+
int expected_height = 0,
1744+
int expected_channel = 3) {
1745+
return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel);
1746+
}

0 commit comments

Comments
 (0)