@@ -87,6 +87,9 @@ struct LoraModel : public GGMLRunner {
8787 break ;
8888 }
8989 }
90+ // if (name.find(".transformer_blocks.0") != std::string::npos) {
91+ // LOG_INFO("%s", name.c_str());
92+ // }
9093
9194 if (dry_run) {
9295 struct ggml_tensor * real = ggml_new_tensor (params_ctx,
@@ -104,7 +107,7 @@ struct LoraModel : public GGMLRunner {
104107
105108 model_loader.load_tensors (on_new_tensor_cb, backend);
106109 alloc_params_buffer ();
107-
110+ // exit(0);
108111 dry_run = false ;
109112 model_loader.load_tensors (on_new_tensor_cb, backend);
110113
@@ -171,32 +174,34 @@ struct LoraModel : public GGMLRunner {
171174 ggml_tensor* lora_down = NULL ;
172175 // LOG_DEBUG("k_tensor %s", k_tensor.c_str());
173176 if (sd_version_is_flux (version)) {
174- size_t l1 = key.find (" linear1" );
177+ size_t l1 = key.find (" linear1" );
178+ size_t l2 = key.find (" linear2" );
179+ size_t mod = key.find (" modulation.lin" );
175180 if (l1 != std::string::npos) {
176- l1 -= 1 ;
177- auto split_q_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_q" + lora_ups [type] + " .weight" ;
178- if (lora_tensors.find (split_q_u_name ) != lora_tensors.end ()) {
181+ l1-- ;
182+ auto split_q_d_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_q" + lora_downs [type] + " .weight" ;
183+ if (lora_tensors.find (split_q_d_name ) != lora_tensors.end ()) {
179184 // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
180185 // find qkv and mlp up parts in LoRA model
181- auto split_k_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_k" + lora_ups[type] + " .weight" ;
182- auto split_v_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_v" + lora_ups[type] + " .weight" ;
183-
184- auto split_q_d_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_q" + lora_downs[type] + " .weight" ;
185186 auto split_k_d_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_k" + lora_downs[type] + " .weight" ;
186187 auto split_v_d_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_v" + lora_downs[type] + " .weight" ;
187188
188- auto split_m_u_name = lora_pre[type] + key.substr (0 , l1) + " .proj_mlp" + lora_ups[type] + " .weight" ;
189+ auto split_q_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_q" + lora_ups[type] + " .weight" ;
190+ auto split_k_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_k" + lora_ups[type] + " .weight" ;
191+ auto split_v_u_name = lora_pre[type] + key.substr (0 , l1) + " .attn.to_v" + lora_ups[type] + " .weight" ;
192+
189193 auto split_m_d_name = lora_pre[type] + key.substr (0 , l1) + " .proj_mlp" + lora_downs[type] + " .weight" ;
194+ auto split_m_u_name = lora_pre[type] + key.substr (0 , l1) + " .proj_mlp" + lora_ups[type] + " .weight" ;
190195
191- ggml_tensor* lora_q_up = NULL ;
192196 ggml_tensor* lora_q_down = NULL ;
193- ggml_tensor* lora_k_up = NULL ;
197+ ggml_tensor* lora_q_up = NULL ;
194198 ggml_tensor* lora_k_down = NULL ;
195- ggml_tensor* lora_v_up = NULL ;
199+ ggml_tensor* lora_k_up = NULL ;
196200 ggml_tensor* lora_v_down = NULL ;
201+ ggml_tensor* lora_v_up = NULL ;
197202
198- ggml_tensor* lora_m_up = NULL ;
199203 ggml_tensor* lora_m_down = NULL ;
204+ ggml_tensor* lora_m_up = NULL ;
200205
201206 lora_q_up = lora_tensors[split_q_u_name];
202207
@@ -301,6 +306,38 @@ struct LoraModel : public GGMLRunner {
301306 // // print_ggml_tensor(it.second, true); // [3072, 21504, 1, 1]
302307 // }
303308 }
309+ } else if (l2 != std::string::npos) {
310+ l2--;
311+ std::string lora_down_name = lora_pre[type] + key.substr (0 , l2) + " .proj_out" + lora_downs[type] + " .weight" ;
312+ if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
313+ std::string lora_up_name = lora_pre[type] + key.substr (0 , l2) + " .proj_out" + lora_ups[type] + " .weight" ;
314+ if (lora_tensors.find (lora_up_name) != lora_tensors.end ()) {
315+ lora_up = lora_tensors[lora_up_name];
316+ }
317+
318+ if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
319+ lora_down = lora_tensors[lora_down_name];
320+ }
321+
322+ applied_lora_tensors.insert (lora_up_name);
323+ applied_lora_tensors.insert (lora_down_name);
324+ }
325+ } else if (mod != std::string::npos) {
326+ mod--;
327+ std::string lora_down_name = lora_pre[type] + key.substr (0 , mod) + " .norm.linear" + lora_downs[type] + " .weight" ;
328+ if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
329+ std::string lora_up_name = lora_pre[type] + key.substr (0 , mod) + " .norm.linear" + lora_ups[type] + " .weight" ;
330+ if (lora_tensors.find (lora_up_name) != lora_tensors.end ()) {
331+ lora_up = lora_tensors[lora_up_name];
332+ }
333+
334+ if (lora_tensors.find (lora_down_name) != lora_tensors.end ()) {
335+ lora_down = lora_tensors[lora_down_name];
336+ }
337+
338+ applied_lora_tensors.insert (lora_up_name);
339+ applied_lora_tensors.insert (lora_down_name);
340+ }
304341 }
305342 }
306343
@@ -380,8 +417,10 @@ struct LoraModel : public GGMLRunner {
380417 total_lora_tensors_count++;
381418 if (applied_lora_tensors.find (kv.first ) == applied_lora_tensors.end ()) {
382419 LOG_WARN (" unused lora tensor %s" , kv.first .c_str ());
383- exit (0 );
384-
420+ print_ggml_tensor (kv.second , true );
421+ if (kv.first .find (" B" ) != std::string::npos) {
422+ exit (0 );
423+ }
385424 } else {
386425 applied_lora_tensors_count++;
387426 }
0 commit comments