22#define __CLIP_HPP__
33
44#include " ggml_extend.hpp"
5+ #include " model.h"
56
67/* ================================================== CLIPTokenizer ===================================================*/
78
@@ -67,6 +68,9 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
6768}
6869
6970// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
71+
72+ typedef std::function<bool (std::string&, std::vector<int32_t >&)> on_new_token_cb_t ;
73+
7074class CLIPTokenizer {
7175private:
7276 SDVersion version = VERSION_1_x;
@@ -234,8 +238,11 @@ class CLIPTokenizer {
234238 return result;
235239 }
236240
237- std::vector<int > tokenize (std::string text, size_t max_length = 0 , bool padding = false ) {
238- std::vector<int32_t > tokens = encode (text);
241+ std::vector<int > tokenize (std::string text,
242+ on_new_token_cb_t on_new_token_cb,
243+ size_t max_length = 0 ,
244+ bool padding = false ) {
245+ std::vector<int32_t > tokens = encode (text, on_new_token_cb);
239246 tokens.insert (tokens.begin (), BOS_TOKEN_ID);
240247 if (max_length > 0 ) {
241248 if (tokens.size () > max_length - 1 ) {
@@ -255,7 +262,7 @@ class CLIPTokenizer {
255262 return tokens;
256263 }
257264
258- std::vector<int > encode (std::string text) {
265+ std::vector<int > encode (std::string text, on_new_token_cb_t on_new_token_cb ) {
259266 std::string original_text = text;
260267 std::vector<int32_t > bpe_tokens;
261268 text = whitespace_clean (text);
@@ -268,6 +275,10 @@ class CLIPTokenizer {
268275 std::string str = text;
269276 std::vector<std::string> token_strs;
270277 while (std::regex_search (str, matches, pat)) {
278+ bool skip = on_new_token_cb (str, bpe_tokens);
279+ if (skip) {
280+ continue ;
281+ }
271282 for (auto & token : matches) {
272283 std::string token_str = token.str ();
273284 std::u32string utf32_token;
@@ -444,13 +455,13 @@ struct ResidualAttentionBlock {
444455 struct ggml_tensor * ln2_b; // [hidden_size, ]
445456
446457 size_t calculate_mem_size (ggml_type wtype) {
447- double mem_size = 0 ;
448- mem_size += 4 * hidden_size * hidden_size * ggml_type_sizef (wtype ); // q_w/k_w/v_w/out_w
449- mem_size += 8 * hidden_size * ggml_type_sizef (GGML_TYPE_F32); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b
450- mem_size += 2 * hidden_size * intermediate_size * ggml_type_sizef (wtype ); // fc1_w/fc2_w
451- mem_size += intermediate_size * ggml_type_sizef (GGML_TYPE_F32); // fc1_b
452- mem_size += hidden_size * ggml_type_sizef (GGML_TYPE_F32); // fc2_b
453- return static_cast < size_t >( mem_size) ;
458+ size_t mem_size = 0 ;
459+ mem_size += 4 * ggml_row_size (wtype, hidden_size * hidden_size ); // q_w/k_w/v_w/out_w
460+ mem_size += 8 * ggml_row_size (GGML_TYPE_F32, hidden_size ); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b
461+ mem_size += 2 * ggml_row_size (wtype, hidden_size * intermediate_size); // fc1_w/fc2_w
462+ mem_size += ggml_row_size (GGML_TYPE_F32, intermediate_size ); // fc1_b
463+ mem_size += ggml_row_size (GGML_TYPE_F32, hidden_size ); // fc2_b
464+ return mem_size;
454465 }
455466
456467 void init_params (struct ggml_context * ctx, ggml_allocr* alloc, ggml_type wtype) {
@@ -597,13 +608,17 @@ struct CLIPTextModel {
597608 struct ggml_tensor * position_ids;
598609 struct ggml_tensor * token_embed_weight;
599610 struct ggml_tensor * position_embed_weight;
611+ struct ggml_tensor * token_embed_custom;
600612
601613 // transformer
602614 std::vector<ResidualAttentionBlock> resblocks;
603615 struct ggml_tensor * final_ln_w;
604616 struct ggml_tensor * final_ln_b;
605617
606618 struct ggml_tensor * text_projection;
619+ std::string embd_dir;
620+ int32_t num_custom_embeddings = 0 ;
621+ std::vector<std::string> readed_embeddings;
607622
608623 CLIPTextModel (CLIPVersion version = OPENAI_CLIP_VIT_L_14,
609624 int clip_skip = -1 ,
@@ -642,18 +657,21 @@ struct CLIPTextModel {
642657 }
643658
644659 size_t calculate_mem_size (ggml_type wtype) {
645- double mem_size = 0 ;
646- mem_size += hidden_size * max_position_embeddings * ggml_type_sizef (GGML_TYPE_I32); // position_ids
647- mem_size += hidden_size * vocab_size * ggml_type_sizef (wtype); // token_embed_weight
648- mem_size += hidden_size * max_position_embeddings * ggml_type_sizef (wtype); // position_embed_weight
660+ size_t mem_size = 0 ;
661+ mem_size += ggml_row_size (GGML_TYPE_I32, hidden_size * max_position_embeddings); // position_ids
662+ mem_size += ggml_row_size (wtype, hidden_size * vocab_size); // token_embed_weight
663+ mem_size += ggml_row_size (wtype, hidden_size * max_position_embeddings); // position_embed_weight
664+ if (version == OPENAI_CLIP_VIT_L_14) {
665+ mem_size += ggml_row_size (wtype, hidden_size * max_position_embeddings); // token_embed_custom
666+ }
649667 for (int i = 0 ; i < num_hidden_layers; i++) {
650668 mem_size += resblocks[i].calculate_mem_size (wtype);
651669 }
652- mem_size += 2 * hidden_size * ggml_type_sizef (GGML_TYPE_F32); // final_ln_w/b
670+ mem_size += 2 * ggml_row_size (GGML_TYPE_F32, hidden_size ); // final_ln_w/b
653671 if (version == OPEN_CLIP_VIT_BIGG_14) {
654- mem_size += hidden_size * projection_dim * ggml_type_sizef (GGML_TYPE_F32 ); // text_projection
672+ mem_size += ggml_row_size (GGML_TYPE_F32, hidden_size * projection_dim); // text_projection
655673 }
656- return static_cast < size_t >( mem_size) ;
674+ return mem_size;
657675 }
658676
659677 void map_by_name (std::map<std::string, struct ggml_tensor *>& tensors, const std::string prefix) {
@@ -670,14 +688,48 @@ struct CLIPTextModel {
670688 }
671689 }
672690
673- struct ggml_tensor * forward (struct ggml_context * ctx0, struct ggml_tensor * input_ids, size_t max_token_idx = 0 , bool return_pooled = false ) {
691+ bool load_embedding (std::string embd_name, std::string embd_path, std::vector<int32_t > &bpe_tokens) {
692+ // the order matters
693+ ModelLoader model_loader;
694+ if (!model_loader.init_from_file (embd_path)) {
695+ LOG_ERROR (" embedding '%s' failed" , embd_name.c_str ());
696+ return false ;
697+ }
698+ struct ggml_init_params params;
699+ params.mem_size = 32 * 1024 ; // max for custom embeddings 32 KB
700+ params.mem_buffer = NULL ;
701+ params.no_alloc = false ;
702+ struct ggml_context * embd_ctx = ggml_init (params);
703+ struct ggml_tensor * embd = NULL ;
704+ auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
705+ if (tensor_storage.ne [0 ] != hidden_size) {
706+ LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], hidden_size);
707+ return false ;
708+ }
709+ embd = ggml_new_tensor_2d (embd_ctx, token_embed_weight->type , hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
710+ *dst_tensor = embd;
711+ return true ;
712+ };
713+ model_loader.load_tensors (on_load, NULL );
714+ ggml_backend_tensor_set (token_embed_custom, embd->data , num_custom_embeddings * hidden_size * ggml_type_size (token_embed_custom->type ), ggml_nbytes (embd));
715+ readed_embeddings.push_back (embd_name);
716+ for (int i = 0 ; i < embd->ne [1 ]; i++) {
717+ bpe_tokens.push_back (vocab_size + num_custom_embeddings);
718+ // LOG_DEBUG("new custom token: %i", vocab_size + num_custom_embeddings);
719+ num_custom_embeddings++;
720+ }
721+ LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i" , embd_name.c_str (), num_custom_embeddings);
722+ return true ;
723+ }
724+
725+ struct ggml_tensor * forward (struct ggml_context * ctx0, struct ggml_tensor * input_ids, struct ggml_tensor * tkn_embeddings, uint32_t max_token_idx = 0 , bool return_pooled = false ) {
674726 // input_ids: [N, n_token]
675727 GGML_ASSERT (input_ids->ne [0 ] <= position_ids->ne [0 ]);
676728
677729 // token_embedding + position_embedding
678730 struct ggml_tensor * x;
679731 x = ggml_add (ctx0,
680- ggml_get_rows (ctx0, token_embed_weight, input_ids),
732+ ggml_get_rows (ctx0, tkn_embeddings == NULL ? token_embed_weight : tkn_embeddings , input_ids),
681733 ggml_get_rows (ctx0,
682734 position_embed_weight,
683735 ggml_view_1d (ctx0, position_ids, input_ids->ne [0 ], 0 ))); // [N, n_token, hidden_size]
@@ -723,6 +775,10 @@ struct CLIPTextModel {
723775
724776 final_ln_b = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, hidden_size);
725777
778+ if (version == OPENAI_CLIP_VIT_L_14) {
779+ token_embed_custom = ggml_new_tensor_2d (ctx, wtype, hidden_size, max_position_embeddings);
780+ }
781+
726782 if (version == OPEN_CLIP_VIT_BIGG_14) {
727783 text_projection = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, projection_dim, hidden_size);
728784 }
@@ -805,11 +861,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
805861 }
806862 }
807863
808- struct ggml_tensor * forward (struct ggml_context * ctx0, struct ggml_tensor * input_ids, struct ggml_tensor * input_ids2, size_t max_token_idx = 0 , bool return_pooled = false ) {
864+ struct ggml_tensor * forward (struct ggml_context * ctx0, struct ggml_tensor * input_ids, struct ggml_tensor * input_ids2, struct ggml_tensor * embeddings, size_t max_token_idx = 0 , bool return_pooled = false ) {
809865 if (return_pooled) {
810- return text_model2.forward (ctx0, input_ids2, max_token_idx, return_pooled);
866+ return text_model2.forward (ctx0, input_ids2, NULL , max_token_idx, return_pooled);
811867 }
812- auto hidden_states = text_model.forward (ctx0, input_ids); // [N, n_token, hidden_size]
868+ auto hidden_states = text_model.forward (ctx0, input_ids, embeddings ); // [N, n_token, hidden_size]
813869 // LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]);
814870 if (version == VERSION_XL) {
815871 hidden_states = ggml_reshape_4d (ctx0,
@@ -820,7 +876,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
820876 hidden_states->ne [3 ]);
821877 hidden_states = ggml_cont (ctx0, ggml_permute (ctx0, hidden_states, 2 , 0 , 1 , 3 ));
822878
823- auto hidden_states2 = text_model2.forward (ctx0, input_ids2); // [N, n_token, hidden_size2]
879+ auto hidden_states2 = text_model2.forward (ctx0, input_ids2, NULL ); // [N, n_token, hidden_size2]
824880 hidden_states2 = ggml_reshape_4d (ctx0,
825881 hidden_states2,
826882 hidden_states2->ne [0 ],
@@ -857,12 +913,36 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
857913 LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
858914 }
859915
916+ auto on_new_token_cb = [&] (std::string& str, std::vector<int32_t > &bpe_tokens) -> bool {
917+ size_t word_end = str.find (" ," );
918+ std::string embd_name = word_end == std::string::npos ? str : str.substr (0 , word_end);
919+ embd_name = trim (embd_name);
920+ std::string embd_path = get_full_path (text_model.embd_dir , embd_name + " .pt" );
921+ if (embd_path.size () == 0 ) {
922+ embd_path = get_full_path (text_model.embd_dir , embd_name + " .ckpt" );
923+ }
924+ if (embd_path.size () == 0 ) {
925+ embd_path = get_full_path (text_model.embd_dir , embd_name + " .safetensors" );
926+ }
927+ if (embd_path.size () > 0 ) {
928+ if (text_model.load_embedding (embd_name, embd_path, bpe_tokens)) {
929+ if (word_end != std::string::npos) {
930+ str = str.substr (word_end);
931+ } else {
932+ str = " " ;
933+ }
934+ return true ;
935+ }
936+ }
937+ return false ;
938+ };
939+
860940 std::vector<int > tokens;
861941 std::vector<float > weights;
862942 for (const auto & item : parsed_attention) {
863943 const std::string& curr_text = item.first ;
864944 float curr_weight = item.second ;
865- std::vector<int > curr_tokens = tokenizer.encode (curr_text);
945+ std::vector<int > curr_tokens = tokenizer.encode (curr_text, on_new_token_cb );
866946 tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
867947 weights.insert (weights.end (), curr_tokens.size (), curr_weight);
868948 }
@@ -951,7 +1031,26 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
9511031 }
9521032 }
9531033
954- struct ggml_tensor * hidden_states = forward (ctx0, input_ids, input_ids2, max_token_idx, return_pooled);
1034+ struct ggml_tensor * embeddings = NULL ;
1035+
1036+ if (text_model.num_custom_embeddings > 0 && version != VERSION_XL) {
1037+ embeddings = ggml_new_tensor_2d (ctx0, wtype, text_model.hidden_size , text_model.vocab_size + text_model.num_custom_embeddings /* custom placeholder */ );
1038+ ggml_allocr_alloc (allocr, embeddings);
1039+ if (!ggml_allocr_is_measure (allocr)) {
1040+ // really bad, there is memory inflexibility (this is for host<->device memory conflicts)
1041+ void * freeze_data = malloc (ggml_nbytes (text_model.token_embed_weight ));
1042+ ggml_backend_tensor_get_and_sync (backend, text_model.token_embed_weight , freeze_data, 0 , ggml_nbytes (text_model.token_embed_weight ));
1043+ ggml_backend_tensor_set (embeddings, freeze_data, 0 , ggml_nbytes (text_model.token_embed_weight ));
1044+ free (freeze_data);
1045+ // concatenate custom embeddings
1046+ void * custom_data = malloc (ggml_nbytes (text_model.token_embed_custom ));
1047+ ggml_backend_tensor_get_and_sync (backend, text_model.token_embed_custom , custom_data, 0 , ggml_nbytes (text_model.token_embed_custom ));
1048+ ggml_backend_tensor_set (embeddings, custom_data, ggml_nbytes (text_model.token_embed_weight ), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size (wtype));
1049+ free (custom_data);
1050+ }
1051+ }
1052+
1053+ struct ggml_tensor * hidden_states = forward (ctx0, input_ids, input_ids2, embeddings, max_token_idx, return_pooled);
9551054
9561055 ggml_build_forward_expand (gf, hidden_states);
9571056 ggml_free (ctx0);
0 commit comments