Skip to content

Commit b304c14

Browse files
committed
implement chroma mask padding
1 parent 74e9a69 commit b304c14

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

flux.hpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,18 @@ namespace Flux {
711711
}
712712

713713
void chroma_modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
714-
// TODO: implement
714+
float* mask_data = (float*)mask->data;
715+
int num_pad = 0;
716+
for (int64_t i = 0; i < max_seq_length; i++) {
717+
if (num_pad >= num_extra_padding) {
718+
break;
719+
}
720+
if (isinf(mask_data[i])) {
721+
mask_data[i] = 0;
722+
++num_pad;
723+
}
724+
}
725+
// LOG_DEBUG("PAD: %d", num_pad);
715726
}
716727

717728
// Generate positional embeddings
@@ -1102,7 +1113,7 @@ namespace Flux {
11021113
c_concat = to_backend(c_concat);
11031114
}
11041115
if (flux_params.is_chroma) {
1105-
flux.chroma_modify_mask_to_attend_padding(y, context->ne[1], 1);
1116+
flux.chroma_modify_mask_to_attend_padding(y, ggml_nelements(y), 1);
11061117
// ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
11071118
range = arange(0, 344);
11081119
precompute_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, range.size());

0 commit comments

Comments
 (0)