@@ -82,7 +82,16 @@ namespace Rope {
8282 return txt_ids;
8383 }
8484
85- __STATIC_INLINE__ std::vector<std::vector<float >> gen_flux_img_ids (int h,
85+ __STATIC_INLINE__ std::vector<std::vector<float >> gen_longcat_txt_ids (int bs, int context_len, int axes_dim_num) {
86+ auto txt_ids = std::vector<std::vector<float >>(bs * context_len, std::vector<float >(axes_dim_num, 0 .0f ));
87+ for (int i = 0 ; i < bs * context_len; i++) {
88+ txt_ids[i][1 ] = (i % context_len);
89+ txt_ids[i][2 ] = (i % context_len);
90+ }
91+ return txt_ids;
92+ }
93+
94+ __STATIC_INLINE__ std::vector<std::vector<float >> gen_flux_img_ids (int h,
8695 int w,
8796 int patch_size,
8897 int bs,
@@ -92,7 +101,6 @@ namespace Rope {
92101 int w_offset = 0 ) {
93102 int h_len = (h + (patch_size / 2 )) / patch_size;
94103 int w_len = (w + (patch_size / 2 )) / patch_size;
95-
96104 std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(axes_dim_num, 0.0 ));
97105
98106 std::vector<float > row_ids = linspace<float >(h_offset, h_len - 1 + h_offset, h_len);
@@ -167,13 +175,14 @@ namespace Rope {
167175 __STATIC_INLINE__ std::vector<std::vector<float >> gen_refs_ids (int patch_size,
168176 int bs,
169177 int axes_dim_num,
178+ int start_index,
170179 const std::vector<ggml_tensor*>& ref_latents,
171180 bool increase_ref_index,
172181 float ref_index_scale) {
173182 std::vector<std::vector<float >> ids;
174183 uint64_t curr_h_offset = 0 ;
175184 uint64_t curr_w_offset = 0 ;
176- int index = 1 ;
185+ int index = start_index ;
177186 for (ggml_tensor* ref : ref_latents) {
178187 uint64_t h_offset = 0 ;
179188 uint64_t w_offset = 0 ;
@@ -213,13 +222,17 @@ namespace Rope {
213222 int context_len,
214223 const std::vector<ggml_tensor*>& ref_latents,
215224 bool increase_ref_index,
216- float ref_index_scale) {
217- auto txt_ids = gen_flux_txt_ids (bs, context_len, axes_dim_num);
218- auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num);
225+ float ref_index_scale,
226+ bool is_longcat) {
227+ int start_index = is_longcat ? 1 : 0 ;
228+
229+ auto txt_ids = is_longcat ? gen_longcat_txt_ids (bs, context_len, axes_dim_num) : gen_flux_txt_ids (bs, context_len, axes_dim_num);
230+ int offset = is_longcat ? context_len : 0 ;
231+ auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num, start_index, offset, offset);
219232
220233 auto ids = concat_ids (txt_ids, img_ids, bs);
221234 if (ref_latents.size () > 0 ) {
222- auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
235+ auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, start_index + 1 , ref_latents, increase_ref_index, ref_index_scale);
223236 ids = concat_ids (ids, refs_ids, bs);
224237 }
225238 return ids;
@@ -235,7 +248,8 @@ namespace Rope {
235248 bool increase_ref_index,
236249 float ref_index_scale,
237250 int theta,
238- const std::vector<int >& axes_dim) {
251+ const std::vector<int >& axes_dim,
252+ bool is_longcat) {
239253 std::vector<std::vector<float >> ids = gen_flux_ids (h,
240254 w,
241255 patch_size,
@@ -244,7 +258,8 @@ namespace Rope {
244258 context_len,
245259 ref_latents,
246260 increase_ref_index,
247- ref_index_scale);
261+ ref_index_scale,
262+ is_longcat);
248263 return embed_nd (ids, bs, theta, axes_dim);
249264 }
250265
@@ -269,7 +284,7 @@ namespace Rope {
269284 auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num);
270285 auto ids = concat_ids (txt_ids_repeated, img_ids, bs);
271286 if (ref_latents.size () > 0 ) {
272- auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1 .f );
287+ auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, 1 , ref_latents, increase_ref_index, 1 .f );
273288 ids = concat_ids (ids, refs_ids, bs);
274289 }
275290 return ids;
0 commit comments