1010#include " ggml_extend.hpp"
1111
1212struct UCacheConfig {
13- bool enabled = false ;
14- float reuse_threshold = 1 .0f ;
15- float start_percent = 0 .15f ;
16- float end_percent = 0 .95f ;
13+ bool enabled = false ;
14+ float reuse_threshold = 1 .0f ;
15+ float start_percent = 0 .15f ;
16+ float end_percent = 0 .95f ;
17+ float error_decay_rate = 1 .0f ;
18+ bool use_relative_threshold = true ;
19+ bool adaptive_threshold = true ;
20+ float early_step_multiplier = 0 .5f ;
21+ float late_step_multiplier = 1 .5f ;
1722};
1823
1924struct UCacheCacheEntry {
@@ -44,6 +49,45 @@ struct UCacheState {
4449 bool has_last_input_change = false ;
4550 int total_steps_skipped = 0 ;
4651 int current_step_index = -1 ;
52+ int steps_computed_since_active = 0 ;
53+ float accumulated_error = 0 .0f ;
54+ float reference_output_norm = 0 .0f ;
55+
56+ struct BlockMetrics {
57+ float sum_transformation_rate = 0 .0f ;
58+ float sum_output_norm = 0 .0f ;
59+ int sample_count = 0 ;
60+ float min_change_rate = std::numeric_limits<float >::max();
61+ float max_change_rate = 0 .0f ;
62+
63+ void reset () {
64+ sum_transformation_rate = 0 .0f ;
65+ sum_output_norm = 0 .0f ;
66+ sample_count = 0 ;
67+ min_change_rate = std::numeric_limits<float >::max ();
68+ max_change_rate = 0 .0f ;
69+ }
70+
71+ void record (float change_rate, float output_norm) {
72+ if (std::isfinite (change_rate) && change_rate > 0 .0f ) {
73+ sum_transformation_rate += change_rate;
74+ sum_output_norm += output_norm;
75+ sample_count++;
76+ if (change_rate < min_change_rate) min_change_rate = change_rate;
77+ if (change_rate > max_change_rate) max_change_rate = change_rate;
78+ }
79+ }
80+
81+ float avg_transformation_rate () const {
82+ return (sample_count > 0 ) ? (sum_transformation_rate / sample_count) : 0 .0f ;
83+ }
84+
85+ float avg_output_norm () const {
86+ return (sample_count > 0 ) ? (sum_output_norm / sample_count) : 0 .0f ;
87+ }
88+ };
89+ BlockMetrics block_metrics;
90+ int total_active_steps = 0 ;
4791
4892 void reset_runtime () {
4993 initial_step = true ;
@@ -64,6 +108,11 @@ struct UCacheState {
64108 has_last_input_change = false ;
65109 total_steps_skipped = 0 ;
66110 current_step_index = -1 ;
111+ steps_computed_since_active = 0 ;
112+ accumulated_error = 0 .0f ;
113+ reference_output_norm = 0 .0f ;
114+ block_metrics.reset ();
115+ total_active_steps = 0 ;
67116 }
68117
69118 void init (const UCacheConfig& cfg, Denoiser* d) {
@@ -114,6 +163,7 @@ struct UCacheState {
114163 return ;
115164 }
116165 step_active = true ;
166+ total_active_steps++;
117167 }
118168
119169 bool step_is_active () const {
@@ -124,6 +174,31 @@ struct UCacheState {
124174 return enabled () && step_active && skip_current_step;
125175 }
126176
177+ float get_adaptive_threshold (int estimated_total_steps = 0 ) const {
178+ float base_threshold = config.reuse_threshold ;
179+
180+ if (!config.adaptive_threshold ) {
181+ return base_threshold;
182+ }
183+
184+ int effective_total = estimated_total_steps;
185+ if (effective_total <= 0 ) {
186+ effective_total = std::max (20 , steps_computed_since_active * 2 );
187+ }
188+
189+ float progress = (effective_total > 0 ) ?
190+ (static_cast <float >(steps_computed_since_active) / effective_total) : 0 .0f ;
191+
192+ float multiplier = 1 .0f ;
193+ if (progress < 0 .2f ) {
194+ multiplier = config.early_step_multiplier ;
195+ } else if (progress > 0 .8f ) {
196+ multiplier = config.late_step_multiplier ;
197+ }
198+
199+ return base_threshold * multiplier;
200+ }
201+
127202 bool has_cache (const SDCondition* cond) const {
128203 auto it = cache_diffs.find (cond);
129204 return it != cache_diffs.end () && !it->second .diff .empty ();
@@ -212,15 +287,18 @@ struct UCacheState {
212287 last_input_change > 0 .0f && output_prev_norm > 0 .0f ) {
213288
214289 float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
215- cumulative_change_rate += approx_output_change_rate;
290+ accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
291+
292+ float effective_threshold = get_adaptive_threshold ();
293+ if (config.use_relative_threshold && reference_output_norm > 0 .0f ) {
294+ effective_threshold = effective_threshold * reference_output_norm;
295+ }
216296
217- if (cumulative_change_rate < config. reuse_threshold ) {
297+ if (accumulated_error < effective_threshold ) {
218298 skip_current_step = true ;
219299 total_steps_skipped++;
220300 apply_cache (cond, input, output);
221301 return true ;
222- } else {
223- cumulative_change_rate = 0 .0f ;
224302 }
225303 }
226304
@@ -270,16 +348,31 @@ struct UCacheState {
270348 output_prev_norm = (ne > 0 ) ? (mean_abs / static_cast <float >(ne)) : 0 .0f ;
271349 has_output_prev_norm = output_prev_norm > 0 .0f ;
272350
351+ if (reference_output_norm == 0 .0f ) {
352+ reference_output_norm = output_prev_norm;
353+ }
354+
273355 if (has_last_input_change && last_input_change > 0 .0f && output_change > 0 .0f ) {
274356 float rate = output_change / last_input_change;
275357 if (std::isfinite (rate)) {
276358 relative_transformation_rate = rate;
277359 has_relative_transformation_rate = true ;
360+ block_metrics.record (rate, output_prev_norm);
278361 }
279362 }
280363
281- cumulative_change_rate = 0 .0f ;
282- has_last_input_change = false ;
364+ has_last_input_change = false ;
365+ }
366+
367+ void log_block_metrics () const {
368+ if (block_metrics.sample_count > 0 ) {
369+ LOG_INFO (" UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f" ,
370+ block_metrics.sample_count ,
371+ block_metrics.avg_transformation_rate (),
372+ block_metrics.min_change_rate ,
373+ block_metrics.max_change_rate ,
374+ block_metrics.avg_output_norm ());
375+ }
283376 }
284377};
285378
0 commit comments