Skip to content

Commit b78c7fa

Browse files
committed
implement tiling vae encode support
1 parent abb115c commit b78c7fa

File tree

2 files changed

+56
-34
lines changed

2 files changed

+56
-34
lines changed

ggml_extend.hpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -734,21 +734,31 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
734734
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
735735

736736
// Tiling
737-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
737+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
738738
output = ggml_set_f32(output, 0);
739739

740740
int input_width = (int)input->ne[0];
741741
int input_height = (int)input->ne[1];
742742
int output_width = (int)output->ne[0];
743743
int output_height = (int)output->ne[1];
744+
745+
int input_tile_size, output_tile_size;
746+
if (scaled_out) {
747+
input_tile_size = tile_size;
748+
output_tile_size = tile_size * scale;
749+
} else {
750+
input_tile_size = tile_size * scale;
751+
output_tile_size = tile_size;
752+
}
753+
744754
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
745755

746-
int tile_overlap = (int32_t)(tile_size * tile_overlap_factor);
747-
int non_tile_overlap = tile_size - tile_overlap;
756+
int tile_overlap = (int32_t)(input_tile_size * tile_overlap_factor);
757+
int non_tile_overlap = input_tile_size - tile_overlap;
748758

749759
struct ggml_init_params params = {};
750-
params.mem_size += tile_size * tile_size * input->ne[2] * sizeof(float); // input chunk
751-
params.mem_size += (tile_size * scale) * (tile_size * scale) * output->ne[2] * sizeof(float); // output chunk
760+
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
761+
params.mem_size += output_tile_size * output_tile_size * output->ne[2] * sizeof(float); // output chunk
752762
params.mem_size += 3 * ggml_tensor_overhead();
753763
params.mem_buffer = NULL;
754764
params.no_alloc = false;
@@ -763,8 +773,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
763773
}
764774

765775
// tiling
766-
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size, tile_size, input->ne[2], 1);
767-
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, tile_size * scale, tile_size * scale, output->ne[2], 1);
776+
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
777+
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
768778
on_processing(input_tile, NULL, true);
769779
int num_tiles = ceil((float)input_width / non_tile_overlap) * ceil((float)input_height / non_tile_overlap);
770780
LOG_INFO("processing %i tiles", num_tiles);
@@ -773,19 +783,23 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
773783
bool last_y = false, last_x = false;
774784
float last_time = 0.0f;
775785
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap) {
776-
if (y + tile_size >= input_height) {
777-
y = input_height - tile_size;
786+
if (y + input_tile_size >= input_height) {
787+
y = input_height - input_tile_size;
778788
last_y = true;
779789
}
780790
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap) {
781-
if (x + tile_size >= input_width) {
782-
x = input_width - tile_size;
791+
if (x + input_tile_size >= input_width) {
792+
x = input_width - input_tile_size;
783793
last_x = true;
784794
}
785795
int64_t t1 = ggml_time_ms();
786796
ggml_split_tensor_2d(input, input_tile, x, y);
787797
on_processing(input_tile, output_tile, false);
788-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
798+
if (scaled_out) {
799+
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap * scale);
800+
} else {
801+
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap / scale);
802+
}
789803
int64_t t2 = ggml_time_ms();
790804
last_time = (t2 - t1) / 1000.0f;
791805
pretty_progress(tile_count, num_tiles, last_time);

stable-diffusion.cpp

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ class StableDiffusionGGML {
351351
offload_params_to_cpu,
352352
model_loader.tensor_storages_types);
353353
diffusion_model = std::make_shared<MMDiTModel>(backend,
354-
offload_params_to_cpu,
355-
model_loader.tensor_storages_types);
354+
offload_params_to_cpu,
355+
model_loader.tensor_storages_types);
356356
} else if (sd_version_is_flux(version)) {
357357
bool is_chroma = false;
358358
for (auto pair : model_loader.tensor_storages_types) {
@@ -388,11 +388,11 @@ class StableDiffusionGGML {
388388
1,
389389
true);
390390
diffusion_model = std::make_shared<WanModel>(backend,
391-
offload_params_to_cpu,
392-
model_loader.tensor_storages_types,
393-
"model.diffusion_model",
394-
version,
395-
sd_ctx_params->diffusion_flash_attn);
391+
offload_params_to_cpu,
392+
model_loader.tensor_storages_types,
393+
"model.diffusion_model",
394+
version,
395+
sd_ctx_params->diffusion_flash_attn);
396396
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
397397
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
398398
offload_params_to_cpu,
@@ -1294,7 +1294,15 @@ class StableDiffusionGGML {
12941294
ggml_tensor* result = NULL;
12951295
if (!use_tiny_autoencoder) {
12961296
process_vae_input_tensor(x);
1297-
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
1297+
if (vae_tiling && !decode_video) {
1298+
// split latent in 32x32 tiles and compute in several steps
1299+
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1300+
first_stage_model->compute(n_threads, in, true, &out, NULL);
1301+
};
1302+
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, false);
1303+
} else {
1304+
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
1305+
}
12981306
first_stage_model->free_compute_buffer();
12991307
} else {
13001308
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
@@ -1321,12 +1329,12 @@ class StableDiffusionGGML {
13211329
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
13221330
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
13231331
latents_std_vec = {
1324-
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1325-
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1326-
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1327-
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1328-
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1329-
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
1332+
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1333+
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1334+
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1335+
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1336+
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1337+
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
13301338
}
13311339
for (int i = 0; i < latent->ne[3]; i++) {
13321340
float mean = latents_mean_vec[i];
@@ -1361,12 +1369,12 @@ class StableDiffusionGGML {
13611369
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
13621370
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
13631371
latents_std_vec = {
1364-
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1365-
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1366-
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1367-
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1368-
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1369-
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
1372+
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1373+
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1374+
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1375+
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1376+
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1377+
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
13701378
}
13711379
for (int i = 0; i < latent->ne[3]; i++) {
13721380
float mean = latents_mean_vec[i];
@@ -1424,7 +1432,7 @@ class StableDiffusionGGML {
14241432
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14251433
first_stage_model->compute(n_threads, in, true, &out, NULL);
14261434
};
1427-
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
1435+
sd_tiling(x, result, 8, 32, 0.5f, on_tiling, true);
14281436
} else {
14291437
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14301438
}
@@ -1436,7 +1444,7 @@ class StableDiffusionGGML {
14361444
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14371445
tae_first_stage->compute(n_threads, in, true, &out);
14381446
};
1439-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
1447+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, true);
14401448
} else {
14411449
tae_first_stage->compute(n_threads, x, true, &result);
14421450
}

0 commit comments

Comments
 (0)