@@ -355,6 +355,113 @@ class CLIPTokenizer {
355355 }
356356};
357357
358+ // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
359+ //
360+ // Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
361+ // Accepted tokens are:
362+ // (abc) - increases attention to abc by a multiplier of 1.1
363+ // (abc:3.12) - increases attention to abc by a multiplier of 3.12
364+ // [abc] - decreases attention to abc by a multiplier of 1.1
365+ // \( - literal character '('
366+ // \[ - literal character '['
367+ // \) - literal character ')'
368+ // \] - literal character ']'
369+ // \\ - literal character '\'
370+ // anything else - just text
371+ //
372+ // >>> parse_prompt_attention('normal text')
373+ // [['normal text', 1.0]]
374+ // >>> parse_prompt_attention('an (important) word')
375+ // [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
376+ // >>> parse_prompt_attention('(unbalanced')
377+ // [['unbalanced', 1.1]]
378+ // >>> parse_prompt_attention('\(literal\]')
379+ // [['(literal]', 1.0]]
380+ // >>> parse_prompt_attention('(unnecessary)(parens)')
381+ // [['unnecessaryparens', 1.1]]
382+ // >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
383+ // [['a ', 1.0],
384+ // ['house', 1.5730000000000004],
385+ // [' ', 1.1],
386+ // ['on', 1.0],
387+ // [' a ', 1.1],
388+ // ['hill', 0.55],
389+ // [', sun, ', 1.1],
390+ // ['sky', 1.4641000000000006],
391+ // ['.', 1.1]]
392+ std::vector<std::pair<std::string, float >> parse_prompt_attention (const std::string& text) {
393+ std::vector<std::pair<std::string, float >> res;
394+ std::vector<int > round_brackets;
395+ std::vector<int > square_brackets;
396+
397+ float round_bracket_multiplier = 1 .1f ;
398+ float square_bracket_multiplier = 1 / 1 .1f ;
399+
400+ std::regex re_attention (R"( \\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)" );
401+ std::regex re_break (R"( \s*\bBREAK\b\s*)" );
402+
403+ auto multiply_range = [&](int start_position, float multiplier) {
404+ for (int p = start_position; p < res.size (); ++p) {
405+ res[p].second *= multiplier;
406+ }
407+ };
408+
409+ std::smatch m;
410+ std::string remaining_text = text;
411+
412+ while (std::regex_search (remaining_text, m, re_attention)) {
413+ std::string text = m[0 ];
414+ std::string weight = m[1 ];
415+
416+ if (text == " (" ) {
417+ round_brackets.push_back (res.size ());
418+ } else if (text == " [" ) {
419+ square_brackets.push_back (res.size ());
420+ } else if (!weight.empty ()) {
421+ if (!round_brackets.empty ()) {
422+ multiply_range (round_brackets.back (), std::stod (weight));
423+ round_brackets.pop_back ();
424+ }
425+ } else if (text == " )" && !round_brackets.empty ()) {
426+ multiply_range (round_brackets.back (), round_bracket_multiplier);
427+ round_brackets.pop_back ();
428+ } else if (text == " ]" && !square_brackets.empty ()) {
429+ multiply_range (square_brackets.back (), square_bracket_multiplier);
430+ square_brackets.pop_back ();
431+ } else if (text == " \\ (" ) {
432+ res.push_back ({text.substr (1 ), 1 .0f });
433+ } else {
434+ res.push_back ({text, 1 .0f });
435+ }
436+
437+ remaining_text = m.suffix ();
438+ }
439+
440+ for (int pos : round_brackets) {
441+ multiply_range (pos, round_bracket_multiplier);
442+ }
443+
444+ for (int pos : square_brackets) {
445+ multiply_range (pos, square_bracket_multiplier);
446+ }
447+
448+ if (res.empty ()) {
449+ res.push_back ({" " , 1 .0f });
450+ }
451+
452+ int i = 0 ;
453+ while (i + 1 < res.size ()) {
454+ if (res[i].second == res[i + 1 ].second ) {
455+ res[i].first += res[i + 1 ].first ;
456+ res.erase (res.begin () + i + 1 );
457+ } else {
458+ ++i;
459+ }
460+ }
461+
462+ return res;
463+ }
464+
358465/* ================================================ FrozenCLIPEmbedder ================================================*/
359466
360467struct ResidualAttentionBlock {
@@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder {
639746 }
640747};
641748
749+ // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
750+ struct FrozenCLIPEmbedderWithCustomWords {
751+ CLIPTokenizer tokenizer;
752+ CLIPTextModel text_model;
753+
754+ std::pair<std::vector<int >, std::vector<float >> tokenize (std::string text,
755+ size_t max_length = 0 ,
756+ bool padding = false ) {
757+ auto parsed_attention = parse_prompt_attention (text);
758+
759+ {
760+ std::stringstream ss;
761+ ss << " [" ;
762+ for (const auto & item : parsed_attention) {
763+ ss << " ['" << item.first << " ', " << item.second << " ], " ;
764+ }
765+ ss << " ]" ;
766+ LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
767+ }
768+
769+ std::vector<int > tokens;
770+ std::vector<float > weights;
771+ for (const auto & item : parsed_attention) {
772+ const std::string& curr_text = item.first ;
773+ float curr_weight = item.second ;
774+ std::vector<int > curr_tokens = tokenizer.encode (curr_text);
775+ tokens.insert (tokens.end (), curr_tokens.begin (), curr_tokens.end ());
776+ weights.insert (weights.end (), curr_tokens.size (), curr_weight);
777+ }
778+ tokens.insert (tokens.begin (), BOS_TOKEN_ID);
779+ weights.insert (weights.begin (), 1.0 );
780+
781+ if (max_length > 0 ) {
782+ if (tokens.size () > max_length - 1 ) {
783+ tokens.resize (max_length - 1 );
784+ weights.resize (max_length - 1 );
785+ } else {
786+ if (padding) {
787+ tokens.insert (tokens.end (), max_length - 1 - tokens.size (), PAD_TOKEN_ID);
788+ weights.insert (weights.end (), max_length - 1 - weights.size (), 1.0 );
789+ }
790+ }
791+ }
792+ tokens.push_back (EOS_TOKEN_ID);
793+ weights.push_back (1.0 );
794+
795+ // for (int i = 0; i < tokens.size(); i++) {
796+ // std::cout << tokens[i] << ":" << weights[i] << ", ";
797+ // }
798+ // std::cout << std::endl;
799+
800+ return {tokens, weights};
801+ }
802+ };
803+
642804/* ==================================================== UnetModel =====================================================*/
643805
644806struct ResBlock {
@@ -2489,7 +2651,7 @@ class StableDiffusionGGML {
24892651 size_t max_params_mem_size = 0 ;
24902652 size_t max_rt_mem_size = 0 ;
24912653
2492- FrozenCLIPEmbedder cond_stage_model;
2654+ FrozenCLIPEmbedderWithCustomWords cond_stage_model;
24932655 UNetModel diffusion_model;
24942656 AutoEncoderKL first_stage_model;
24952657
@@ -2784,9 +2946,11 @@ class StableDiffusionGGML {
27842946 }
27852947
27862948 ggml_tensor* get_learned_condition (ggml_context* res_ctx, const std::string& text) {
2787- std::vector<int32_t > tokens = cond_stage_model.tokenizer .tokenize (text,
2788- cond_stage_model.text_model .max_position_embeddings ,
2789- true );
2949+ auto tokens_and_weights = cond_stage_model.tokenize (text,
2950+ cond_stage_model.text_model .max_position_embeddings ,
2951+ true );
2952+ std::vector<int >& tokens = tokens_and_weights.first ;
2953+ std::vector<float >& weights = tokens_and_weights.second ;
27902954 size_t ctx_size = 1 * 1024 * 1024 ; // 1MB
27912955 // calculate the amount of memory required
27922956 {
@@ -2848,10 +3012,39 @@ class StableDiffusionGGML {
28483012 int64_t t1 = ggml_time_ms ();
28493013 LOG_DEBUG (" computing condition graph completed, taking %.2fs" , (t1 - t0) * 1 .0f / 1000 );
28503014
2851- ggml_tensor* result = ggml_dup_tensor (res_ctx, hidden_states);
2852- copy_ggml_tensor (result, hidden_states);
3015+ ggml_tensor* result = ggml_dup_tensor (res_ctx, hidden_states); // [N, n_token, hidden_size]
3016+
3017+ {
3018+ int64_t nelements = ggml_nelements (hidden_states);
3019+ float original_mean = 0 .f ;
3020+ float new_mean = 0 .f ;
3021+ float * vec = (float *)hidden_states->data ;
3022+ for (int i = 0 ; i < nelements; i++) {
3023+ original_mean += vec[i] / nelements * 1 .0f ;
3024+ }
3025+
3026+ for (int i2 = 0 ; i2 < hidden_states->ne [2 ]; i2++) {
3027+ for (int i1 = 0 ; i1 < hidden_states->ne [1 ]; i1++) {
3028+ for (int i0 = 0 ; i0 < hidden_states->ne [0 ]; i0++) {
3029+ float value = ggml_tensor_get_f32 (hidden_states, i0, i1, i2);
3030+ value *= weights[i1];
3031+ ggml_tensor_set_f32 (result, value, i0, i1, i2);
3032+ }
3033+ }
3034+ }
3035+
3036+ vec = (float *)result->data ;
3037+ for (int i = 0 ; i < nelements; i++) {
3038+ new_mean += vec[i] / nelements * 1 .0f ;
3039+ }
3040+
3041+ for (int i = 0 ; i < nelements; i++) {
3042+ vec[i] = vec[i] * (original_mean / new_mean);
3043+ }
3044+ }
28533045
28543046 // print_ggml_tensor(result);
3047+
28553048 size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size ();
28563049 if (rt_mem_size > max_rt_mem_size) {
28573050 max_rt_mem_size = rt_mem_size;
0 commit comments