Skip to content

Commit 74e9a69

Browse files
committed
Chroma: Attention masking (no pad)
1 parent 57b0557 commit 74e9a69

File tree

4 files changed

+269
-43
lines changed

4 files changed

+269
-43
lines changed

conditioner.hpp

Lines changed: 183 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,7 @@ struct SD3CLIPEmbedder : public Conditioner {
813813
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
814814
}
815815
if (use_t5) {
816-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
816+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
817817
}
818818

819819
// for (int i = 0; i < clip_l_tokens.size(); i++) {
@@ -1216,7 +1216,7 @@ struct FluxCLIPEmbedder : public Conditioner {
12161216
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
12171217
}
12181218
if (use_t5) {
1219-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
1219+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
12201220
}
12211221

12221222
// for (int i = 0; i < clip_l_tokens.size(); i++) {
@@ -1311,7 +1311,6 @@ struct FluxCLIPEmbedder : public Conditioner {
13111311
ggml_set_f32(chunk_hidden_states, 0.f);
13121312
}
13131313

1314-
13151314
int64_t t1 = ggml_time_ms();
13161315
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
13171316
if (force_zero_embeddings) {
@@ -1320,12 +1319,11 @@ struct FluxCLIPEmbedder : public Conditioner {
13201319
vec[i] = 0;
13211320
}
13221321
}
1323-
13241322
hidden_states_vec.insert(hidden_states_vec.end(),
1325-
(float*)chunk_hidden_states->data,
1326-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1323+
(float*)chunk_hidden_states->data,
1324+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
13271325
}
1328-
1326+
13291327
if (hidden_states_vec.size() > 0) {
13301328
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
13311329
hidden_states = ggml_reshape_2d(work_ctx,
@@ -1373,4 +1371,182 @@ struct FluxCLIPEmbedder : public Conditioner {
13731371
}
13741372
};
13751373

1374+
struct PixArtCLIPEmbedder : public Conditioner {
1375+
T5UniGramTokenizer t5_tokenizer;
1376+
std::shared_ptr<T5Runner> t5;
1377+
1378+
PixArtCLIPEmbedder(ggml_backend_t backend,
1379+
std::map<std::string, enum ggml_type>& tensor_types,
1380+
int clip_skip = -1) {
1381+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1382+
}
1383+
1384+
void set_clip_skip(int clip_skip) {
1385+
}
1386+
1387+
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1388+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1389+
}
1390+
1391+
void alloc_params_buffer() {
1392+
t5->alloc_params_buffer();
1393+
}
1394+
1395+
void free_params_buffer() {
1396+
t5->free_params_buffer();
1397+
}
1398+
1399+
size_t get_params_buffer_size() {
1400+
size_t buffer_size = 0;
1401+
1402+
buffer_size += t5->get_params_buffer_size();
1403+
1404+
return buffer_size;
1405+
}
1406+
1407+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
1408+
size_t max_length = 0,
1409+
bool padding = false) {
1410+
auto parsed_attention = parse_prompt_attention(text);
1411+
1412+
{
1413+
std::stringstream ss;
1414+
ss << "[";
1415+
for (const auto& item : parsed_attention) {
1416+
ss << "['" << item.first << "', " << item.second << "], ";
1417+
}
1418+
ss << "]";
1419+
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
1420+
}
1421+
1422+
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
1423+
return false;
1424+
};
1425+
1426+
std::vector<int> t5_tokens;
1427+
std::vector<float> t5_weights;
1428+
std::vector<float> t5_mask;
1429+
for (const auto& item : parsed_attention) {
1430+
const std::string& curr_text = item.first;
1431+
float curr_weight = item.second;
1432+
1433+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1434+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1435+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1436+
}
1437+
1438+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
1439+
1440+
return {t5_tokens, t5_weights, t5_mask};
1441+
}
1442+
1443+
SDCondition get_learned_condition_common(ggml_context* work_ctx,
1444+
int n_threads,
1445+
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
1446+
int clip_skip,
1447+
bool force_zero_embeddings = false) {
1448+
auto& t5_tokens = std::get<0>(token_and_weights);
1449+
auto& t5_weights = std::get<1>(token_and_weights);
1450+
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
1451+
1452+
int64_t t0 = ggml_time_ms();
1453+
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
1454+
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
1455+
struct ggml_tensor* pooled = NULL; // [768,]
1456+
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
1457+
1458+
std::vector<float> hidden_states_vec;
1459+
1460+
size_t chunk_len = 256;
1461+
size_t chunk_count = t5_tokens.size() / chunk_len;
1462+
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
1463+
// t5
1464+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1465+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1466+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1467+
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1468+
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
1469+
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
1470+
1471+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1472+
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
1473+
1474+
t5->compute(n_threads,
1475+
input_ids,
1476+
&chunk_hidden_states,
1477+
work_ctx,
1478+
t5_attn_mask_chunk);
1479+
{
1480+
auto tensor = chunk_hidden_states;
1481+
float original_mean = ggml_tensor_mean(tensor);
1482+
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
1483+
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
1484+
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
1485+
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
1486+
value *= chunk_weights[i1];
1487+
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
1488+
}
1489+
}
1490+
}
1491+
float new_mean = ggml_tensor_mean(tensor);
1492+
ggml_tensor_scale(tensor, (original_mean / new_mean));
1493+
}
1494+
1495+
int64_t t1 = ggml_time_ms();
1496+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
1497+
if (force_zero_embeddings) {
1498+
float* vec = (float*)chunk_hidden_states->data;
1499+
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
1500+
vec[i] = 0;
1501+
}
1502+
}
1503+
1504+
hidden_states_vec.insert(hidden_states_vec.end(),
1505+
(float*)chunk_hidden_states->data,
1506+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1507+
}
1508+
1509+
if (hidden_states_vec.size() > 0) {
1510+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1511+
hidden_states = ggml_reshape_2d(work_ctx,
1512+
hidden_states,
1513+
chunk_hidden_states->ne[0],
1514+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1515+
} else {
1516+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1517+
ggml_set_f32(hidden_states, 0.f);
1518+
}
1519+
return SDCondition(hidden_states, t5_attn_mask, NULL);
1520+
}
1521+
1522+
SDCondition get_learned_condition(ggml_context* work_ctx,
1523+
int n_threads,
1524+
const std::string& text,
1525+
int clip_skip,
1526+
int width,
1527+
int height,
1528+
int adm_in_channels = -1,
1529+
bool force_zero_embeddings = false) {
1530+
auto tokens_and_weights = tokenize(text, 512, true);
1531+
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1532+
}
1533+
1534+
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
1535+
int n_threads,
1536+
const std::string& text,
1537+
int clip_skip,
1538+
int width,
1539+
int height,
1540+
int num_input_imgs,
1541+
int adm_in_channels = -1,
1542+
bool force_zero_embeddings = false) {
1543+
GGML_ASSERT(0 && "Not implemented yet!");
1544+
}
1545+
1546+
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
1547+
const std::string& prompt) {
1548+
GGML_ASSERT(0 && "Not implemented yet!");
1549+
}
1550+
};
1551+
13761552
#endif

0 commit comments

Comments
 (0)