Skip to content

Commit 6b5dc9d

Browse files
committed
Chroma: Fix t5 chunk length
1 parent 957aec1 commit 6b5dc9d

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

conditioner.hpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ struct FluxCLIPEmbedder : public Conditioner {
10951095
T5UniGramTokenizer t5_tokenizer;
10961096
std::shared_ptr<CLIPTextModelRunner> clip_l;
10971097
std::shared_ptr<T5Runner> t5;
1098+
size_t chunk_len = 256;
10981099

10991100
bool use_clip_l = false;
11001101
bool use_t5 = false;
@@ -1249,7 +1250,6 @@ struct FluxCLIPEmbedder : public Conditioner {
12491250
struct ggml_tensor* pooled = NULL; // [768,]
12501251
std::vector<float> hidden_states_vec;
12511252

1252-
size_t chunk_len = 256;
12531253
size_t chunk_count = std::max(clip_l_tokens.size() > 0 ? chunk_len : 0, t5_tokens.size()) / chunk_len;
12541254
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
12551255
// clip_l
@@ -1349,7 +1349,7 @@ struct FluxCLIPEmbedder : public Conditioner {
13491349
int height,
13501350
int adm_in_channels = -1,
13511351
bool force_zero_embeddings = false) {
1352-
auto tokens_and_weights = tokenize(text, 256, true);
1352+
auto tokens_and_weights = tokenize(text, chunk_len, true);
13531353
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
13541354
}
13551355

@@ -1374,6 +1374,7 @@ struct FluxCLIPEmbedder : public Conditioner {
13741374
struct PixArtCLIPEmbedder : public Conditioner {
13751375
T5UniGramTokenizer t5_tokenizer;
13761376
std::shared_ptr<T5Runner> t5;
1377+
size_t chunk_len = 512;
13771378

13781379
PixArtCLIPEmbedder(ggml_backend_t backend,
13791380
std::map<std::string, enum ggml_type>& tensor_types,
@@ -1457,8 +1458,18 @@ struct PixArtCLIPEmbedder : public Conditioner {
14571458

14581459
std::vector<float> hidden_states_vec;
14591460

1460-
size_t chunk_len = 256;
14611461
size_t chunk_count = t5_tokens.size() / chunk_len;
1462+
1463+
bool use_mask = true;
1464+
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1465+
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1466+
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1467+
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1468+
use_mask = false;
1469+
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1470+
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1471+
}
1472+
}
14621473
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
14631474
// t5
14641475
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
@@ -1469,17 +1480,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
14691480
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
14701481

14711482
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1472-
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
1473-
1474-
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1475-
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1476-
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1477-
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1478-
t5_attn_mask_chunk = NULL;
1479-
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1480-
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1481-
}
1482-
}
1483+
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
14831484

14841485
t5->compute(n_threads,
14851486
input_ids,
@@ -1537,7 +1538,7 @@ struct PixArtCLIPEmbedder : public Conditioner {
15371538
int height,
15381539
int adm_in_channels = -1,
15391540
bool force_zero_embeddings = false) {
1540-
auto tokens_and_weights = tokenize(text, 512, true);
1541+
auto tokens_and_weights = tokenize(text, chunk_len, true);
15411542
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
15421543
}
15431544

0 commit comments

Comments
 (0)