@@ -658,38 +658,108 @@ struct SD3CLIPEmbedder : public Conditioner {
658658 std::shared_ptr<CLIPTextModelRunner> clip_l;
659659 std::shared_ptr<CLIPTextModelRunner> clip_g;
660660 std::shared_ptr<T5Runner> t5;
661+ bool use_clip_l = false ;
662+ bool use_clip_g = false ;
663+ bool use_t5 = false ;
661664
662665 SD3CLIPEmbedder (ggml_backend_t backend,
663666 bool offload_params_to_cpu,
664667 const String2GGMLType& tensor_types = {})
665668 : clip_g_tokenizer(0 ) {
666- clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, false );
667- clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, false );
668- t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.t5xxl.transformer" );
669+ if (clip_skip <= 0 ) {
670+ clip_skip = 2 ;
671+ }
672+
673+ for (auto pair : tensor_types) {
674+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
675+ use_clip_l = true ;
676+ } else if (pair.first .find (" text_encoders.clip_g" ) != std::string::npos) {
677+ use_clip_g = true ;
678+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
679+ use_t5 = true ;
680+ }
681+ }
682+ if (!use_clip_l && !use_clip_g && !use_t5) {
683+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
684+ return ;
685+ }
686+ if (use_clip_l) {
687+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, false );
688+ } else {
689+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
690+ }
691+ if (use_clip_g) {
692+ clip_g = std::make_shared<CLIPTextModelRunner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.clip_g.transformer.text_model" , OPEN_CLIP_VIT_BIGG_14, clip_skip, false );
693+ } else {
694+ LOG_WARN (" clip_g text encoder not found! Prompt adherence might be degraded." );
695+ }
696+ if (use_t5) {
697+ t5 = std::make_shared<T5Runner>(backend, offload_params_to_cpu, tensor_types, " text_encoders.t5xxl.transformer" );
698+ } else {
699+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
700+ }
701+ set_clip_skip (clip_skip);
702+ }
703+
704+ void set_clip_skip (int clip_skip) {
705+ if (clip_skip <= 0 ) {
706+ clip_skip = 2 ;
707+ }
708+ if (use_clip_l) {
709+ clip_l->set_clip_skip (clip_skip);
710+ }
711+ if (use_clip_g) {
712+ clip_g->set_clip_skip (clip_skip);
713+ }
669714 }
670715
671716 void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
672- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
673- clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
674- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
717+ if (use_clip_l) {
718+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
719+ }
720+ if (use_clip_g) {
721+ clip_g->get_param_tensors (tensors, " text_encoders.clip_g.transformer.text_model" );
722+ }
723+ if (use_t5) {
724+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
725+ }
675726 }
676727
677728 void alloc_params_buffer () {
678- clip_l->alloc_params_buffer ();
679- clip_g->alloc_params_buffer ();
680- t5->alloc_params_buffer ();
729+ if (use_clip_l) {
730+ clip_l->alloc_params_buffer ();
731+ }
732+ if (use_clip_g) {
733+ clip_g->alloc_params_buffer ();
734+ }
735+ if (use_t5) {
736+ t5->alloc_params_buffer ();
737+ }
681738 }
682739
683740 void free_params_buffer () {
684- clip_l->free_params_buffer ();
685- clip_g->free_params_buffer ();
686- t5->free_params_buffer ();
741+ if (use_clip_l) {
742+ clip_l->free_params_buffer ();
743+ }
744+ if (use_clip_g) {
745+ clip_g->free_params_buffer ();
746+ }
747+ if (use_t5) {
748+ t5->free_params_buffer ();
749+ }
687750 }
688751
689752 size_t get_params_buffer_size () {
690- size_t buffer_size = clip_l->get_params_buffer_size ();
691- buffer_size += clip_g->get_params_buffer_size ();
692- buffer_size += t5->get_params_buffer_size ();
753+ size_t buffer_size = 0 ;
754+ if (use_clip_l) {
755+ buffer_size += clip_l->get_params_buffer_size ();
756+ }
757+ if (use_clip_g) {
758+ buffer_size += clip_g->get_params_buffer_size ();
759+ }
760+ if (use_t5) {
761+ buffer_size += t5->get_params_buffer_size ();
762+ }
693763 return buffer_size;
694764 }
695765
@@ -721,23 +791,32 @@ struct SD3CLIPEmbedder : public Conditioner {
721791 for (const auto & item : parsed_attention) {
722792 const std::string& curr_text = item.first ;
723793 float curr_weight = item.second ;
724-
725- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
726- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
727- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
728-
729- curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
730- clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
731- clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
732-
733- curr_tokens = t5_tokenizer.Encode (curr_text, true );
734- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
735- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
794+ if (use_clip_l) {
795+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
796+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
797+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
798+ }
799+ if (use_clip_g) {
800+ std::vector<int > curr_tokens = clip_g_tokenizer.encode (curr_text, on_new_token_cb);
801+ clip_g_tokens.insert (clip_g_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
802+ clip_g_weights.insert (clip_g_weights.end (), curr_tokens.size (), curr_weight);
803+ }
804+ if (use_t5) {
805+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
806+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
807+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
808+ }
736809 }
737810
738- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
739- clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
740- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
811+ if (use_clip_l) {
812+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, max_length, padding);
813+ }
814+ if (use_clip_g) {
815+ clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
816+ }
817+ if (use_t5) {
818+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
819+ }
741820
742821 // for (int i = 0; i < clip_l_tokens.size(); i++) {
743822 // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -785,10 +864,10 @@ struct SD3CLIPEmbedder : public Conditioner {
785864 std::vector<float > hidden_states_vec;
786865
787866 size_t chunk_len = 77 ;
788- size_t chunk_count = clip_l_tokens.size () / chunk_len;
867+ size_t chunk_count = std::max ( std::max ( clip_l_tokens.size (), clip_g_tokens. size ()), t5_tokens. size () ) / chunk_len;
789868 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
790869 // clip_l
791- {
870+ if (use_clip_l) {
792871 std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len,
793872 clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
794873 std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len,
@@ -835,10 +914,17 @@ struct SD3CLIPEmbedder : public Conditioner {
835914 &pooled_l,
836915 work_ctx);
837916 }
917+ } else {
918+ chunk_hidden_states_l = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 768 , chunk_len);
919+ ggml_set_f32 (chunk_hidden_states_l, 0 .f );
920+ if (chunk_idx == 0 ) {
921+ pooled_l = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
922+ ggml_set_f32 (pooled_l, 0 .f );
923+ }
838924 }
839925
840926 // clip_g
841- {
927+ if (use_clip_g) {
842928 std::vector<int > chunk_tokens (clip_g_tokens.begin () + chunk_idx * chunk_len,
843929 clip_g_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
844930 std::vector<float > chunk_weights (clip_g_weights.begin () + chunk_idx * chunk_len,
@@ -886,10 +972,17 @@ struct SD3CLIPEmbedder : public Conditioner {
886972 &pooled_g,
887973 work_ctx);
888974 }
975+ } else {
976+ chunk_hidden_states_g = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 1280 , chunk_len);
977+ ggml_set_f32 (chunk_hidden_states_g, 0 .f );
978+ if (chunk_idx == 0 ) {
979+ pooled_g = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 1280 );
980+ ggml_set_f32 (pooled_g, 0 .f );
981+ }
889982 }
890983
891984 // t5
892- {
985+ if (use_t5) {
893986 std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
894987 t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
895988 std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -917,6 +1010,8 @@ struct SD3CLIPEmbedder : public Conditioner {
9171010 float new_mean = ggml_tensor_mean (tensor);
9181011 ggml_tensor_scale (tensor, (original_mean / new_mean));
9191012 }
1013+ } else {
1014+ chunk_hidden_states_t5 = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
9201015 }
9211016
9221017 auto chunk_hidden_states_lg_pad = ggml_new_tensor_3d (work_ctx,
@@ -959,11 +1054,19 @@ struct SD3CLIPEmbedder : public Conditioner {
9591054 ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
9601055 }
9611056
962- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
963- hidden_states = ggml_reshape_2d (work_ctx,
964- hidden_states,
965- chunk_hidden_states->ne [0 ],
966- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1057+ if (hidden_states_vec.size () > 0 ) {
1058+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1059+ hidden_states = ggml_reshape_2d (work_ctx,
1060+ hidden_states,
1061+ chunk_hidden_states->ne [0 ],
1062+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1063+ } else {
1064+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 0 );
1065+ }
1066+ if (pooled == NULL ) {
1067+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 2048 );
1068+ ggml_set_f32 (pooled, 0 .f );
1069+ }
9671070 return SDCondition (hidden_states, pooled, NULL );
9681071 }
9691072
0 commit comments