Skip to content

Commit d22f183

Browse files
committed
Flux Lora working!
1 parent f8db4fa commit d22f183

File tree

1 file changed

+62
-70
lines changed

1 file changed

+62
-70
lines changed

lora.hpp

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ struct LoraModel : public GGMLRunner {
173173
ggml_tensor* lora_up = NULL;
174174
ggml_tensor* lora_down = NULL;
175175

176-
std::string alpha_name = "";
177-
std::string scale_name = "";
176+
std::string alpha_name = "";
177+
std::string scale_name = "";
178+
std::string lora_down_name = "";
179+
std::string lora_up_name = "";
178180
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
179181
if (sd_version_is_flux(version)) {
180182
size_t linear1 = key.find("linear1");
@@ -221,38 +223,38 @@ struct LoraModel : public GGMLRunner {
221223
ggml_tensor* lora_m_down = NULL;
222224
ggml_tensor* lora_m_up = NULL;
223225

224-
lora_q_up = lora_tensors[split_q_u_name];
226+
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
225227

226228
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
227-
lora_q_down = lora_tensors[split_q_d_name];
229+
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
228230
}
229231

230232
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
231-
lora_q_up = lora_tensors[split_q_u_name];
233+
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
232234
}
233235

234236
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
235-
lora_k_down = lora_tensors[split_k_d_name];
237+
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
236238
}
237239

238240
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
239-
lora_k_up = lora_tensors[split_k_u_name];
241+
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
240242
}
241243

242244
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
243-
lora_v_down = lora_tensors[split_v_d_name];
245+
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
244246
}
245247

246248
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
247-
lora_v_up = lora_tensors[split_v_u_name];
249+
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
248250
}
249251

250252
if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) {
251-
lora_m_down = lora_tensors[split_m_d_name];
253+
lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]);
252254
}
253255

254256
if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) {
255-
lora_m_up = lora_tensors[split_m_u_name];
257+
lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]);
256258
}
257259

258260
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
@@ -300,28 +302,26 @@ struct LoraModel : public GGMLRunner {
300302
lora_down = ggml_cont(compute_ctx, lora_down_concat);
301303
lora_up = ggml_cont(compute_ctx, lora_up_concat);
302304

303-
std::string lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
304-
std::string lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
305+
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
306+
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
305307

306308
lora_tensors[lora_down_name] = lora_down;
307309
lora_tensors[lora_up_name] = lora_up;
308310

309-
lora_tensors.erase(split_q_u_name);
310-
lora_tensors.erase(split_k_u_name);
311-
lora_tensors.erase(split_v_u_name);
312-
lora_tensors.erase(split_m_u_name);
311+
// Would be nice to be able to clean up lora_tensors, but it breaks because this is called twice :/
312+
// lora_tensors.erase(split_q_u_name);
313+
// lora_tensors.erase(split_k_u_name);
314+
// lora_tensors.erase(split_v_u_name);
315+
// lora_tensors.erase(split_m_u_name);
313316

314-
lora_tensors.erase(split_q_d_name);
315-
lora_tensors.erase(split_k_d_name);
316-
lora_tensors.erase(split_v_d_name);
317-
lora_tensors.erase(split_m_d_name);
318-
319-
applied_lora_tensors.insert(lora_down_name);
320-
applied_lora_tensors.insert(lora_up_name);
317+
// lora_tensors.erase(split_q_d_name);
318+
// lora_tensors.erase(split_k_d_name);
319+
// lora_tensors.erase(split_v_d_name);
320+
// lora_tensors.erase(split_m_d_name);
321321

322322
} else {
323-
// std::string lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
324-
// std::string lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
323+
// lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
324+
// lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
325325
// if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
326326
// // print_ggml_tensor(lora_tensors[lora_down_name], true); // [3072, R, 1, 1]
327327
// // print_ggml_tensor(lora_tensors[lora_up_name], true); // [R, 21504, 1, 1]
@@ -330,9 +330,9 @@ struct LoraModel : public GGMLRunner {
330330
}
331331
} else if (linear2 != std::string::npos) {
332332
linear2--;
333-
std::string lora_down_name = lora_pre[type] + key.substr(0, linear2) + ".proj_out" + lora_downs[type] + ".weight";
333+
lora_down_name = lora_pre[type] + key.substr(0, linear2) + ".proj_out" + lora_downs[type] + ".weight";
334334
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
335-
std::string lora_up_name = lora_pre[type] + key.substr(0, linear2) + ".proj_out" + lora_ups[type] + ".weight";
335+
lora_up_name = lora_pre[type] + key.substr(0, linear2) + ".proj_out" + lora_ups[type] + ".weight";
336336
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
337337
lora_up = lora_tensors[lora_up_name];
338338
}
@@ -346,9 +346,9 @@ struct LoraModel : public GGMLRunner {
346346
}
347347
} else if (modulation != std::string::npos) {
348348
modulation--;
349-
std::string lora_down_name = lora_pre[type] + key.substr(0, modulation) + ".norm.linear" + lora_downs[type] + ".weight";
349+
lora_down_name = lora_pre[type] + key.substr(0, modulation) + ".norm.linear" + lora_downs[type] + ".weight";
350350
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
351-
std::string lora_up_name = lora_pre[type] + key.substr(0, modulation) + ".norm.linear" + lora_ups[type] + ".weight";
351+
lora_up_name = lora_pre[type] + key.substr(0, modulation) + ".norm.linear" + lora_ups[type] + ".weight";
352352
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
353353
lora_up = lora_tensors[lora_up_name];
354354
}
@@ -391,28 +391,26 @@ struct LoraModel : public GGMLRunner {
391391
ggml_tensor* lora_v_down = NULL;
392392
ggml_tensor* lora_v_up = NULL;
393393

394-
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
395-
lora_q_down = lora_tensors[split_q_d_name];
396-
}
394+
lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]);
397395

398396
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
399-
lora_q_up = lora_tensors[split_q_u_name];
397+
lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]);
400398
}
401399

402400
if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) {
403-
lora_k_down = lora_tensors[split_k_d_name];
401+
lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]);
404402
}
405403

406404
if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) {
407-
lora_k_up = lora_tensors[split_k_u_name];
405+
lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]);
408406
}
409407

410408
if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) {
411-
lora_v_down = lora_tensors[split_v_d_name];
409+
lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]);
412410
}
413411

414412
if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) {
415-
lora_v_up = lora_tensors[split_v_u_name];
413+
lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]);
416414
}
417415

418416
// print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1]
@@ -448,22 +446,20 @@ struct LoraModel : public GGMLRunner {
448446
lora_down = ggml_cont(compute_ctx, lora_down_concat);
449447
lora_up = ggml_cont(compute_ctx, lora_up_concat);
450448

451-
std::string lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
452-
std::string lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
449+
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
450+
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
453451

454452
lora_tensors[lora_down_name] = lora_down;
455453
lora_tensors[lora_up_name] = lora_up;
456454

457-
lora_tensors.erase(split_q_u_name);
458-
lora_tensors.erase(split_k_u_name);
459-
lora_tensors.erase(split_v_u_name);
460-
461-
lora_tensors.erase(split_q_d_name);
462-
lora_tensors.erase(split_k_d_name);
463-
lora_tensors.erase(split_v_d_name);
455+
// Would be nice to be able to clean up lora_tensors, but it breaks because this is called twice :/
456+
// lora_tensors.erase(split_q_u_name);
457+
// lora_tensors.erase(split_k_u_name);
458+
// lora_tensors.erase(split_v_u_name);
464459

465-
applied_lora_tensors.insert(lora_down_name);
466-
applied_lora_tensors.insert(lora_up_name);
460+
// lora_tensors.erase(split_q_d_name);
461+
// lora_tensors.erase(split_k_d_name);
462+
// lora_tensors.erase(split_v_d_name);
467463
}
468464
} else if (txt_attn_proj != std::string::npos || img_attn_proj != std::string::npos) {
469465
size_t match = txt_attn_proj;
@@ -474,9 +470,9 @@ struct LoraModel : public GGMLRunner {
474470
}
475471
match--;
476472

477-
std::string lora_down_name = lora_pre[type] + key.substr(0, match) + new_name + lora_downs[type] + ".weight";
473+
lora_down_name = lora_pre[type] + key.substr(0, match) + new_name + lora_downs[type] + ".weight";
478474
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
479-
std::string lora_up_name = lora_pre[type] + key.substr(0, match) + new_name + lora_ups[type] + ".weight";
475+
lora_up_name = lora_pre[type] + key.substr(0, match) + new_name + lora_ups[type] + ".weight";
480476
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
481477
lora_up = lora_tensors[lora_up_name];
482478
}
@@ -507,9 +503,9 @@ struct LoraModel : public GGMLRunner {
507503
match = img_mlp_2;
508504
}
509505
match--;
510-
std::string lora_down_name = lora_pre[type] + key.substr(0, match) + prefix + suffix + lora_downs[type] + ".weight";
506+
lora_down_name = lora_pre[type] + key.substr(0, match) + prefix + suffix + lora_downs[type] + ".weight";
511507
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
512-
std::string lora_up_name = lora_pre[type] + key.substr(0, match) + prefix + suffix + lora_ups[type] + ".weight";
508+
lora_up_name = lora_pre[type] + key.substr(0, match) + prefix + suffix + lora_ups[type] + ".weight";
513509
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
514510
lora_up = lora_tensors[lora_up_name];
515511
}
@@ -530,9 +526,9 @@ struct LoraModel : public GGMLRunner {
530526
}
531527
match--;
532528

533-
std::string lora_down_name = lora_pre[type] + key.substr(0, match) + new_name + lora_downs[type] + ".weight";
529+
lora_down_name = lora_pre[type] + key.substr(0, match) + new_name + lora_downs[type] + ".weight";
534530
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
535-
std::string lora_up_name = lora_pre[type] + key.substr(0, match) + new_name + lora_ups[type] + ".weight";
531+
lora_up_name = lora_pre[type] + key.substr(0, match) + new_name + lora_ups[type] + ".weight";
536532
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
537533
lora_up = lora_tensors[lora_up_name];
538534
}
@@ -548,7 +544,7 @@ struct LoraModel : public GGMLRunner {
548544
}
549545

550546
if (lora_up == NULL || lora_down == NULL) {
551-
std::string lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
547+
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
552548
if (lora_tensors.find(lora_up_name) == lora_tensors.end()) {
553549
if (key == "model_diffusion_model_output_blocks_2_2_conv") {
554550
// fix for some sdxl lora, like lcm-lora-xl
@@ -557,9 +553,9 @@ struct LoraModel : public GGMLRunner {
557553
}
558554
}
559555

560-
std::string lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
561-
alpha_name = lora_pre[type] + key + ".alpha";
562-
scale_name = lora_pre[type] + key + ".scale";
556+
lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
557+
alpha_name = lora_pre[type] + key + ".alpha";
558+
scale_name = lora_pre[type] + key + ".scale";
563559

564560
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
565561
lora_up = lora_tensors[lora_up_name];
@@ -568,14 +564,14 @@ struct LoraModel : public GGMLRunner {
568564
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
569565
lora_down = lora_tensors[lora_down_name];
570566
}
571-
applied_lora_tensors.insert(lora_up_name);
572-
applied_lora_tensors.insert(lora_down_name);
573-
applied_lora_tensors.insert(alpha_name);
574-
applied_lora_tensors.insert(scale_name);
567+
}
568+
applied_lora_tensors.insert(lora_up_name);
569+
applied_lora_tensors.insert(lora_down_name);
570+
applied_lora_tensors.insert(alpha_name);
571+
applied_lora_tensors.insert(scale_name);
575572

576-
if (lora_up == NULL || lora_down == NULL) {
577-
continue;
578-
}
573+
if (lora_up == NULL || lora_down == NULL) {
574+
continue;
579575
}
580576
// calc_scale
581577
int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1];
@@ -622,10 +618,6 @@ struct LoraModel : public GGMLRunner {
622618
total_lora_tensors_count++;
623619
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
624620
LOG_WARN("unused lora tensor %s", kv.first.c_str());
625-
print_ggml_tensor(kv.second, true);
626-
if (kv.first.find("B") != std::string::npos) {
627-
exit(0);
628-
}
629621
} else {
630622
applied_lora_tensors_count++;
631623
}

0 commit comments

Comments
 (0)