Skip to content

Commit a5b73ac

Browse files
committed
Use env variable to control chroma padding settings
1 parent b304c14 commit a5b73ac

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

conditioner.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,16 @@ struct PixArtCLIPEmbedder : public Conditioner {
14711471
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
14721472
auto t5_attn_mask_chunk = vector_to_ggml_tensor(work_ctx, chunk_mask);
14731473

1474+
const char* SD_CHROMA_USE_T5_MASK = getenv("SD_CHROMA_USE_T5_MASK");
1475+
if (SD_CHROMA_USE_T5_MASK != nullptr) {
1476+
std::string sd_chroma_use_t5_mask_str = SD_CHROMA_USE_T5_MASK;
1477+
if (sd_chroma_use_t5_mask_str == "OFF" || sd_chroma_use_t5_mask_str == "FALSE") {
1478+
t5_attn_mask_chunk = NULL;
1479+
} else if (sd_chroma_use_t5_mask_str != "ON" && sd_chroma_use_t5_mask_str != "TRUE") {
1480+
LOG_WARN("SD_CHROMA_USE_T5_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_T5_MASK);
1481+
}
1482+
}
1483+
14741484
t5->compute(n_threads,
14751485
input_ids,
14761486
&chunk_hidden_states,

flux.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,7 +1113,30 @@ namespace Flux {
11131113
c_concat = to_backend(c_concat);
11141114
}
11151115
if (flux_params.is_chroma) {
1116-
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1);
1116+
int mask_pad = 1;
1117+
const char* SD_CHROMA_MASK_PAD_OVERRIDE = getenv("SD_CHROMA_MASK_PAD_OVERRIDE");
1118+
if (SD_CHROMA_MASK_PAD_OVERRIDE != nullptr) {
1119+
std::string mask_pad_str = SD_CHROMA_MASK_PAD_OVERRIDE;
1120+
try {
1121+
mask_pad = std::stoi(mask_pad_str);
1122+
} catch (const std::invalid_argument&) {
1123+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable is not a valid integer (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1124+
} catch (const std::out_of_range&) {
1125+
LOG_WARN("SD_CHROMA_MASK_PAD_OVERRIDE environment variable value is out of range for `int` type (%s). Falling back to default (%d)", SD_CHROMA_MASK_PAD_OVERRIDE, mask_pad);
1126+
}
1127+
}
1128+
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), mask_pad);
1129+
1130+
const char* SD_CHROMA_USE_DIT_MASK = getenv("SD_CHROMA_USE_DIT_MASK");
1131+
if (SD_CHROMA_USE_DIT_MASK != nullptr) {
1132+
std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
1133+
if (sd_chroma_use_DiT_mask_str == "OFF" || sd_chroma_use_DiT_mask_str == "FALSE") {
1134+
y = NULL;
1135+
} else if (sd_chroma_use_DiT_mask_str != "ON" && sd_chroma_use_DiT_mask_str != "TRUE") {
1136+
LOG_WARN("SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\"ON\"). (Expected \"ON\"/\"TRUE\" or\"OFF\"/\"FALSE\", got \"%s\")", SD_CHROMA_USE_DIT_MASK);
1137+
}
1138+
}
1139+
11171140
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
11181141
range = arange(0, 344);
11191142
precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());

0 commit comments

Comments
 (0)