@@ -813,7 +813,7 @@ struct SD3CLIPEmbedder : public Conditioner {
813813 clip_g_tokenizer.pad_tokens (clip_g_tokens, clip_g_weights, max_length, padding);
814814 }
815815 if (use_t5) {
816- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
816+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
817817 }
818818
819819 // for (int i = 0; i < clip_l_tokens.size(); i++) {
@@ -1216,7 +1216,7 @@ struct FluxCLIPEmbedder : public Conditioner {
12161216 clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
12171217 }
12181218 if (use_t5) {
1219- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
1219+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, NULL , max_length, padding);
12201220 }
12211221
12221222 // for (int i = 0; i < clip_l_tokens.size(); i++) {
@@ -1311,7 +1311,6 @@ struct FluxCLIPEmbedder : public Conditioner {
13111311 ggml_set_f32 (chunk_hidden_states, 0 .f );
13121312 }
13131313
1314-
13151314 int64_t t1 = ggml_time_ms ();
13161315 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
13171316 if (force_zero_embeddings) {
@@ -1320,12 +1319,11 @@ struct FluxCLIPEmbedder : public Conditioner {
13201319 vec[i] = 0 ;
13211320 }
13221321 }
1323-
13241322 hidden_states_vec.insert (hidden_states_vec.end (),
1325- (float *)chunk_hidden_states->data ,
1326- ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1323+ (float *)chunk_hidden_states->data ,
1324+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
13271325 }
1328-
1326+
13291327 if (hidden_states_vec.size () > 0 ) {
13301328 hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
13311329 hidden_states = ggml_reshape_2d (work_ctx,
@@ -1373,4 +1371,182 @@ struct FluxCLIPEmbedder : public Conditioner {
13731371 }
13741372};
13751373
1374+ struct PixArtCLIPEmbedder : public Conditioner {
1375+ T5UniGramTokenizer t5_tokenizer;
1376+ std::shared_ptr<T5Runner> t5;
1377+
1378+ PixArtCLIPEmbedder (ggml_backend_t backend,
1379+ std::map<std::string, enum ggml_type>& tensor_types,
1380+ int clip_skip = -1 ) {
1381+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1382+ }
1383+
1384+ void set_clip_skip (int clip_skip) {
1385+ }
1386+
1387+ void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1388+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1389+ }
1390+
1391+ void alloc_params_buffer () {
1392+ t5->alloc_params_buffer ();
1393+ }
1394+
1395+ void free_params_buffer () {
1396+ t5->free_params_buffer ();
1397+ }
1398+
1399+ size_t get_params_buffer_size () {
1400+ size_t buffer_size = 0 ;
1401+
1402+ buffer_size += t5->get_params_buffer_size ();
1403+
1404+ return buffer_size;
1405+ }
1406+
1407+ std::tuple<std::vector<int >, std::vector<float >, std::vector<float >> tokenize (std::string text,
1408+ size_t max_length = 0 ,
1409+ bool padding = false ) {
1410+ auto parsed_attention = parse_prompt_attention (text);
1411+
1412+ {
1413+ std::stringstream ss;
1414+ ss << " [" ;
1415+ for (const auto & item : parsed_attention) {
1416+ ss << " ['" << item.first << " ', " << item.second << " ], " ;
1417+ }
1418+ ss << " ]" ;
1419+ LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
1420+ }
1421+
1422+ auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
1423+ return false ;
1424+ };
1425+
1426+ std::vector<int > t5_tokens;
1427+ std::vector<float > t5_weights;
1428+ std::vector<float > t5_mask;
1429+ for (const auto & item : parsed_attention) {
1430+ const std::string& curr_text = item.first ;
1431+ float curr_weight = item.second ;
1432+
1433+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1434+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1435+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1436+ }
1437+
1438+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, &t5_mask, max_length, padding);
1439+
1440+ return {t5_tokens, t5_weights, t5_mask};
1441+ }
1442+
1443+ SDCondition get_learned_condition_common (ggml_context* work_ctx,
1444+ int n_threads,
1445+ std::tuple<std::vector<int >, std::vector<float >, std::vector<float >> token_and_weights,
1446+ int clip_skip,
1447+ bool force_zero_embeddings = false ) {
1448+ auto & t5_tokens = std::get<0 >(token_and_weights);
1449+ auto & t5_weights = std::get<1 >(token_and_weights);
1450+ auto & t5_attn_mask_vec = std::get<2 >(token_and_weights);
1451+
1452+ int64_t t0 = ggml_time_ms ();
1453+ struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
1454+ struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token, 4096]
1455+ struct ggml_tensor * pooled = NULL ; // [768,]
1456+ struct ggml_tensor * t5_attn_mask = vector_to_ggml_tensor (work_ctx, t5_attn_mask_vec); // [768,]
1457+
1458+ std::vector<float > hidden_states_vec;
1459+
1460+ size_t chunk_len = 256 ;
1461+ size_t chunk_count = t5_tokens.size () / chunk_len;
1462+ for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1463+ // t5
1464+ std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
1465+ t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
1466+ std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
1467+ t5_weights.begin () + (chunk_idx + 1 ) * chunk_len);
1468+ std::vector<float > chunk_mask (t5_attn_mask_vec.begin () + chunk_idx * chunk_len,
1469+ t5_attn_mask_vec.begin () + (chunk_idx + 1 ) * chunk_len);
1470+
1471+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1472+ auto t5_attn_mask_chunk = vector_to_ggml_tensor (work_ctx, chunk_mask);
1473+
1474+ t5->compute (n_threads,
1475+ input_ids,
1476+ &chunk_hidden_states,
1477+ work_ctx,
1478+ t5_attn_mask_chunk);
1479+ {
1480+ auto tensor = chunk_hidden_states;
1481+ float original_mean = ggml_tensor_mean (tensor);
1482+ for (int i2 = 0 ; i2 < tensor->ne [2 ]; i2++) {
1483+ for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) {
1484+ for (int i0 = 0 ; i0 < tensor->ne [0 ]; i0++) {
1485+ float value = ggml_tensor_get_f32 (tensor, i0, i1, i2);
1486+ value *= chunk_weights[i1];
1487+ ggml_tensor_set_f32 (tensor, value, i0, i1, i2);
1488+ }
1489+ }
1490+ }
1491+ float new_mean = ggml_tensor_mean (tensor);
1492+ ggml_tensor_scale (tensor, (original_mean / new_mean));
1493+ }
1494+
1495+ int64_t t1 = ggml_time_ms ();
1496+ LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
1497+ if (force_zero_embeddings) {
1498+ float * vec = (float *)chunk_hidden_states->data ;
1499+ for (int i = 0 ; i < ggml_nelements (chunk_hidden_states); i++) {
1500+ vec[i] = 0 ;
1501+ }
1502+ }
1503+
1504+ hidden_states_vec.insert (hidden_states_vec.end (),
1505+ (float *)chunk_hidden_states->data ,
1506+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1507+ }
1508+
1509+ if (hidden_states_vec.size () > 0 ) {
1510+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1511+ hidden_states = ggml_reshape_2d (work_ctx,
1512+ hidden_states,
1513+ chunk_hidden_states->ne [0 ],
1514+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1515+ } else {
1516+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1517+ ggml_set_f32 (hidden_states, 0 .f );
1518+ }
1519+ return SDCondition (hidden_states, t5_attn_mask, NULL );
1520+ }
1521+
1522+ SDCondition get_learned_condition (ggml_context* work_ctx,
1523+ int n_threads,
1524+ const std::string& text,
1525+ int clip_skip,
1526+ int width,
1527+ int height,
1528+ int adm_in_channels = -1 ,
1529+ bool force_zero_embeddings = false ) {
1530+ auto tokens_and_weights = tokenize (text, 512 , true );
1531+ return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1532+ }
1533+
1534+ std::tuple<SDCondition, std::vector<bool >> get_learned_condition_with_trigger (ggml_context* work_ctx,
1535+ int n_threads,
1536+ const std::string& text,
1537+ int clip_skip,
1538+ int width,
1539+ int height,
1540+ int num_input_imgs,
1541+ int adm_in_channels = -1 ,
1542+ bool force_zero_embeddings = false ) {
1543+ GGML_ASSERT (0 && " Not implemented yet!" );
1544+ }
1545+
1546+ std::string remove_trigger_from_prompt (ggml_context* work_ctx,
1547+ const std::string& prompt) {
1548+ GGML_ASSERT (0 && " Not implemented yet!" );
1549+ }
1550+ };
1551+
13761552#endif
0 commit comments