Skip to content

Commit 86d176b

Browse files
committed
add decay rate and relative threshold
1 parent c8cc665 commit 86d176b

File tree

5 files changed

+151
-42
lines changed

5 files changed

+151
-42
lines changed

examples/cli/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,7 @@ Generation Options:
125125
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126126
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
127127
--cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)
128-
--cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)
128+
--cache-option cache parameters: easycache uses "threshold,start,end" (default: 0.2,0.15,0.95).
129+
ucache uses "threshold,start,end[,decay,relative]" (default: 1.0,0.15,0.95,1.0,1).
130+
decay: error decay rate (0.0-1.0), relative: use relative threshold (0 or 1)
129131
```

examples/cli/main.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,7 +1441,7 @@ struct SDGenerationParams {
14411441
on_cache_mode_arg},
14421442
{"",
14431443
"--cache-option",
1444-
"cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)",
1444+
"cache parameters \"threshold,start,end[,warmup,decay,relative]\" (ucache extended: warmup=0, decay=1.0, relative=1)",
14451445
on_cache_option_arg},
14461446

14471447
};
@@ -1568,28 +1568,32 @@ struct SDGenerationParams {
15681568
}
15691569
}
15701570

1571-
float values[3] = {0.0f, 0.0f, 0.0f};
1571+
// Format: threshold,start,end[,decay,relative]
1572+
// - values[0-2]: threshold, start_percent, end_percent (required)
1573+
// - values[3]: error_decay_rate (optional, default: 1.0)
1574+
// - values[4]: use_relative_threshold (optional, 0 or 1, default: 1)
1575+
float values[5] = {0.0f, 0.0f, 0.0f, 1.0f, 1.0f};
15721576
std::stringstream ss(option_str);
15731577
std::string token;
15741578
int idx = 0;
1579+
auto trim = [](std::string& s) {
1580+
const char* whitespace = " \t\r\n";
1581+
auto start = s.find_first_not_of(whitespace);
1582+
if (start == std::string::npos) {
1583+
s.clear();
1584+
return;
1585+
}
1586+
auto end = s.find_last_not_of(whitespace);
1587+
s = s.substr(start, end - start + 1);
1588+
};
15751589
while (std::getline(ss, token, ',')) {
1576-
auto trim = [](std::string& s) {
1577-
const char* whitespace = " \t\r\n";
1578-
auto start = s.find_first_not_of(whitespace);
1579-
if (start == std::string::npos) {
1580-
s.clear();
1581-
return;
1582-
}
1583-
auto end = s.find_last_not_of(whitespace);
1584-
s = s.substr(start, end - start + 1);
1585-
};
15861590
trim(token);
15871591
if (token.empty()) {
15881592
fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str());
15891593
return false;
15901594
}
1591-
if (idx >= 3) {
1592-
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
1595+
if (idx >= 5) {
1596+
fprintf(stderr, "error: cache option expects 3-5 comma-separated values (threshold,start,end[,decay,relative])\n");
15931597
return false;
15941598
}
15951599
try {
@@ -1600,8 +1604,8 @@ struct SDGenerationParams {
16001604
}
16011605
idx++;
16021606
}
1603-
if (idx != 3) {
1604-
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
1607+
if (idx < 3) {
1608+
fprintf(stderr, "error: cache option expects at least 3 comma-separated values (threshold,start,end)\n");
16051609
return false;
16061610
}
16071611
if (values[0] < 0.0f) {
@@ -1619,10 +1623,12 @@ struct SDGenerationParams {
16191623
easycache_params.start_percent = values[1];
16201624
easycache_params.end_percent = values[2];
16211625
} else {
1622-
ucache_params.enabled = true;
1623-
ucache_params.reuse_threshold = values[0];
1624-
ucache_params.start_percent = values[1];
1625-
ucache_params.end_percent = values[2];
1626+
ucache_params.enabled = true;
1627+
ucache_params.reuse_threshold = values[0];
1628+
ucache_params.start_percent = values[1];
1629+
ucache_params.end_percent = values[2];
1630+
ucache_params.error_decay_rate = values[3];
1631+
ucache_params.use_relative_threshold = (values[4] != 0.0f);
16261632
}
16271633
}
16281634

stable-diffusion.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,10 +1538,12 @@ class StableDiffusionGGML {
15381538
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
15391539
} else {
15401540
UCacheConfig ucache_config;
1541-
ucache_config.enabled = true;
1542-
ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold);
1543-
ucache_config.start_percent = ucache_params->start_percent;
1544-
ucache_config.end_percent = ucache_params->end_percent;
1541+
ucache_config.enabled = true;
1542+
ucache_config.reuse_threshold = std::max(0.0f, ucache_params->reuse_threshold);
1543+
ucache_config.start_percent = ucache_params->start_percent;
1544+
ucache_config.end_percent = ucache_params->end_percent;
1545+
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, ucache_params->error_decay_rate));
1546+
ucache_config.use_relative_threshold = ucache_params->use_relative_threshold;
15451547
bool percent_valid = ucache_config.start_percent >= 0.0f &&
15461548
ucache_config.start_percent < 1.0f &&
15471549
ucache_config.end_percent > 0.0f &&
@@ -1555,10 +1557,12 @@ class StableDiffusionGGML {
15551557
ucache_state.init(ucache_config, denoiser.get());
15561558
if (ucache_state.enabled()) {
15571559
ucache_enabled = true;
1558-
LOG_INFO("UCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
1560+
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s",
15591561
ucache_config.reuse_threshold,
15601562
ucache_config.start_percent,
1561-
ucache_config.end_percent);
1563+
ucache_config.end_percent,
1564+
ucache_config.error_decay_rate,
1565+
ucache_config.use_relative_threshold ? "true" : "false");
15621566
} else {
15631567
LOG_WARN("UCache requested but could not be initialized for this run");
15641568
}
@@ -2594,11 +2598,13 @@ void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
25942598
}
25952599

25962600
void sd_ucache_params_init(sd_ucache_params_t* ucache_params) {
2597-
*ucache_params = {};
2598-
ucache_params->enabled = false;
2599-
ucache_params->reuse_threshold = 1.0f;
2600-
ucache_params->start_percent = 0.15f;
2601-
ucache_params->end_percent = 0.95f;
2601+
*ucache_params = {};
2602+
ucache_params->enabled = false;
2603+
ucache_params->reuse_threshold = 1.0f;
2604+
ucache_params->start_percent = 0.15f;
2605+
ucache_params->end_percent = 0.95f;
2606+
ucache_params->error_decay_rate = 1.0f;
2607+
ucache_params->use_relative_threshold = true;
26022608
}
26032609

26042610
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {

stable-diffusion.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ typedef struct {
246246
float reuse_threshold;
247247
float start_percent;
248248
float end_percent;
249+
float error_decay_rate;
250+
bool use_relative_threshold;
249251
} sd_ucache_params_t;
250252

251253
typedef struct {

ucache.hpp

Lines changed: 103 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
#include "ggml_extend.hpp"
1111

1212
struct UCacheConfig {
13-
bool enabled = false;
14-
float reuse_threshold = 1.0f;
15-
float start_percent = 0.15f;
16-
float end_percent = 0.95f;
13+
bool enabled = false;
14+
float reuse_threshold = 1.0f;
15+
float start_percent = 0.15f;
16+
float end_percent = 0.95f;
17+
float error_decay_rate = 1.0f;
18+
bool use_relative_threshold = true;
19+
bool adaptive_threshold = true;
20+
float early_step_multiplier = 0.5f;
21+
float late_step_multiplier = 1.5f;
1722
};
1823

1924
struct UCacheCacheEntry {
@@ -44,6 +49,45 @@ struct UCacheState {
4449
bool has_last_input_change = false;
4550
int total_steps_skipped = 0;
4651
int current_step_index = -1;
52+
int steps_computed_since_active = 0;
53+
float accumulated_error = 0.0f;
54+
float reference_output_norm = 0.0f;
55+
56+
struct BlockMetrics {
57+
float sum_transformation_rate = 0.0f;
58+
float sum_output_norm = 0.0f;
59+
int sample_count = 0;
60+
float min_change_rate = std::numeric_limits<float>::max();
61+
float max_change_rate = 0.0f;
62+
63+
void reset() {
64+
sum_transformation_rate = 0.0f;
65+
sum_output_norm = 0.0f;
66+
sample_count = 0;
67+
min_change_rate = std::numeric_limits<float>::max();
68+
max_change_rate = 0.0f;
69+
}
70+
71+
void record(float change_rate, float output_norm) {
72+
if (std::isfinite(change_rate) && change_rate > 0.0f) {
73+
sum_transformation_rate += change_rate;
74+
sum_output_norm += output_norm;
75+
sample_count++;
76+
if (change_rate < min_change_rate) min_change_rate = change_rate;
77+
if (change_rate > max_change_rate) max_change_rate = change_rate;
78+
}
79+
}
80+
81+
float avg_transformation_rate() const {
82+
return (sample_count > 0) ? (sum_transformation_rate / sample_count) : 0.0f;
83+
}
84+
85+
float avg_output_norm() const {
86+
return (sample_count > 0) ? (sum_output_norm / sample_count) : 0.0f;
87+
}
88+
};
89+
BlockMetrics block_metrics;
90+
int total_active_steps = 0;
4791

4892
void reset_runtime() {
4993
initial_step = true;
@@ -64,6 +108,11 @@ struct UCacheState {
64108
has_last_input_change = false;
65109
total_steps_skipped = 0;
66110
current_step_index = -1;
111+
steps_computed_since_active = 0;
112+
accumulated_error = 0.0f;
113+
reference_output_norm = 0.0f;
114+
block_metrics.reset();
115+
total_active_steps = 0;
67116
}
68117

69118
void init(const UCacheConfig& cfg, Denoiser* d) {
@@ -114,6 +163,7 @@ struct UCacheState {
114163
return;
115164
}
116165
step_active = true;
166+
total_active_steps++;
117167
}
118168

119169
bool step_is_active() const {
@@ -124,6 +174,31 @@ struct UCacheState {
124174
return enabled() && step_active && skip_current_step;
125175
}
126176

177+
float get_adaptive_threshold(int estimated_total_steps = 0) const {
178+
float base_threshold = config.reuse_threshold;
179+
180+
if (!config.adaptive_threshold) {
181+
return base_threshold;
182+
}
183+
184+
int effective_total = estimated_total_steps;
185+
if (effective_total <= 0) {
186+
effective_total = std::max(20, steps_computed_since_active * 2);
187+
}
188+
189+
float progress = (effective_total > 0) ?
190+
(static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
191+
192+
float multiplier = 1.0f;
193+
if (progress < 0.2f) {
194+
multiplier = config.early_step_multiplier;
195+
} else if (progress > 0.8f) {
196+
multiplier = config.late_step_multiplier;
197+
}
198+
199+
return base_threshold * multiplier;
200+
}
201+
127202
bool has_cache(const SDCondition* cond) const {
128203
auto it = cache_diffs.find(cond);
129204
return it != cache_diffs.end() && !it->second.diff.empty();
@@ -212,15 +287,18 @@ struct UCacheState {
212287
last_input_change > 0.0f && output_prev_norm > 0.0f) {
213288

214289
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
215-
cumulative_change_rate += approx_output_change_rate;
290+
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
291+
292+
float effective_threshold = get_adaptive_threshold();
293+
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
294+
effective_threshold = effective_threshold * reference_output_norm;
295+
}
216296

217-
if (cumulative_change_rate < config.reuse_threshold) {
297+
if (accumulated_error < effective_threshold) {
218298
skip_current_step = true;
219299
total_steps_skipped++;
220300
apply_cache(cond, input, output);
221301
return true;
222-
} else {
223-
cumulative_change_rate = 0.0f;
224302
}
225303
}
226304

@@ -270,16 +348,31 @@ struct UCacheState {
270348
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
271349
has_output_prev_norm = output_prev_norm > 0.0f;
272350

351+
if (reference_output_norm == 0.0f) {
352+
reference_output_norm = output_prev_norm;
353+
}
354+
273355
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
274356
float rate = output_change / last_input_change;
275357
if (std::isfinite(rate)) {
276358
relative_transformation_rate = rate;
277359
has_relative_transformation_rate = true;
360+
block_metrics.record(rate, output_prev_norm);
278361
}
279362
}
280363

281-
cumulative_change_rate = 0.0f;
282-
has_last_input_change = false;
364+
has_last_input_change = false;
365+
}
366+
367+
void log_block_metrics() const {
368+
if (block_metrics.sample_count > 0) {
369+
LOG_INFO("UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f",
370+
block_metrics.sample_count,
371+
block_metrics.avg_transformation_rate(),
372+
block_metrics.min_change_rate,
373+
block_metrics.max_change_rate,
374+
block_metrics.avg_output_norm());
375+
}
283376
}
284377
};
285378

0 commit comments

Comments
 (0)