Skip to content

Commit dbc1f18

Browse files
committed
Support Custom ESRGAN tile size
1 parent 0585e26 commit dbc1f18

File tree

4 files changed

+24
-8
lines changed

4 files changed

+24
-8
lines changed

esrgan.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,10 @@ struct ESRGAN : public GGMLRunner {
156156

157157
ESRGAN(ggml_backend_t backend,
158158
bool offload_params_to_cpu,
159+
int tile_size = 128,
159160
const String2GGMLType& tensor_types = {})
160161
: GGMLRunner(backend, offload_params_to_cpu) {
161-
// rrdb_net will be created in load_from_file
162+
this->tile_size = tile_size;
162163
}
163164

164165
void enable_conv2d_direct() {

examples/cli/main.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ struct SDParams {
116116
bool canny_preprocess = false;
117117
bool color = false;
118118
int upscale_repeats = 1;
119+
int upscale_tile = 128;
119120

120121
// Photo Maker
121122
std::string photo_maker_path;
@@ -201,6 +202,7 @@ void print_params(SDParams params) {
201202
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
202203
printf(" force_sdxl_vae_conv_scale: %s\n", params.force_sdxl_vae_conv_scale ? "true" : "false");
203204
printf(" upscale_repeats: %d\n", params.upscale_repeats);
205+
printf(" upscale_tile: %d\n", params.upscale_tile);
204206
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
205207
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
206208
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
@@ -235,6 +237,7 @@ void print_usage(int argc, const char* argv[]) {
235237
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
236238
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
237239
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
240+
printf(" --upscale-tile Tile size for the ESRGAN upscaler (default 128)\n");
238241
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
239242
printf(" If not specified, the default is the type of the weight file\n");
240243
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
@@ -527,6 +530,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
527530
options.int_options = {
528531
{"-t", "--threads", "", &params.n_threads},
529532
{"", "--upscale-repeats", "", &params.upscale_repeats},
533+
{"","--upscale-tile", "", &params.upscale_tile},
530534
{"-H", "--height", "", &params.height},
531535
{"-W", "--width", "", &params.width},
532536
{"", "--steps", "", &params.sample_params.sample_steps},
@@ -917,6 +921,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
917921
exit(1);
918922
}
919923

924+
if (params.upscale_tile < 1) {
925+
fprintf(stderr, "error: upscale tile size must be at least 1\n");
926+
exit(1);
927+
}
928+
920929
if (params.mode == UPSCALE) {
921930
if (params.esrgan_path.length() == 0) {
922931
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
@@ -1486,7 +1495,8 @@ int main(int argc, const char* argv[]) {
14861495
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
14871496
params.offload_params_to_cpu,
14881497
params.diffusion_conv_direct,
1489-
params.n_threads);
1498+
params.n_threads,
1499+
params.upscale_tile);
14901500

14911501
if (upscaler_ctx == NULL) {
14921502
printf("new_upscaler_ctx failed\n");

stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
292292
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
293293
bool offload_params_to_cpu,
294294
bool direct,
295-
int n_threads);
295+
int n_threads,
296+
int tile_size);
296297
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
297298

298299
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,

upscaler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ struct UpscalerGGML {
1010
std::string esrgan_path;
1111
int n_threads;
1212
bool direct = false;
13+
int tile_size = 128;
1314

1415
UpscalerGGML(int n_threads,
15-
bool direct = false)
16+
bool direct = false,
17+
int tile_size = 128)
1618
: n_threads(n_threads),
17-
direct(direct) {
19+
direct(direct),
20+
tile_size(tile_size) {
1821
}
1922

2023
bool load_from_file(const std::string& esrgan_path,
@@ -51,7 +54,7 @@ struct UpscalerGGML {
5154
backend = ggml_backend_cpu_init();
5255
}
5356
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
54-
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.tensor_storages_types);
57+
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, tile_size, model_loader.tensor_storages_types);
5558
if (direct) {
5659
esrgan_upscaler->enable_conv2d_direct();
5760
}
@@ -113,14 +116,15 @@ struct upscaler_ctx_t {
113116
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
114117
bool offload_params_to_cpu,
115118
bool direct,
116-
int n_threads) {
119+
int n_threads,
120+
int tile_size) {
117121
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
118122
if (upscaler_ctx == NULL) {
119123
return NULL;
120124
}
121125
std::string esrgan_path(esrgan_path_c_str);
122126

123-
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
127+
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct, tile_size);
124128
if (upscaler_ctx->upscaler == NULL) {
125129
return NULL;
126130
}

0 commit comments

Comments
 (0)