Skip to content

Commit 7711efb

Browse files
committed
longcat rope ids
1 parent 00071aa commit 7711efb

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

conditioner.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,17 @@ struct LLMEmbedder : public Conditioner {
18251825
prompt_attn_range.second = prompt.size();
18261826

18271827
prompt += "[/INST]";
1828+
} else if (sd_version_is_longcat(version)) {
1829+
prompt_template_encode_start_idx = 36;
1830+
// prompt_template_encode_end_idx = 5;
1831+
1832+
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
1833+
1834+
prompt_attn_range.first = static_cast<int>(prompt.size());
1835+
prompt += conditioner_params.text;
1836+
prompt_attn_range.second = static_cast<int>(prompt.size());
1837+
1838+
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18281839
} else {
18291840
prompt_template_encode_start_idx = 34;
18301841

flux.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,7 +1282,7 @@ namespace Flux {
12821282
}
12831283

12841284
if (flux_params.diffusers_style) {
1285-
LOG_INFO("Using diffusers-style naming");
1285+
LOG_INFO("Using diffusers-style attention blocks");
12861286
}
12871287

12881288
flux = Flux(flux_params);
@@ -1388,7 +1388,6 @@ namespace Flux {
13881388
for (int i = 0; i < ref_latents.size(); i++) {
13891389
ref_latents[i] = to_backend(ref_latents[i]);
13901390
}
1391-
13921391
pe_vec = Rope::gen_flux_pe(x->ne[1],
13931392
x->ne[0],
13941393
flux_params.patch_size,
@@ -1398,9 +1397,9 @@ namespace Flux {
13981397
sd_version_is_flux2(version) ? true : increase_ref_index,
13991398
flux_params.ref_index_scale,
14001399
flux_params.theta,
1401-
flux_params.axes_dim);
1400+
flux_params.axes_dim,
1401+
sd_version_is_longcat(version));
14021402
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
1403-
// LOG_DEBUG("pos_len %d", pos_len);
14041403
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);
14051404
// pe->data = pe_vec.data();
14061405
// print_ggml_tensor(pe);

ggml_extend.hpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,7 @@ class SplitLinear : public Linear {
22002200
in_features(in_features),
22012201
out_features_vec(out_features_vec),
22022202
bias(bias),
2203-
force_f32(true),
2203+
force_f32(force_f32),
22042204
force_prec_f32(force_prec_f32),
22052205
scale(scale) {}
22062206

@@ -2210,21 +2210,29 @@ class SplitLinear : public Linear {
22102210
if (bias) {
22112211
b = params["bias"];
22122212
}
2213-
// concat all weights and biases together
2214-
for (int i = 1; i < out_features_vec.size(); i++) {
2215-
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
2216-
if (bias) {
2217-
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
2218-
}
2219-
}
22202213
if (ctx->weight_adapter) {
2214+
// concat all weights and biases together so it runs in one linear layer
2215+
for (int i = 1; i < out_features_vec.size(); i++) {
2216+
w = ggml_concat(ctx->ggml_ctx, w, params["weight." + std::to_string(i)], 1);
2217+
if (bias) {
2218+
b = ggml_concat(ctx->ggml_ctx, b, params["bias." + std::to_string(i)], 0);
2219+
}
2220+
}
22212221
WeightAdapter::ForwardParams forward_params;
22222222
forward_params.op_type = WeightAdapter::ForwardParams::op_type_t::OP_LINEAR;
22232223
forward_params.linear.force_prec_f32 = force_prec_f32;
22242224
forward_params.linear.scale = scale;
22252225
return ctx->weight_adapter->forward_with_lora(ctx->ggml_ctx, x, w, b, prefix, forward_params);
22262226
}
2227-
return ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
2227+
auto x0 = ggml_ext_linear(ctx->ggml_ctx, x, w, b, force_prec_f32, scale);
2228+
for (int i = 1; i < out_features_vec.size(); i++) {
2229+
auto wi = params["weight." + std::to_string(i)];
2230+
auto bi = bias ? params["bias." + std::to_string(i)] : nullptr;
2231+
auto xi = ggml_ext_linear(ctx->ggml_ctx, x, wi, bi, force_prec_f32, scale);
2232+
x0 = ggml_concat(ctx->ggml_ctx, x0, xi, 0);
2233+
}
2234+
2235+
return x0;
22282236
}
22292237
};
22302238

rope.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,16 @@ namespace Rope {
8282
return txt_ids;
8383
}
8484

85-
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
85+
__STATIC_INLINE__ std::vector<std::vector<float>> gen_longcat_txt_ids(int bs, int context_len, int axes_dim_num) {
86+
auto txt_ids = std::vector<std::vector<float>>(bs * context_len, std::vector<float>(axes_dim_num, 0.0f));
87+
for (int i = 0; i < bs * context_len; i++) {
88+
txt_ids[i][1] = (i % context_len);
89+
txt_ids[i][2] = (i % context_len);
90+
}
91+
return txt_ids;
92+
}
93+
94+
__STATIC_INLINE__ std::vector<std::vector<float>> gen_flux_img_ids(int h,
8695
int w,
8796
int patch_size,
8897
int bs,
@@ -92,7 +101,6 @@ namespace Rope {
92101
int w_offset = 0) {
93102
int h_len = (h + (patch_size / 2)) / patch_size;
94103
int w_len = (w + (patch_size / 2)) / patch_size;
95-
96104
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
97105

98106
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
@@ -167,13 +175,14 @@ namespace Rope {
167175
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
168176
int bs,
169177
int axes_dim_num,
178+
int start_index,
170179
const std::vector<ggml_tensor*>& ref_latents,
171180
bool increase_ref_index,
172181
float ref_index_scale) {
173182
std::vector<std::vector<float>> ids;
174183
uint64_t curr_h_offset = 0;
175184
uint64_t curr_w_offset = 0;
176-
int index = 1;
185+
int index = start_index;
177186
for (ggml_tensor* ref : ref_latents) {
178187
uint64_t h_offset = 0;
179188
uint64_t w_offset = 0;
@@ -213,13 +222,17 @@ namespace Rope {
213222
int context_len,
214223
const std::vector<ggml_tensor*>& ref_latents,
215224
bool increase_ref_index,
216-
float ref_index_scale) {
217-
auto txt_ids = gen_flux_txt_ids(bs, context_len, axes_dim_num);
218-
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
225+
float ref_index_scale,
226+
bool is_longcat) {
227+
int start_index = is_longcat ? 1 : 0;
228+
229+
auto txt_ids = is_longcat ? gen_longcat_txt_ids(bs, context_len, axes_dim_num) : gen_flux_txt_ids(bs, context_len, axes_dim_num);
230+
int offset = is_longcat ? context_len : 0;
231+
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
219232

220233
auto ids = concat_ids(txt_ids, img_ids, bs);
221234
if (ref_latents.size() > 0) {
222-
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
235+
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, start_index + 1, ref_latents, increase_ref_index, ref_index_scale);
223236
ids = concat_ids(ids, refs_ids, bs);
224237
}
225238
return ids;
@@ -235,7 +248,8 @@ namespace Rope {
235248
bool increase_ref_index,
236249
float ref_index_scale,
237250
int theta,
238-
const std::vector<int>& axes_dim) {
251+
const std::vector<int>& axes_dim,
252+
bool is_longcat) {
239253
std::vector<std::vector<float>> ids = gen_flux_ids(h,
240254
w,
241255
patch_size,
@@ -244,7 +258,8 @@ namespace Rope {
244258
context_len,
245259
ref_latents,
246260
increase_ref_index,
247-
ref_index_scale);
261+
ref_index_scale,
262+
is_longcat);
248263
return embed_nd(ids, bs, theta, axes_dim);
249264
}
250265

@@ -269,7 +284,7 @@ namespace Rope {
269284
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
270285
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
271286
if (ref_latents.size() > 0) {
272-
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
287+
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, 1, ref_latents, increase_ref_index, 1.f);
273288
ids = concat_ids(ids, refs_ids, bs);
274289
}
275290
return ids;

0 commit comments

Comments
 (0)