Skip to content

Commit e0d0edb

Browse files
committed
conditionner: make text encoders optional for SD3.x
1 parent fd693ac commit e0d0edb

File tree

1 file changed

+142
-39
lines changed

1 file changed

+142
-39
lines changed

conditioner.hpp

Lines changed: 142 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -658,38 +658,108 @@ struct SD3CLIPEmbedder : public Conditioner {
658658
std::shared_ptr<CLIPTextModelRunner> clip_l;
659659
std::shared_ptr<CLIPTextModelRunner> clip_g;
660660
std::shared_ptr<T5Runner> t5;
661+
bool use_clip_l = false;
662+
bool use_clip_g = false;
663+
bool use_t5 = false;
661664

662665
SD3CLIPEmbedder(ggml_backend_t backend,
663666
bool offload_params_to_cpu,
664667
const String2GGMLType& tensor_types = {})
665668
: clip_g_tokenizer(0) {
666-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false);
667-
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false);
668-
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
669+
if (clip_skip <= 0) {
670+
clip_skip = 2;
671+
}
672+
673+
for (auto pair : tensor_types) {
674+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
675+
use_clip_l = true;
676+
} else if (pair.first.find("text_encoders.clip_g") != std::string::npos) {
677+
use_clip_g = true;
678+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
679+
use_t5 = true;
680+
}
681+
}
682+
if (!use_clip_l && !use_clip_g && !use_t5) {
683+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
684+
return;
685+
}
686+
if (use_clip_l) {
687+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, false);
688+
} else {
689+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
690+
}
691+
if (use_clip_g) {
692+
clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, clip_skip, false);
693+
} else {
694+
LOG_WARN("clip_g text encoder not found! Prompt adherence might be degraded.");
695+
}
696+
if (use_t5) {
697+
t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.t5xxl.transformer");
698+
} else {
699+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
700+
}
701+
set_clip_skip(clip_skip);
702+
}
703+
704+
void set_clip_skip(int clip_skip) {
705+
if (clip_skip <= 0) {
706+
clip_skip = 2;
707+
}
708+
if (use_clip_l) {
709+
clip_l->set_clip_skip(clip_skip);
710+
}
711+
if (use_clip_g) {
712+
clip_g->set_clip_skip(clip_skip);
713+
}
669714
}
670715

671716
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
672-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
673-
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
674-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
717+
if (use_clip_l) {
718+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
719+
}
720+
if (use_clip_g) {
721+
clip_g->get_param_tensors(tensors, "text_encoders.clip_g.transformer.text_model");
722+
}
723+
if (use_t5) {
724+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
725+
}
675726
}
676727

677728
void alloc_params_buffer() {
678-
clip_l->alloc_params_buffer();
679-
clip_g->alloc_params_buffer();
680-
t5->alloc_params_buffer();
729+
if (use_clip_l) {
730+
clip_l->alloc_params_buffer();
731+
}
732+
if (use_clip_g) {
733+
clip_g->alloc_params_buffer();
734+
}
735+
if (use_t5) {
736+
t5->alloc_params_buffer();
737+
}
681738
}
682739

683740
void free_params_buffer() {
684-
clip_l->free_params_buffer();
685-
clip_g->free_params_buffer();
686-
t5->free_params_buffer();
741+
if (use_clip_l) {
742+
clip_l->free_params_buffer();
743+
}
744+
if (use_clip_g) {
745+
clip_g->free_params_buffer();
746+
}
747+
if (use_t5) {
748+
t5->free_params_buffer();
749+
}
687750
}
688751

689752
size_t get_params_buffer_size() {
690-
size_t buffer_size = clip_l->get_params_buffer_size();
691-
buffer_size += clip_g->get_params_buffer_size();
692-
buffer_size += t5->get_params_buffer_size();
753+
size_t buffer_size = 0;
754+
if (use_clip_l) {
755+
buffer_size += clip_l->get_params_buffer_size();
756+
}
757+
if (use_clip_g) {
758+
buffer_size += clip_g->get_params_buffer_size();
759+
}
760+
if (use_t5) {
761+
buffer_size += t5->get_params_buffer_size();
762+
}
693763
return buffer_size;
694764
}
695765

@@ -721,23 +791,32 @@ struct SD3CLIPEmbedder : public Conditioner {
721791
for (const auto& item : parsed_attention) {
722792
const std::string& curr_text = item.first;
723793
float curr_weight = item.second;
724-
725-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
726-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
727-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
728-
729-
curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
730-
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
731-
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
732-
733-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
734-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
735-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
794+
if (use_clip_l) {
795+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
796+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
797+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
798+
}
799+
if (use_clip_g) {
800+
std::vector<int> curr_tokens = clip_g_tokenizer.encode(curr_text, on_new_token_cb);
801+
clip_g_tokens.insert(clip_g_tokens.end(), curr_tokens.begin(), curr_tokens.end());
802+
clip_g_weights.insert(clip_g_weights.end(), curr_tokens.size(), curr_weight);
803+
}
804+
if (use_t5) {
805+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
806+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
807+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
808+
}
736809
}
737810

738-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
739-
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
740-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
811+
if (use_clip_l) {
812+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
813+
}
814+
if (use_clip_g) {
815+
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
816+
}
817+
if (use_t5) {
818+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
819+
}
741820

742821
// for (int i = 0; i < clip_l_tokens.size(); i++) {
743822
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -785,10 +864,10 @@ struct SD3CLIPEmbedder : public Conditioner {
785864
std::vector<float> hidden_states_vec;
786865

787866
size_t chunk_len = 77;
788-
size_t chunk_count = clip_l_tokens.size() / chunk_len;
867+
size_t chunk_count = std::max(std::max(clip_l_tokens.size(), clip_g_tokens.size()), t5_tokens.size()) / chunk_len;
789868
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
790869
// clip_l
791-
{
870+
if (use_clip_l) {
792871
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
793872
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
794873
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
@@ -835,10 +914,17 @@ struct SD3CLIPEmbedder : public Conditioner {
835914
&pooled_l,
836915
work_ctx);
837916
}
917+
} else {
918+
chunk_hidden_states_l = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, chunk_len);
919+
ggml_set_f32(chunk_hidden_states_l, 0.f);
920+
if (chunk_idx == 0) {
921+
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
922+
ggml_set_f32(pooled_l, 0.f);
923+
}
838924
}
839925

840926
// clip_g
841-
{
927+
if (use_clip_g) {
842928
std::vector<int> chunk_tokens(clip_g_tokens.begin() + chunk_idx * chunk_len,
843929
clip_g_tokens.begin() + (chunk_idx + 1) * chunk_len);
844930
std::vector<float> chunk_weights(clip_g_weights.begin() + chunk_idx * chunk_len,
@@ -886,10 +972,17 @@ struct SD3CLIPEmbedder : public Conditioner {
886972
&pooled_g,
887973
work_ctx);
888974
}
975+
} else {
976+
chunk_hidden_states_g = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 1280, chunk_len);
977+
ggml_set_f32(chunk_hidden_states_g, 0.f);
978+
if (chunk_idx == 0) {
979+
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
980+
ggml_set_f32(pooled_g, 0.f);
981+
}
889982
}
890983

891984
// t5
892-
{
985+
if (use_t5) {
893986
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
894987
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
895988
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -917,6 +1010,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9171010
float new_mean = ggml_tensor_mean(tensor);
9181011
ggml_tensor_scale(tensor, (original_mean / new_mean));
9191012
}
1013+
} else {
1014+
chunk_hidden_states_t5 = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
9201015
}
9211016

9221017
auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d(work_ctx,
@@ -959,11 +1054,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9591054
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
9601055
}
9611056

962-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
963-
hidden_states = ggml_reshape_2d(work_ctx,
964-
hidden_states,
965-
chunk_hidden_states->ne[0],
966-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1057+
if (hidden_states_vec.size() > 0) {
1058+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1059+
hidden_states = ggml_reshape_2d(work_ctx,
1060+
hidden_states,
1061+
chunk_hidden_states->ne[0],
1062+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1063+
} else {
1064+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 0);
1065+
}
1066+
if (pooled == NULL) {
1067+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 2048);
1068+
ggml_set_f32(pooled, 0.f);
1069+
}
9671070
return SDCondition(hidden_states, pooled, NULL);
9681071
}
9691072

0 commit comments

Comments
 (0)