@@ -520,7 +520,6 @@ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
520520}
521521
522522// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
523- // TODO: implement bpe
524523class CLIPTokenizer {
525524private:
526525 SDVersion version = VERSION_1_x;
@@ -547,6 +546,21 @@ class CLIPTokenizer {
547546 return text;
548547 }
549548
549+ static std::set<std::pair<std::u32string, std::u32string>> get_pairs (const std::vector<std::u32string>& subwords) {
550+ std::set<std::pair<std::u32string, std::u32string>> pairs;
551+ if (subwords.size () == 0 ) {
552+ return pairs;
553+ }
554+ std::u32string prev_subword = subwords[0 ];
555+ for (int i = 1 ; i < subwords.size (); i++) {
556+ std::u32string subword = subwords[i];
557+ std::pair<std::u32string, std::u32string> pair (prev_subword, subword);
558+ pairs.insert (pair);
559+ prev_subword = subword;
560+ }
561+ return pairs;
562+ }
563+
550564public:
551565 CLIPTokenizer (SDVersion version = VERSION_1_x)
552566 : version(version) {}
@@ -565,7 +579,9 @@ class CLIPTokenizer {
565579 merges.push_back (merges_utf32_str.substr (start, pos - start));
566580 start = pos + 1 ;
567581 }
568- merges = std::vector<std::u32string>(merges.begin () + 1 , merges.begin () + 49152 - 256 - 2 + 1 );
582+ // LOG_DEBUG("merges size %llu", merges.size());
583+ GGML_ASSERT (merges.size () == 48895 );
584+ merges = std::vector<std::u32string>(merges.begin () + 1 , merges.end ());
569585 std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
570586 for (const auto & merge : merges) {
571587 size_t space_pos = merge.find (' ' );
@@ -596,14 +612,79 @@ class CLIPTokenizer {
596612 }
597613 };
598614
599- std::u32string bpe (std::u32string token) {
600- std::u32string word = token + utf8_to_utf32 (" </w>" );
601- if (encoder.find (word) != encoder.end ()) {
602- return word;
603- } else if (encoder.find (token) != encoder.end ()) {
604- return token;
615+ std::u32string bpe (const std::u32string& token) {
616+ std::vector<std::u32string> word;
617+
618+ for (int i = 0 ; i < token.size () - 1 ; i++) {
619+ word.emplace_back (1 , token[i]);
620+ }
621+ word.push_back (token.substr (token.size () - 1 ) + utf8_to_utf32 (" </w>" ));
622+
623+ std::set<std::pair<std::u32string, std::u32string>> pairs = get_pairs (word);
624+
625+ if (pairs.empty ()) {
626+ return token + utf8_to_utf32 (" </w>" );
605627 }
606- return utf8_to_utf32 (UNK_TOKEN);
628+
629+ while (true ) {
630+ auto min_pair_iter = std::min_element (pairs.begin (),
631+ pairs.end (),
632+ [&](const std::pair<std::u32string, std::u32string>& a,
633+ const std::pair<std::u32string, std::u32string>& b) {
634+ if (bpe_ranks.find (a) == bpe_ranks.end ()) {
635+ return false ;
636+ } else if (bpe_ranks.find (b) == bpe_ranks.end ()) {
637+ return true ;
638+ }
639+ return bpe_ranks.at (a) < bpe_ranks.at (b);
640+ });
641+
642+ const std::pair<std::u32string, std::u32string>& bigram = *min_pair_iter;
643+
644+ if (bpe_ranks.find (bigram) == bpe_ranks.end ()) {
645+ break ;
646+ }
647+
648+ std::u32string first = bigram.first ;
649+ std::u32string second = bigram.second ;
650+ std::vector<std::u32string> new_word;
651+ int32_t i = 0 ;
652+
653+ while (i < word.size ()) {
654+ auto it = std::find (word.begin () + i, word.end (), first);
655+ if (it == word.end ()) {
656+ new_word.insert (new_word.end (), word.begin () + i, word.end ());
657+ break ;
658+ }
659+ new_word.insert (new_word.end (), word.begin () + i, it);
660+ i = static_cast <int32_t >(std::distance (word.begin (), it));
661+
662+ if (word[i] == first && i < static_cast <int32_t >(word.size ()) - 1 && word[i + 1 ] == second) {
663+ new_word.push_back (first + second);
664+ i += 2 ;
665+ } else {
666+ new_word.push_back (word[i]);
667+ i += 1 ;
668+ }
669+ }
670+
671+ word = new_word;
672+
673+ if (word.size () == 1 ) {
674+ break ;
675+ }
676+ pairs = get_pairs (word);
677+ }
678+
679+ std::u32string result;
680+ for (int i = 0 ; i < word.size (); i++) {
681+ result += word[i];
682+ if (i != word.size () - 1 ) {
683+ result += utf8_to_utf32 (" " );
684+ }
685+ }
686+
687+ return result;
607688 }
608689
609690 std::vector<int > tokenize (std::string text, size_t max_length = 0 , bool padding = false ) {
0 commit comments