@@ -451,65 +451,83 @@ class StableDiffusionGGML {
451451 int height,
452452 bool force_zero_embeddings = false ) {
453453 cond_stage_model->set_clip_skip (clip_skip);
454- auto tokens_and_weights = cond_stage_model->tokenize (text, true );
455- std::vector<int >& tokens = tokens_and_weights.first ;
456- std::vector<float >& weights = tokens_and_weights.second ;
457- int64_t t0 = ggml_time_ms ();
458- struct ggml_tensor * hidden_states = NULL ; // [N, n_token, hidden_size]
459- struct ggml_tensor * pooled = NULL ;
460-
461- auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, tokens);
462- struct ggml_tensor * input_ids2 = NULL ;
463- size_t max_token_idx = 0 ;
464- if (version == VERSION_XL) {
465- auto it = std::find (tokens.begin (), tokens.end (), EOS_TOKEN_ID);
466- if (it != tokens.end ()) {
467- std::fill (std::next (it), tokens.end (), 0 );
468- }
454+ auto tokens_and_weights = cond_stage_model->tokenize (text, true );
455+ std::vector<int >& tokens = tokens_and_weights.first ;
456+ std::vector<float >& weights = tokens_and_weights.second ;
457+ int64_t t0 = ggml_time_ms ();
458+ struct ggml_tensor * hidden_states = NULL ; // [N, n_token, hidden_size]
459+ struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token, hidden_size]
460+ struct ggml_tensor * pooled = NULL ;
461+ std::vector<float > hidden_states_vec;
462+
463+ size_t chunk_len = 77 ;
464+ size_t chunk_count = tokens.size () / chunk_len;
465+ for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
466+ std::vector<int > chunk_tokens (tokens.begin () + chunk_idx * chunk_len,
467+ tokens.begin () + (chunk_idx + 1 ) * chunk_len);
468+ std::vector<float > chunk_weights (weights.begin () + chunk_idx * chunk_len,
469+ weights.begin () + (chunk_idx + 1 ) * chunk_len);
470+
471+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
472+ struct ggml_tensor * input_ids2 = NULL ;
473+ size_t max_token_idx = 0 ;
474+ if (version == VERSION_XL) {
475+ auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), EOS_TOKEN_ID);
476+ if (it != chunk_tokens.end ()) {
477+ std::fill (std::next (it), chunk_tokens.end (), 0 );
478+ }
469479
470- max_token_idx = std::min<size_t >(std::distance (tokens .begin (), it), tokens .size () - 1 );
480+ max_token_idx = std::min<size_t >(std::distance (chunk_tokens .begin (), it), chunk_tokens .size () - 1 );
471481
472- input_ids2 = vector_to_ggml_tensor_i32 (work_ctx, tokens );
482+ input_ids2 = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens );
473483
474- // for (int i = 0; i < tokens .size(); i++) {
475- // printf("%d ", tokens [i]);
476- // }
477- // printf("\n");
478- }
484+ // for (int i = 0; i < chunk_tokens .size(); i++) {
485+ // printf("%d ", chunk_tokens [i]);
486+ // }
487+ // printf("\n");
488+ }
479489
480- cond_stage_model->compute (n_threads, input_ids, input_ids2, max_token_idx, false , &hidden_states , work_ctx);
481- if (version == VERSION_XL) {
482- cond_stage_model->compute (n_threads, input_ids, input_ids2, max_token_idx, true , &pooled, work_ctx);
483- }
484- // if (pooled != NULL) {
485- // print_ggml_tensor(hidden_states );
486- // print_ggml_tensor(pooled);
487- // }
490+ cond_stage_model->compute (n_threads, input_ids, input_ids2, max_token_idx, false , &chunk_hidden_states , work_ctx);
491+ if (version == VERSION_XL && chunk_idx == 0 ) {
492+ cond_stage_model->compute (n_threads, input_ids, input_ids2, max_token_idx, true , &pooled, work_ctx);
493+ }
494+ // if (pooled != NULL) {
495+ // print_ggml_tensor(chunk_hidden_states );
496+ // print_ggml_tensor(pooled);
497+ // }
488498
489- int64_t t1 = ggml_time_ms ();
490- LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
491- ggml_tensor* result = ggml_dup_tensor (work_ctx, hidden_states);
492- {
493- float original_mean = ggml_tensor_mean (hidden_states);
494- for (int i2 = 0 ; i2 < hidden_states->ne [2 ]; i2++) {
495- for (int i1 = 0 ; i1 < hidden_states->ne [1 ]; i1++) {
496- for (int i0 = 0 ; i0 < hidden_states->ne [0 ]; i0++) {
497- float value = ggml_tensor_get_f32 (hidden_states, i0, i1, i2);
498- value *= weights[i1];
499- ggml_tensor_set_f32 (result, value, i0, i1, i2);
499+ int64_t t1 = ggml_time_ms ();
500+ LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
501+ ggml_tensor* result = ggml_dup_tensor (work_ctx, chunk_hidden_states);
502+ {
503+ float original_mean = ggml_tensor_mean (chunk_hidden_states);
504+ for (int i2 = 0 ; i2 < chunk_hidden_states->ne [2 ]; i2++) {
505+ for (int i1 = 0 ; i1 < chunk_hidden_states->ne [1 ]; i1++) {
506+ for (int i0 = 0 ; i0 < chunk_hidden_states->ne [0 ]; i0++) {
507+ float value = ggml_tensor_get_f32 (chunk_hidden_states, i0, i1, i2);
508+ value *= chunk_weights[i1];
509+ ggml_tensor_set_f32 (result, value, i0, i1, i2);
510+ }
500511 }
501512 }
513+ float new_mean = ggml_tensor_mean (result);
514+ ggml_tensor_scale (result, (original_mean / new_mean));
502515 }
503- float new_mean = ggml_tensor_mean (result);
504- ggml_tensor_scale (result, (original_mean / new_mean));
505- }
506- if (force_zero_embeddings) {
507- float * vec = (float *)result->data ;
508- for (int i = 0 ; i < ggml_nelements (result); i++) {
509- vec[i] = 0 ;
516+ if (force_zero_embeddings) {
517+ float * vec = (float *)result->data ;
518+ for (int i = 0 ; i < ggml_nelements (result); i++) {
519+ vec[i] = 0 ;
520+ }
510521 }
522+ hidden_states_vec.insert (hidden_states_vec.end (), (float *)result->data , ((float *)result->data ) + ggml_nelements (result));
511523 }
512524
525+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
526+ hidden_states = ggml_reshape_2d (work_ctx,
527+ hidden_states,
528+ chunk_hidden_states->ne [0 ],
529+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
530+
513531 ggml_tensor* vec = NULL ;
514532 if (version == VERSION_XL) {
515533 int out_dim = 256 ;
@@ -547,7 +565,7 @@ class StableDiffusionGGML {
547565 GGML_ASSERT (offset == ggml_nbytes (vec));
548566 }
549567 // print_ggml_tensor(result);
550- return {result , vec};
568+ return {hidden_states , vec};
551569 }
552570
553571 std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> get_svd_condition (ggml_context* work_ctx,
0 commit comments