Skip to content

Commit 2a5be18

Browse files
committed
optimize lora loading
1 parent 5b40537 commit 2a5be18

File tree

3 files changed

+174
-60
lines changed

3 files changed

+174
-60
lines changed

model.cpp

Lines changed: 162 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <stdarg.h>
22
#include <algorithm>
33
#include <atomic>
4+
#include <array>
45
#include <chrono>
56
#include <fstream>
67
#include <functional>
@@ -1034,15 +1035,19 @@ bool is_safetensors_file(const std::string& file_path) {
10341035
}
10351036

10361037
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix) {
1038+
return init_from_file(file_path, prefix, 0);
1039+
}
1040+
1041+
bool ModelLoader::init_from_file(const std::string& file_path, const std::string& prefix, int n_threads) {
10371042
if (is_directory(file_path)) {
10381043
LOG_INFO("load %s using diffusers format", file_path.c_str());
1039-
return init_from_diffusers_file(file_path, prefix);
1044+
return init_from_diffusers_file(file_path, prefix, n_threads);
10401045
} else if (is_gguf_file(file_path)) {
10411046
LOG_INFO("load %s using gguf format", file_path.c_str());
10421047
return init_from_gguf_file(file_path, prefix);
10431048
} else if (is_safetensors_file(file_path)) {
10441049
LOG_INFO("load %s using safetensors format", file_path.c_str());
1045-
return init_from_safetensors_file(file_path, prefix);
1050+
return init_from_safetensors_file(file_path, prefix, n_threads);
10461051
} else if (is_zip_file(file_path)) {
10471052
LOG_INFO("load %s using checkpoint format", file_path.c_str());
10481053
return init_from_ckpt_file(file_path, prefix);
@@ -1147,7 +1152,12 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
11471152
}
11481153

11491154
// https://huggingface.co/docs/safetensors/index
1155+
11501156
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix) {
1157+
return init_from_safetensors_file(file_path, prefix, 0);
1158+
}
1159+
1160+
bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::string& prefix, int n_threads_p) {
11511161
LOG_DEBUG("init from '%s', prefix = '%s'", file_path.c_str(), prefix.c_str());
11521162
file_paths_.push_back(file_path);
11531163
size_t file_index = file_paths_.size() - 1;
@@ -1195,12 +1205,32 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11951205
return false;
11961206
}
11971207

1208+
11981209
nlohmann::json header_ = nlohmann::json::parse(header_buf.data());
11991210

1211+
1212+
struct SafetensorTask {
1213+
std::string name;
1214+
ggml_type type = GGML_TYPE_COUNT;
1215+
std::array<int64_t, SD_MAX_DIMS> ne{};
1216+
int n_dims = 0;
1217+
size_t offset = 0;
1218+
size_t tensor_data_size = 0;
1219+
bool is_bf16 = false;
1220+
bool is_f8_e4m3 = false;
1221+
bool is_f8_e5m2 = false;
1222+
bool is_f64 = false;
1223+
bool is_i64 = false;
1224+
};
1225+
1226+
std::vector<SafetensorTask> tasks;
1227+
tasks.reserve(header_.size());
1228+
1229+
size_t base_offset = ST_HEADER_SIZE_LEN + header_size_;
1230+
12001231
for (auto& item : header_.items()) {
1201-
std::string name = item.key();
1232+
std::string name = item.key();
12021233
nlohmann::json tensor_info = item.value();
1203-
// LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
12041234

12051235
if (name == "__metadata__") {
12061236
continue;
@@ -1210,96 +1240,177 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
12101240
continue;
12111241
}
12121242

1213-
std::string dtype = tensor_info["dtype"];
1214-
nlohmann::json shape = tensor_info["shape"];
1215-
1243+
std::string dtype = tensor_info["dtype"];
12161244
if (dtype == "U8") {
12171245
continue;
12181246
}
12191247

12201248
size_t begin = tensor_info["data_offsets"][0].get<size_t>();
1221-
size_t end = tensor_info["data_offsets"][1].get<size_t>();
1249+
size_t end = tensor_info["data_offsets"][1].get<size_t>();
12221250

12231251
ggml_type type = str_to_ggml_type(dtype);
12241252
if (type == GGML_TYPE_COUNT) {
12251253
LOG_ERROR("unsupported dtype '%s' (tensor '%s')", dtype.c_str(), name.c_str());
12261254
return false;
12271255
}
12281256

1257+
nlohmann::json shape = tensor_info["shape"];
1258+
12291259
if (shape.size() > SD_MAX_DIMS) {
12301260
LOG_ERROR("invalid tensor '%s'", name.c_str());
12311261
return false;
12321262
}
12331263

1234-
int n_dims = (int)shape.size();
1235-
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
1264+
int n_dims = (int)shape.size();
1265+
std::array<int64_t, SD_MAX_DIMS> ne = {1, 1, 1, 1, 1};
12361266
for (int i = 0; i < n_dims; i++) {
12371267
ne[i] = shape[i].get<int64_t>();
12381268
}
12391269

12401270
if (n_dims == 5) {
12411271
n_dims = 4;
1242-
ne[0] = ne[0] * ne[1];
1243-
ne[1] = ne[2];
1244-
ne[2] = ne[3];
1245-
ne[3] = ne[4];
1272+
ne[0] = ne[0] * ne[1];
1273+
ne[1] = ne[2];
1274+
ne[2] = ne[3];
1275+
ne[3] = ne[4];
12461276
}
12471277

1248-
// ggml_n_dims returns 1 for scalars
12491278
if (n_dims == 0) {
12501279
n_dims = 1;
12511280
}
12521281

1253-
if (!starts_with(name, prefix)) {
1254-
name = prefix + name;
1282+
std::string full_name = name;
1283+
if (!starts_with(full_name, prefix)) {
1284+
full_name = prefix + full_name;
12551285
}
12561286

1257-
TensorStorage tensor_storage(name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1258-
tensor_storage.reverse_ne();
1287+
SafetensorTask task;
1288+
task.name = std::move(full_name);
1289+
task.type = type;
1290+
task.ne = ne;
1291+
task.n_dims = n_dims;
1292+
task.offset = base_offset + begin;
1293+
task.tensor_data_size = end - begin;
1294+
task.is_bf16 = (dtype == "BF16");
1295+
task.is_f8_e4m3 = (dtype == "F8_E4M3");
1296+
task.is_f8_e5m2 = (dtype == "F8_E5M2");
1297+
task.is_f64 = (dtype == "F64");
1298+
task.is_i64 = (dtype == "I64");
12591299

1260-
size_t tensor_data_size = end - begin;
1261-
1262-
if (dtype == "BF16") {
1263-
tensor_storage.is_bf16 = true;
1264-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1265-
} else if (dtype == "F8_E4M3") {
1266-
tensor_storage.is_f8_e4m3 = true;
1267-
// f8 -> f16
1268-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1269-
} else if (dtype == "F8_E5M2") {
1270-
tensor_storage.is_f8_e5m2 = true;
1271-
// f8 -> f16
1272-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1273-
} else if (dtype == "F64") {
1274-
tensor_storage.is_f64 = true;
1275-
// f64 -> f32
1276-
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
1277-
} else if (dtype == "I64") {
1278-
tensor_storage.is_i64 = true;
1279-
// i64 -> i32
1280-
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
1281-
} else {
1282-
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
1283-
}
1300+
tasks.push_back(std::move(task));
1301+
}
12841302

1285-
tensor_storages.push_back(tensor_storage);
1286-
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
1303+
if (tasks.empty()) {
1304+
return true;
1305+
}
12871306

1288-
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
1307+
int num_threads_to_use = n_threads_p > 0 ? n_threads_p : (int)std::thread::hardware_concurrency();
1308+
if (num_threads_to_use < 1) {
1309+
num_threads_to_use = 1;
12891310
}
1311+
int n_threads = std::min(num_threads_to_use, (int)tasks.size());
1312+
if (n_threads < 1) {
1313+
n_threads = 1;
1314+
}
1315+
1316+
std::vector<TensorStorage> processed(tasks.size());
1317+
1318+
std::vector<std::thread> workers;
1319+
workers.reserve(n_threads);
1320+
1321+
for (int i = 0; i < n_threads; ++i) {
1322+
workers.emplace_back([&, thread_id = i]() {
1323+
for (size_t idx = thread_id; idx < tasks.size(); idx += n_threads) {
1324+
const auto& task = tasks[idx];
1325+
1326+
TensorStorage tensor_storage(task.name, task.type, task.ne.data(), task.n_dims, file_index, task.offset);
1327+
tensor_storage.reverse_ne();
1328+
1329+
tensor_storage.is_bf16 = task.is_bf16;
1330+
tensor_storage.is_f8_e4m3 = task.is_f8_e4m3;
1331+
tensor_storage.is_f8_e5m2 = task.is_f8_e5m2;
1332+
tensor_storage.is_f64 = task.is_f64;
1333+
tensor_storage.is_i64 = task.is_i64;
1334+
1335+
size_t tensor_data_size = task.tensor_data_size;
1336+
1337+
if (tensor_storage.is_bf16) {
1338+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1339+
} else if (tensor_storage.is_f8_e4m3) {
1340+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1341+
} else if (tensor_storage.is_f8_e5m2) {
1342+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
1343+
} else if (tensor_storage.is_f64) {
1344+
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
1345+
} else if (tensor_storage.is_i64) {
1346+
GGML_ASSERT(tensor_storage.nbytes() * 2 == tensor_data_size);
1347+
} else {
1348+
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
1349+
}
1350+
1351+
processed[idx] = std::move(tensor_storage);
1352+
}
1353+
});
1354+
}
1355+
1356+
for (auto& worker : workers) {
1357+
worker.join();
1358+
}
1359+
1360+
1361+
const size_t prior_size = tensor_storages.size();
1362+
tensor_storages.resize(prior_size + processed.size());
1363+
1364+
int append_threads = std::min(num_threads_to_use, (int)processed.size());
1365+
if (append_threads < 1) {
1366+
append_threads = 1;
1367+
}
1368+
1369+
std::vector<String2GGMLType> local_types(append_threads);
1370+
std::vector<std::thread> append_workers;
1371+
append_workers.reserve(append_threads);
1372+
1373+
for (int thread_id = 0; thread_id < append_threads; ++thread_id) {
1374+
append_workers.emplace_back([&, thread_id]() {
1375+
auto& local_map = local_types[thread_id];
1376+
for (size_t idx = thread_id; idx < processed.size(); idx += append_threads) {
1377+
size_t target_index = prior_size + idx;
1378+
tensor_storages[target_index] = std::move(processed[idx]);
1379+
add_preprocess_tensor_storage_types(local_map,
1380+
tensor_storages[target_index].name,
1381+
tensor_storages[target_index].type);
1382+
}
1383+
});
1384+
}
1385+
1386+
for (auto& worker : append_workers) {
1387+
worker.join();
1388+
}
1389+
1390+
for (auto& local_map : local_types) {
1391+
for (auto& kv : local_map) {
1392+
tensor_storages_types[kv.first] = kv.second;
1393+
}
1394+
}
1395+
1396+
processed.clear();
1397+
processed.shrink_to_fit();
12901398

12911399
return true;
12921400
}
1293-
12941401
/*================================================= DiffusersModelLoader ==================================================*/
12951402

12961403
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) {
1404+
return init_from_diffusers_file(file_path, prefix, 0);
1405+
}
1406+
1407+
bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix, int n_threads) {
12971408
std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors");
12981409
std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors");
12991410
std::string clip_path = path_join(file_path, "text_encoder/model.safetensors");
13001411
std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors");
13011412

1302-
if (!init_from_safetensors_file(unet_path, "unet.")) {
1413+
if (!init_from_safetensors_file(unet_path, "unet.", n_threads)) {
13031414
return false;
13041415
}
13051416
for (auto ts : tensor_storages) {
@@ -1323,15 +1434,15 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
13231434
}
13241435
}
13251436

1326-
if (!init_from_safetensors_file(vae_path, "vae.")) {
1437+
if (!init_from_safetensors_file(vae_path, "vae.", n_threads)) {
13271438
LOG_WARN("Couldn't find working VAE in %s", file_path.c_str());
13281439
// return false;
13291440
}
1330-
if (!init_from_safetensors_file(clip_path, "te.")) {
1441+
if (!init_from_safetensors_file(clip_path, "te.", n_threads)) {
13311442
LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str());
13321443
// return false;
13331444
}
1334-
if (!init_from_safetensors_file(clip_g_path, "te.1.")) {
1445+
if (!init_from_safetensors_file(clip_g_path, "te.1.", n_threads)) {
13351446
LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str());
13361447
}
13371448
return true;

model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,16 @@ class ModelLoader {
233233

234234
bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = "");
235235
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = "");
236+
bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix, int n_threads);
236237
bool init_from_ckpt_file(const std::string& file_path, const std::string& prefix = "");
237238
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = "");
239+
bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix, int n_threads);
238240

239241
public:
240242
String2GGMLType tensor_storages_types;
241243

242244
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
245+
bool init_from_file(const std::string& file_path, const std::string& prefix, int n_threads);
243246
bool model_is_unet();
244247
SDVersion get_sd_version();
245248
ggml_type get_sd_wtype();

0 commit comments

Comments
 (0)