33
44#include " ggml_extend.hpp"
55#include " model.h"
6+ #include " tokenize_util.h"
67
78/* ================================================== CLIPTokenizer ===================================================*/
89
9- __STATIC_INLINE__ std::pair<std::unordered_map<std::string, float >, std::string> extract_and_remove_lora (std::string text) {
10- std::regex re (" <lora:([^:]+):([^>]+)>" );
11- std::smatch matches;
12- std::unordered_map<std::string, float > filename2multiplier;
13-
14- while (std::regex_search (text, matches, re)) {
15- std::string filename = matches[1 ].str ();
16- float multiplier = std::stof (matches[2 ].str ());
17-
18- text = std::regex_replace (text, re, " " , std::regex_constants::format_first_only);
19-
20- if (multiplier == 0 .f ) {
21- continue ;
22- }
23-
24- if (filename2multiplier.find (filename) == filename2multiplier.end ()) {
25- filename2multiplier[filename] = multiplier;
26- } else {
27- filename2multiplier[filename] += multiplier;
28- }
29- }
30-
31- return std::make_pair (filename2multiplier, text);
32- }
33-
3410__STATIC_INLINE__ std::vector<std::pair<int , std::u32string>> bytes_to_unicode () {
3511 std::vector<std::pair<int , std::u32string>> byte_unicode_pairs;
3612 std::set<int > byte_set;
@@ -72,6 +48,8 @@ class CLIPTokenizer {
7248 int encoder_len;
7349 int bpe_len;
7450
51+ std::vector<std::string> special_tokens;
52+
7553public:
7654 const std::string UNK_TOKEN = " <|endoftext|>" ;
7755 const std::string BOS_TOKEN = " <|startoftext|>" ;
@@ -117,6 +95,15 @@ class CLIPTokenizer {
11795 return pairs;
11896 }
11997
98+ bool is_special_token (const std::string& token) {
99+ for (auto & special_token : special_tokens) {
100+ if (special_token == token) {
101+ return true ;
102+ }
103+ }
104+ return false ;
105+ }
106+
120107public:
121108 CLIPTokenizer (int pad_token_id = 49407 , const std::string& merges_utf8_str = " " )
122109 : PAD_TOKEN_ID(pad_token_id) {
@@ -125,6 +112,8 @@ class CLIPTokenizer {
125112 } else {
126113 load_from_merges (ModelLoader::load_merges ());
127114 }
115+ add_special_token (" <|startoftext|>" );
116+ add_special_token (" <|endoftext|>" );
128117 }
129118
130119 void load_from_merges (const std::string& merges_utf8_str) {
@@ -201,6 +190,10 @@ class CLIPTokenizer {
201190 }
202191 }
203192
193+ void add_special_token (const std::string& token) {
194+ special_tokens.push_back (token);
195+ }
196+
204197 std::u32string bpe (const std::u32string& token) {
205198 std::vector<std::u32string> word;
206199
@@ -379,25 +372,54 @@ class CLIPTokenizer {
379372 return trim (text);
380373 }
381374
375+ std::vector<std::string> token_split (const std::string& text) {
376+ std::regex pat (R"( 's|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)" ,
377+ std::regex::icase);
378+ std::sregex_iterator iter (text.begin (), text.end (), pat);
379+ std::sregex_iterator end;
380+
381+ std::vector<std::string> result;
382+ for (; iter != end; ++iter) {
383+ result.emplace_back (iter->str ());
384+ }
385+
386+ return result;
387+ }
388+
382389 std::vector<int > encode (std::string text, on_new_token_cb_t on_new_token_cb) {
383390 std::string original_text = text;
384391 std::vector<int32_t > bpe_tokens;
385392 text = whitespace_clean (text);
386393 std::transform (text.begin (), text.end (), text.begin (), [](unsigned char c) { return std::tolower (c); });
387394
388- std::regex pat (R"( <\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)" ,
389- std::regex::icase);
390-
391- std::smatch matches;
392395 std::string str = text;
393396 std::vector<std::string> token_strs;
394- while (std::regex_search (str, matches, pat)) {
395- bool skip = on_new_token_cb (str, bpe_tokens);
396- if (skip) {
397+
398+ auto splited_texts = split_with_special_tokens (text, special_tokens);
399+
400+ for (auto & splited_text : splited_texts) {
401+ LOG_DEBUG (" token %s" , splited_text.c_str ());
402+ if (is_special_token (splited_text)) {
403+ LOG_DEBUG (" special %s" , splited_text.c_str ());
404+ bool skip = on_new_token_cb (splited_text, bpe_tokens);
405+ if (skip) {
406+ token_strs.push_back (splited_text);
407+ continue ;
408+ }
397409 continue ;
398410 }
399- for (auto & token : matches) {
400- std::string token_str = token.str ();
411+
412+ auto tokens = token_split (splited_text);
413+ for (auto & token : tokens) {
414+ if (on_new_token_cb != nullptr ) {
415+ bool skip = on_new_token_cb (token, bpe_tokens);
416+ if (skip) {
417+ token_strs.push_back (token);
418+ continue ;
419+ }
420+ }
421+
422+ std::string token_str = token;
401423 std::u32string utf32_token;
402424 for (int i = 0 ; i < token_str.length (); i++) {
403425 unsigned char b = token_str[i];
@@ -417,14 +439,13 @@ class CLIPTokenizer {
417439 bpe_tokens.push_back (encoder[bpe_str]);
418440 token_strs.push_back (utf32_to_utf8 (bpe_str));
419441 }
420- str = matches.suffix ();
421442 }
422- std::stringstream ss;
423- ss << " [" ;
424- for (auto token : token_strs) {
425- ss << " \" " << token << " \" , " ;
426- }
427- ss << " ]" ;
443+ // std::stringstream ss;
444+ // ss << "[";
445+ // for (auto token : token_strs) {
446+ // ss << "\"" << token << "\", ";
447+ // }
448+ // ss << "]";
428449 // LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
429450 // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
430451 return bpe_tokens;
@@ -963,7 +984,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
963984 return gf;
964985 }
965986
966- void compute (const int n_threads,
987+ bool compute (const int n_threads,
967988 struct ggml_tensor * input_ids,
968989 int num_custom_embeddings,
969990 void * custom_embeddings_data,
@@ -975,7 +996,7 @@ struct CLIPTextModelRunner : public GGMLRunner {
975996 auto get_graph = [&]() -> struct ggml_cgraph * {
976997 return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
977998 };
978- GGMLRunner::compute (get_graph, n_threads, true , output, output_ctx);
999+ return GGMLRunner::compute (get_graph, n_threads, true , output, output_ctx);
9791000 }
9801001};
9811002
0 commit comments