Skip to content

Commit e402958

Browse files
committed
add easycache
1 parent 28ffb6c commit e402958

File tree

3 files changed

+448
-10
lines changed

3 files changed

+448
-10
lines changed

examples/cli/main.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ struct SDParams {
105105
std::vector<int> high_noise_skip_layers = {7, 8, 9};
106106
sd_sample_params_t high_noise_sample_params;
107107

108+
std::string easycache_option;
109+
sd_easycache_params_t easycache_params;
110+
108111
float moe_boundary = 0.875f;
109112
int video_frames = 1;
110113
int fps = 16;
@@ -154,6 +157,7 @@ struct SDParams {
154157
sd_sample_params_init(&sample_params);
155158
sd_sample_params_init(&high_noise_sample_params);
156159
high_noise_sample_params.sample_steps = -1;
160+
sd_easycache_params_init(&easycache_params);
157161
}
158162
};
159163

@@ -225,6 +229,11 @@ void print_params(SDParams params) {
225229
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
226230
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
227231
printf(" video_frames: %d\n", params.video_frames);
232+
printf(" easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
233+
params.easycache_params.enabled ? "enabled" : "disabled",
234+
params.easycache_params.reuse_threshold,
235+
params.easycache_params.start_percent,
236+
params.easycache_params.end_percent);
228237
printf(" vace_strength: %.2f\n", params.vace_strength);
229238
printf(" fps: %d\n", params.fps);
230239
printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised");
@@ -616,6 +625,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
616625
"--upscale-model",
617626
"path to esrgan model.",
618627
&params.esrgan_path},
628+
{"",
629+
"--easycache",
630+
"enable EasyCache for DiT models with \"threshold,start_percent,end_percent\" (example: 0.2,0.15,0.95)",
631+
&params.easycache_option},
619632
};
620633

621634
options.int_options = {
@@ -1215,6 +1228,59 @@ void parse_args(int argc, const char** argv, SDParams& params) {
12151228
exit(1);
12161229
}
12171230

1231+
if (!params.easycache_option.empty()) {
1232+
float values[3] = {0.0f, 0.0f, 0.0f};
1233+
std::stringstream ss(params.easycache_option);
1234+
std::string token;
1235+
int idx = 0;
1236+
while (std::getline(ss, token, ',')) {
1237+
auto trim = [](std::string& s) {
1238+
const char* whitespace = " \t\r\n";
1239+
auto start = s.find_first_not_of(whitespace);
1240+
if (start == std::string::npos) {
1241+
s.clear();
1242+
return;
1243+
}
1244+
auto end = s.find_last_not_of(whitespace);
1245+
s = s.substr(start, end - start + 1);
1246+
};
1247+
trim(token);
1248+
if (token.empty()) {
1249+
fprintf(stderr, "error: invalid easycache option '%s'\n", params.easycache_option.c_str());
1250+
exit(1);
1251+
}
1252+
if (idx >= 3) {
1253+
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1254+
exit(1);
1255+
}
1256+
try {
1257+
values[idx] = std::stof(token);
1258+
} catch (const std::exception&) {
1259+
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
1260+
exit(1);
1261+
}
1262+
idx++;
1263+
}
1264+
if (idx != 3) {
1265+
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1266+
exit(1);
1267+
}
1268+
if (values[0] < 0.0f) {
1269+
fprintf(stderr, "error: easycache threshold must be non-negative\n");
1270+
exit(1);
1271+
}
1272+
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1273+
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1274+
exit(1);
1275+
}
1276+
params.easycache_params.enabled = true;
1277+
params.easycache_params.reuse_threshold = values[0];
1278+
params.easycache_params.start_percent = values[1];
1279+
params.easycache_params.end_percent = values[2];
1280+
} else {
1281+
params.easycache_params.enabled = false;
1282+
}
1283+
12181284
if (params.n_threads <= 0) {
12191285
params.n_threads = get_num_physical_cores();
12201286
}
@@ -1852,6 +1918,7 @@ int main(int argc, const char* argv[]) {
18521918
params.pm_style_strength,
18531919
}, // pm_params
18541920
params.vae_tiling_params,
1921+
params.easycache_params,
18551922
};
18561923

18571924
results = generate_image(sd_ctx, &img_gen_params);
@@ -1874,6 +1941,7 @@ int main(int argc, const char* argv[]) {
18741941
params.seed,
18751942
params.video_frames,
18761943
params.vace_strength,
1944+
params.easycache_params,
18771945
};
18781946

18791947
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

0 commit comments

Comments
 (0)