Skip to content

Commit a797e70

Browse files
committed
reworked dit sampling
1 parent e402958 commit a797e70

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

stable-diffusion.cpp

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ struct EasyCacheState {
199199
return enabled() && step_active;
200200
}
201201

202+
bool is_step_skipped() const {
203+
return enabled() && step_active && skip_current_step;
204+
}
205+
202206
bool has_cache(const SDCondition* cond) const {
203207
auto it = cache_diffs.find(cond);
204208
return it != cache_diffs.end() && !it->second.diff.empty();
@@ -1865,12 +1869,38 @@ class StableDiffusionGGML {
18651869
pretty_progress(0, (int)steps, 0);
18661870
}
18671871

1872+
DiffusionParams diffusion_params;
1873+
18681874
const bool easycache_step_active = easycache_enabled && step > 0;
18691875
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
18701876
if (easycache_step_active) {
18711877
easycache_state.begin_step(easycache_step_index, sigma);
18721878
}
18731879

1880+
auto easycache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
1881+
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
1882+
return false;
1883+
}
1884+
return easycache_state.before_condition(condition,
1885+
diffusion_params.x,
1886+
output_tensor,
1887+
sigma,
1888+
easycache_step_index);
1889+
};
1890+
1891+
auto easycache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
1892+
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
1893+
return;
1894+
}
1895+
easycache_state.after_condition(condition,
1896+
diffusion_params.x,
1897+
output_tensor);
1898+
};
1899+
1900+
auto easycache_step_is_skipped = [&]() {
1901+
return easycache_step_active && easycache_state.is_step_skipped();
1902+
};
1903+
18741904
std::vector<float> scaling = denoiser->get_scalings(sigma);
18751905
GGML_ASSERT(scaling.size() == 3);
18761906
float c_skip = scaling[0];
@@ -1916,7 +1946,6 @@ class StableDiffusionGGML {
19161946
// GGML_ASSERT(0);
19171947
}
19181948

1919-
DiffusionParams diffusion_params;
19201949
diffusion_params.x = noised_input;
19211950
diffusion_params.timesteps = timesteps;
19221951
diffusion_params.guidance = guidance_tensor;
@@ -1942,39 +1971,35 @@ class StableDiffusionGGML {
19421971
active_condition = &id_cond;
19431972
}
19441973

1945-
bool skip_model = false;
1946-
if (easycache_step_active && active_condition != nullptr) {
1947-
skip_model = easycache_state.before_condition(active_condition,
1948-
diffusion_params.x,
1949-
*active_output,
1950-
sigma,
1951-
easycache_step_index);
1952-
}
1974+
bool skip_model = easycache_before_condition(active_condition, *active_output);
19531975
if (!skip_model) {
19541976
work_diffusion_model->compute(n_threads,
19551977
diffusion_params,
19561978
active_output);
1957-
if (easycache_step_active && active_condition != nullptr) {
1958-
easycache_state.after_condition(active_condition,
1959-
diffusion_params.x,
1960-
*active_output);
1961-
}
1979+
easycache_after_condition(active_condition, *active_output);
19621980
}
19631981

1982+
bool current_step_skipped = easycache_step_is_skipped();
1983+
19641984
float* negative_data = nullptr;
19651985
if (has_unconditioned) {
19661986
// uncond
1967-
if (control_hint != nullptr && control_net != nullptr) {
1987+
if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
19681988
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
19691989
controls = control_net->controls;
19701990
}
1991+
current_step_skipped = easycache_step_is_skipped();
19711992
diffusion_params.controls = controls;
19721993
diffusion_params.context = uncond.c_crossattn;
19731994
diffusion_params.c_concat = uncond.c_concat;
19741995
diffusion_params.y = uncond.c_vector;
1975-
work_diffusion_model->compute(n_threads,
1976-
diffusion_params,
1977-
&out_uncond);
1996+
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
1997+
if (!skip_uncond) {
1998+
work_diffusion_model->compute(n_threads,
1999+
diffusion_params,
2000+
&out_uncond);
2001+
easycache_after_condition(&uncond, out_uncond);
2002+
}
19782003
negative_data = (float*)out_uncond->data;
19792004
}
19802005

@@ -1983,25 +2008,31 @@ class StableDiffusionGGML {
19832008
diffusion_params.context = img_cond.c_crossattn;
19842009
diffusion_params.c_concat = img_cond.c_concat;
19852010
diffusion_params.y = img_cond.c_vector;
1986-
work_diffusion_model->compute(n_threads,
1987-
diffusion_params,
1988-
&out_img_cond);
2011+
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
2012+
if (!skip_img_cond) {
2013+
work_diffusion_model->compute(n_threads,
2014+
diffusion_params,
2015+
&out_img_cond);
2016+
easycache_after_condition(&img_cond, out_img_cond);
2017+
}
19892018
img_cond_data = (float*)out_img_cond->data;
19902019
}
19912020

19922021
int step_count = sigmas.size();
19932022
bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count);
1994-
float* skip_layer_data = nullptr;
2023+
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
19952024
if (is_skiplayer_step) {
19962025
LOG_DEBUG("Skipping layers at step %d\n", step);
1997-
// skip layer (same as conditionned)
1998-
diffusion_params.context = cond.c_crossattn;
1999-
diffusion_params.c_concat = cond.c_concat;
2000-
diffusion_params.y = cond.c_vector;
2001-
diffusion_params.skip_layers = skip_layers;
2002-
work_diffusion_model->compute(n_threads,
2003-
diffusion_params,
2004-
&out_skip);
2026+
if (!easycache_step_is_skipped()) {
2027+
// skip layer (same as conditioned)
2028+
diffusion_params.context = cond.c_crossattn;
2029+
diffusion_params.c_concat = cond.c_concat;
2030+
diffusion_params.y = cond.c_vector;
2031+
diffusion_params.skip_layers = skip_layers;
2032+
work_diffusion_model->compute(n_threads,
2033+
diffusion_params,
2034+
&out_skip);
2035+
}
20052036
skip_layer_data = (float*)out_skip->data;
20062037
}
20072038
float* vec_denoised = (float*)denoised->data;

0 commit comments

Comments
 (0)