@@ -1087,20 +1087,28 @@ struct FluxCLIPEmbedder : public Conditioner {
10871087 int64_t t0 = ggml_time_ms ();
10881088 struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
10891089 struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token*2, 4096]
1090- struct ggml_tensor * chunk_hidden_states_l = NULL ; // [n_token, hidden_size_l]
1091- struct ggml_tensor * chunk_hidden_states_t5 = NULL ; // [n_token, hidden_size_t5]
10921090 struct ggml_tensor * pooled = NULL ; // [768,]
10931091 std::vector<float > hidden_states_vec;
10941092
1095- size_t chunk_len = 77 ;
1096- size_t chunk_count = clip_l_tokens.size () / chunk_len;
1093+ size_t chunk_len_l = 77 ;
1094+ size_t chunk_count_l = clip_l_tokens.size () / chunk_len_l;
1095+
1096+ size_t chunk_len_t5 = 256 ;
1097+ size_t chunk_count_t5 = t5_tokens.size () / chunk_len_t5;
1098+
1099+ // TODO: I believe chunk_count_l is actually bigger than chunk_count_t5
1100+ // So this ignores some tokens for clip
1101+ size_t chunk_count = chunk_count_t5;
1102+
10971103 for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1104+ struct ggml_tensor * chunk_hidden_states_l = NULL ; // [n_token, hidden_size_l]
1105+ struct ggml_tensor * chunk_hidden_states_t5 = NULL ; // [n_token, hidden_size_t5]
10981106 // clip_l
1099- {
1100- std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len ,
1101- clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len );
1102- std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len ,
1103- clip_l_weights.begin () + (chunk_idx + 1 ) * chunk_len );
1107+ if (chunk_idx < chunk_count_l) {
1108+ std::vector<int > chunk_tokens (clip_l_tokens.begin () + chunk_idx * chunk_len_l ,
1109+ clip_l_tokens.begin () + (chunk_idx + 1 ) * chunk_len_l );
1110+ std::vector<float > chunk_weights (clip_l_weights.begin () + chunk_idx * chunk_len_l ,
1111+ clip_l_weights.begin () + (chunk_idx + 1 ) * chunk_len_l );
11041112
11051113 auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
11061114 size_t max_token_idx = 0 ;
@@ -1129,7 +1137,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11291137 ggml_tensor_scale (tensor, (original_mean / new_mean));
11301138 }
11311139 if (chunk_idx == 0 ) {
1132- size_t chunk_len_l = 77 ;
11331140 std::vector<int > chunk_tokens (clip_l_tokens.begin (),
11341141 clip_l_tokens.begin () + chunk_len_l);
11351142 std::vector<float > chunk_weights (clip_l_weights.begin (),
@@ -1157,11 +1164,11 @@ struct FluxCLIPEmbedder : public Conditioner {
11571164 }
11581165
11591166 // t5
1160- {
1161- std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len ,
1162- t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len );
1163- std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len ,
1164- t5_weights.begin () + (chunk_idx + 1 ) * chunk_len );
1167+ if (chunk_idx < chunk_count_t5) {
1168+ std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len_t5 ,
1169+ t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len_t5 );
1170+ std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len_t5 ,
1171+ t5_weights.begin () + (chunk_idx + 1 ) * chunk_len_t5 );
11651172
11661173 auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
11671174
@@ -1205,8 +1212,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12051212 }
12061213 }
12071214 }
1208-
1209- chunk_hidden_states = ggml_tensor_concat (work_ctx, chunk_hidden_states_l_pad, chunk_hidden_states_t5, 1 ); // [n_token*2, 4096]
1215+
1216+ if (chunk_hidden_states_t5 == NULL ){
1217+ chunk_hidden_states = chunk_hidden_states_l_pad;
1218+ } else {
1219+ chunk_hidden_states = ggml_tensor_concat (work_ctx, chunk_hidden_states_l_pad, chunk_hidden_states_t5, 1 ); // [n_token*2, 4096]
1220+ }
12101221
12111222
12121223 int64_t t1 = ggml_time_ms ();
0 commit comments