Skip to content

Commit 0b600d7

Browse files
committed
Flux Lora: single_block
1 parent f8d33f3 commit 0b600d7

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

lora.hpp

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ struct LoraModel : public GGMLRunner {
8787
break;
8888
}
8989
}
90+
// if (name.find(".transformer_blocks.0") != std::string::npos) {
91+
// LOG_INFO("%s", name.c_str());
92+
// }
9093

9194
if (dry_run) {
9295
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
@@ -104,7 +107,7 @@ struct LoraModel : public GGMLRunner {
104107

105108
model_loader.load_tensors(on_new_tensor_cb, backend);
106109
alloc_params_buffer();
107-
110+
// exit(0);
108111
dry_run = false;
109112
model_loader.load_tensors(on_new_tensor_cb, backend);
110113

@@ -171,32 +174,34 @@ struct LoraModel : public GGMLRunner {
171174
ggml_tensor* lora_down = NULL;
172175
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
173176
if (sd_version_is_flux(version)) {
174-
size_t l1 = key.find("linear1");
177+
size_t l1 = key.find("linear1");
178+
size_t l2 = key.find("linear2");
179+
size_t mod = key.find("modulation.lin");
175180
if (l1 != std::string::npos) {
176-
l1 -= 1;
177-
auto split_q_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_q" + lora_ups[type] + ".weight";
178-
if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) {
181+
l1--;
182+
auto split_q_d_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_q" + lora_downs[type] + ".weight";
183+
if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) {
179184
// print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1]
180185
// find qkv and mlp up parts in LoRA model
181-
auto split_k_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_k" + lora_ups[type] + ".weight";
182-
auto split_v_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_v" + lora_ups[type] + ".weight";
183-
184-
auto split_q_d_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_q" + lora_downs[type] + ".weight";
185186
auto split_k_d_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_k" + lora_downs[type] + ".weight";
186187
auto split_v_d_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_v" + lora_downs[type] + ".weight";
187188

188-
auto split_m_u_name = lora_pre[type] + key.substr(0, l1) + ".proj_mlp" + lora_ups[type] + ".weight";
189+
auto split_q_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_q" + lora_ups[type] + ".weight";
190+
auto split_k_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_k" + lora_ups[type] + ".weight";
191+
auto split_v_u_name = lora_pre[type] + key.substr(0, l1) + ".attn.to_v" + lora_ups[type] + ".weight";
192+
189193
auto split_m_d_name = lora_pre[type] + key.substr(0, l1) + ".proj_mlp" + lora_downs[type] + ".weight";
194+
auto split_m_u_name = lora_pre[type] + key.substr(0, l1) + ".proj_mlp" + lora_ups[type] + ".weight";
190195

191-
ggml_tensor* lora_q_up = NULL;
192196
ggml_tensor* lora_q_down = NULL;
193-
ggml_tensor* lora_k_up = NULL;
197+
ggml_tensor* lora_q_up = NULL;
194198
ggml_tensor* lora_k_down = NULL;
195-
ggml_tensor* lora_v_up = NULL;
199+
ggml_tensor* lora_k_up = NULL;
196200
ggml_tensor* lora_v_down = NULL;
201+
ggml_tensor* lora_v_up = NULL;
197202

198-
ggml_tensor* lora_m_up = NULL;
199203
ggml_tensor* lora_m_down = NULL;
204+
ggml_tensor* lora_m_up = NULL;
200205

201206
lora_q_up = lora_tensors[split_q_u_name];
202207

@@ -301,6 +306,38 @@ struct LoraModel : public GGMLRunner {
301306
// // print_ggml_tensor(it.second, true); // [3072, 21504, 1, 1]
302307
// }
303308
}
309+
} else if (l2 != std::string::npos) {
310+
l2--;
311+
std::string lora_down_name = lora_pre[type] + key.substr(0, l2) + ".proj_out" + lora_downs[type] + ".weight";
312+
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
313+
std::string lora_up_name = lora_pre[type] + key.substr(0, l2) + ".proj_out" + lora_ups[type] + ".weight";
314+
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
315+
lora_up = lora_tensors[lora_up_name];
316+
}
317+
318+
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
319+
lora_down = lora_tensors[lora_down_name];
320+
}
321+
322+
applied_lora_tensors.insert(lora_up_name);
323+
applied_lora_tensors.insert(lora_down_name);
324+
}
325+
} else if (mod != std::string::npos) {
326+
mod--;
327+
std::string lora_down_name = lora_pre[type] + key.substr(0, mod) + ".norm.linear" + lora_downs[type] + ".weight";
328+
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
329+
std::string lora_up_name = lora_pre[type] + key.substr(0, mod) + ".norm.linear" + lora_ups[type] + ".weight";
330+
if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
331+
lora_up = lora_tensors[lora_up_name];
332+
}
333+
334+
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
335+
lora_down = lora_tensors[lora_down_name];
336+
}
337+
338+
applied_lora_tensors.insert(lora_up_name);
339+
applied_lora_tensors.insert(lora_down_name);
340+
}
304341
}
305342
}
306343

@@ -380,8 +417,10 @@ struct LoraModel : public GGMLRunner {
380417
total_lora_tensors_count++;
381418
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
382419
LOG_WARN("unused lora tensor %s", kv.first.c_str());
383-
exit(0);
384-
420+
print_ggml_tensor(kv.second, true);
421+
if (kv.first.find("B") != std::string::npos) {
422+
exit(0);
423+
}
385424
} else {
386425
applied_lora_tensors_count++;
387426
}

0 commit comments

Comments
 (0)