Skip to content

Commit 2af9294

Browse files
committed
vae tiling: refactor again, base on smaller buffer for alignment
1 parent 5054cf4 commit 2af9294

File tree

2 files changed

+77
-44
lines changed

2 files changed

+77
-44
lines changed

ggml_extend.hpp

Lines changed: 65 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -749,62 +749,67 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_chunk(struct ggml_contex
749749
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
750750

751751
// Tiling
752-
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing, bool scaled_out = true) {
752+
__STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
753753
output = ggml_set_f32(output, 0);
754754

755755
int input_width = (int)input->ne[0];
756756
int input_height = (int)input->ne[1];
757757
int output_width = (int)output->ne[0];
758758
int output_height = (int)output->ne[1];
759759

760-
int input_tile_size, output_tile_size;
761-
if (scaled_out) {
762-
input_tile_size = tile_size;
763-
output_tile_size = tile_size * scale;
764-
} else {
765-
input_tile_size = tile_size * scale;
766-
output_tile_size = tile_size;
760+
GGML_ASSERT(input_width / output_width == input_height / output_height && output_width / input_width == output_height / input_height);
761+
GGML_ASSERT(input_width / output_width == scale || output_width / input_width == scale);
762+
763+
int small_width = output_width;
764+
int small_height = output_height;
765+
766+
bool big_out = output_width > input_width;
767+
if (big_out) {
768+
// Ex: decode
769+
small_width = input_width;
770+
small_height = input_height;
767771
}
768-
int tile_overlap = (input_tile_size * tile_overlap_factor);
769-
int non_tile_overlap = input_tile_size - tile_overlap;
770772

771-
int num_tiles_x = (input_width - tile_overlap) / non_tile_overlap;
772-
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % input_width;
773+
int tile_overlap = (tile_size * tile_overlap_factor);
774+
int non_tile_overlap = tile_size - tile_overlap;
775+
776+
int num_tiles_x = (small_width - tile_overlap) / non_tile_overlap;
777+
int overshoot_x = ((num_tiles_x + 1) * non_tile_overlap + tile_overlap) % small_width;
773778

774-
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (input_tile_size / 2 - tile_overlap))) {
779+
if ((overshoot_x != non_tile_overlap) && (overshoot_x <= num_tiles_x * (tile_size / 2 - tile_overlap))) {
775780
// if tiles don't fit perfectly using the desired overlap
776781
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
777782
num_tiles_x++;
778783
}
779784

780-
float tile_overlap_factor_x = (float)(input_tile_size * num_tiles_x - input_width) / (float)(input_tile_size * (num_tiles_x - 1));
785+
float tile_overlap_factor_x = (float)(tile_size * num_tiles_x - small_width) / (float)(tile_size * (num_tiles_x - 1));
781786
if (num_tiles_x <= 2) {
782-
if (input_width <= input_tile_size) {
787+
if (small_width <= tile_size) {
783788
num_tiles_x = 1;
784789
tile_overlap_factor_x = 0;
785790
} else {
786791
num_tiles_x = 2;
787-
tile_overlap_factor_x = (2 * input_tile_size - input_width) / (float)input_tile_size;
792+
tile_overlap_factor_x = (2 * tile_size - small_width) / (float)tile_size;
788793
}
789794
}
790795

791-
int num_tiles_y = (input_height - tile_overlap) / non_tile_overlap;
792-
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % input_height;
796+
int num_tiles_y = (small_height - tile_overlap) / non_tile_overlap;
797+
int overshoot_y = ((num_tiles_y + 1) * non_tile_overlap + tile_overlap) % small_height;
793798

794-
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (input_tile_size / 2 - tile_overlap))) {
799+
if ((overshoot_y != non_tile_overlap) && (overshoot_y <= num_tiles_y * (tile_size / 2 - tile_overlap))) {
795800
// if tiles don't fit perfectly using the desired overlap
796801
// and there is enough room to squeeze an extra tile without overlap becoming >0.5
797802
num_tiles_y++;
798803
}
799804

800-
float tile_overlap_factor_y = (float)(input_tile_size * num_tiles_y - input_height) / (float)(input_tile_size * (num_tiles_y - 1));
805+
float tile_overlap_factor_y = (float)(tile_size * num_tiles_y - small_height) / (float)(tile_size * (num_tiles_y - 1));
801806
if (num_tiles_y <= 2) {
802-
if (input_height <= input_tile_size) {
807+
if (small_height <= tile_size) {
803808
num_tiles_y = 1;
804809
tile_overlap_factor_y = 0;
805810
} else {
806811
num_tiles_y = 2;
807-
tile_overlap_factor_y = (2 * input_tile_size - input_height) / (float)input_tile_size;
812+
tile_overlap_factor_y = (2 * tile_size - small_height) / (float)tile_size;
808813
}
809814
}
810815

@@ -813,11 +818,20 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
813818

814819
GGML_ASSERT(input_width % 2 == 0 && input_height % 2 == 0 && output_width % 2 == 0 && output_height % 2 == 0); // should be multiple of 2
815820

816-
int tile_overlap_x = (int32_t)(input_tile_size * tile_overlap_factor_x);
817-
int non_tile_overlap_x = input_tile_size - tile_overlap_x;
821+
int tile_overlap_x = (int32_t)(tile_size * tile_overlap_factor_x);
822+
int non_tile_overlap_x = tile_size - tile_overlap_x;
818823

819-
int tile_overlap_y = (int32_t)(input_tile_size * tile_overlap_factor_y);
820-
int non_tile_overlap_y = input_tile_size - tile_overlap_y;
824+
int tile_overlap_y = (int32_t)(tile_size * tile_overlap_factor_y);
825+
int non_tile_overlap_y = tile_size - tile_overlap_y;
826+
827+
int input_tile_size = tile_size;
828+
int output_tile_size = tile_size;
829+
830+
if (big_out) {
831+
output_tile_size *= scale;
832+
} else {
833+
input_tile_size *= scale;
834+
}
821835

822836
struct ggml_init_params params = {};
823837
params.mem_size += input_tile_size * input_tile_size * input->ne[2] * sizeof(float); // input chunk
@@ -838,37 +852,48 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const
838852
// tiling
839853
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size, input_tile_size, input->ne[2], 1);
840854
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size, output_tile_size, output->ne[2], 1);
841-
on_processing(input_tile, NULL, true);
842855
int num_tiles = num_tiles_x * num_tiles_y;
843856
LOG_INFO("processing %i tiles", num_tiles);
844-
pretty_progress(1, num_tiles, 0.0f);
857+
pretty_progress(0, num_tiles, 0.0f);
845858
int tile_count = 1;
846859
bool last_y = false, last_x = false;
847860
float last_time = 0.0f;
848-
for (int y = 0; y < input_height && !last_y; y += non_tile_overlap_y) {
861+
for (int y = 0; y < small_height && !last_y; y += non_tile_overlap_y) {
849862
int dy = 0;
850-
if (y + input_tile_size >= input_height) {
863+
if (y + tile_size >= small_height) {
851864
int _y = y;
852-
y = input_height - input_tile_size;
865+
y = small_height - tile_size;
853866
dy = _y - y;
867+
if (big_out) {
868+
dy *= scale;
869+
}
854870
last_y = true;
855871
}
856-
for (int x = 0; x < input_width && !last_x; x += non_tile_overlap_x) {
872+
for (int x = 0; x < small_width && !last_x; x += non_tile_overlap_x) {
857873
int dx = 0;
858-
if (x + input_tile_size >= input_width) {
874+
if (x + tile_size >= small_width) {
859875
int _x = x;
860-
x = input_width - input_tile_size;
876+
x = small_width - tile_size;
861877
dx = _x - x;
878+
if (big_out) {
879+
dx *= scale;
880+
}
862881
last_x = true;
863882
}
883+
884+
int x_in = big_out ? x : scale * x;
885+
int y_in = big_out ? y : scale * y;
886+
int x_out = big_out ? x * scale : x;
887+
int y_out = big_out ? y * scale : y;
888+
889+
int overlap_x_out = big_out ? tile_overlap_x * scale : tile_overlap_x;
890+
int overlap_y_out = big_out ? tile_overlap_y * scale : tile_overlap_y;
891+
864892
int64_t t1 = ggml_time_ms();
865-
ggml_split_tensor_2d(input, input_tile, x, y);
893+
ggml_split_tensor_2d(input, input_tile, x_in, y_in);
866894
on_processing(input_tile, output_tile, false);
867-
if (scaled_out) {
868-
ggml_merge_tensor_2d(output_tile, output, x * scale, y * scale, tile_overlap_x * scale, tile_overlap_y * scale, dx * scale, dy * scale);
869-
} else {
870-
ggml_merge_tensor_2d(output_tile, output, x / scale, y / scale, tile_overlap_x / scale, tile_overlap_y / scale, dx / scale, dy / scale);
871-
}
895+
ggml_merge_tensor_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
896+
872897
int64_t t2 = ggml_time_ms();
873898
last_time = (t2 - t1) / 1000.0f;
874899
pretty_progress(tile_count, num_tiles, last_time);

stable-diffusion.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,13 +1324,21 @@ class StableDiffusionGGML {
13241324
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
13251325
first_stage_model->compute(n_threads, in, true, &out, NULL);
13261326
};
1327-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling, false);
1327+
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
13281328
} else {
13291329
first_stage_model->compute(n_threads, x, false, &result, work_ctx);
13301330
}
13311331
first_stage_model->free_compute_buffer();
13321332
} else {
1333-
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
1333+
if (vae_tiling && !decode_video) {
1334+
// split latent in 32x32 tiles and compute in several steps
1335+
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1336+
tae_first_stage->compute(n_threads, in, true, &out, NULL);
1337+
};
1338+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
1339+
} else {
1340+
tae_first_stage->compute(n_threads, x, false, &result, work_ctx);
1341+
}
13341342
tae_first_stage->free_compute_buffer();
13351343
}
13361344

@@ -1469,7 +1477,7 @@ class StableDiffusionGGML {
14691477
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14701478
first_stage_model->compute(n_threads, in, true, &out, NULL);
14711479
};
1472-
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling, true);
1480+
sd_tiling(x, result, 8, tile_size, 0.5f, on_tiling);
14731481
} else {
14741482
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
14751483
}
@@ -1481,7 +1489,7 @@ class StableDiffusionGGML {
14811489
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
14821490
tae_first_stage->compute(n_threads, in, true, &out);
14831491
};
1484-
sd_tiling(x, result, 8, 64, 0.5f, on_tiling, true);
1492+
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
14851493
} else {
14861494
tae_first_stage->compute(n_threads, x, true, &result);
14871495
}

0 commit comments

Comments
 (0)