11#include < stdarg.h>
22#include < algorithm>
33#include < atomic>
4+ #include < array>
45#include < chrono>
56#include < fstream>
67#include < functional>
@@ -1034,15 +1035,19 @@ bool is_safetensors_file(const std::string& file_path) {
10341035}
10351036
10361037bool ModelLoader::init_from_file (const std::string& file_path, const std::string& prefix) {
1038+ return init_from_file (file_path, prefix, 0 );
1039+ }
1040+
1041+ bool ModelLoader::init_from_file (const std::string& file_path, const std::string& prefix, int n_threads) {
10371042 if (is_directory (file_path)) {
10381043 LOG_INFO (" load %s using diffusers format" , file_path.c_str ());
1039- return init_from_diffusers_file (file_path, prefix);
1044+ return init_from_diffusers_file (file_path, prefix, n_threads );
10401045 } else if (is_gguf_file (file_path)) {
10411046 LOG_INFO (" load %s using gguf format" , file_path.c_str ());
10421047 return init_from_gguf_file (file_path, prefix);
10431048 } else if (is_safetensors_file (file_path)) {
10441049 LOG_INFO (" load %s using safetensors format" , file_path.c_str ());
1045- return init_from_safetensors_file (file_path, prefix);
1050+ return init_from_safetensors_file (file_path, prefix, n_threads );
10461051 } else if (is_zip_file (file_path)) {
10471052 LOG_INFO (" load %s using checkpoint format" , file_path.c_str ());
10481053 return init_from_ckpt_file (file_path, prefix);
@@ -1147,7 +1152,12 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
11471152}
11481153
11491154// https://huggingface.co/docs/safetensors/index
1155+
11501156bool ModelLoader::init_from_safetensors_file (const std::string& file_path, const std::string& prefix) {
1157+ return init_from_safetensors_file (file_path, prefix, 0 );
1158+ }
1159+
1160+ bool ModelLoader::init_from_safetensors_file (const std::string& file_path, const std::string& prefix, int n_threads_p) {
11511161 LOG_DEBUG (" init from '%s', prefix = '%s'" , file_path.c_str (), prefix.c_str ());
11521162 file_paths_.push_back (file_path);
11531163 size_t file_index = file_paths_.size () - 1 ;
@@ -1195,12 +1205,32 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11951205 return false ;
11961206 }
11971207
1208+
11981209 nlohmann::json header_ = nlohmann::json::parse (header_buf.data ());
11991210
1211+
1212+ struct SafetensorTask {
1213+ std::string name;
1214+ ggml_type type = GGML_TYPE_COUNT;
1215+ std::array<int64_t , SD_MAX_DIMS> ne{};
1216+ int n_dims = 0 ;
1217+ size_t offset = 0 ;
1218+ size_t tensor_data_size = 0 ;
1219+ bool is_bf16 = false ;
1220+ bool is_f8_e4m3 = false ;
1221+ bool is_f8_e5m2 = false ;
1222+ bool is_f64 = false ;
1223+ bool is_i64 = false ;
1224+ };
1225+
1226+ std::vector<SafetensorTask> tasks;
1227+ tasks.reserve (header_.size ());
1228+
1229+ size_t base_offset = ST_HEADER_SIZE_LEN + header_size_;
1230+
12001231 for (auto & item : header_.items ()) {
1201- std::string name = item.key ();
1232+ std::string name = item.key ();
12021233 nlohmann::json tensor_info = item.value ();
1203- // LOG_DEBUG("%s %s\n", name.c_str(), tensor_info.dump().c_str());
12041234
12051235 if (name == " __metadata__" ) {
12061236 continue ;
@@ -1210,96 +1240,177 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
12101240 continue ;
12111241 }
12121242
1213- std::string dtype = tensor_info[" dtype" ];
1214- nlohmann::json shape = tensor_info[" shape" ];
1215-
1243+ std::string dtype = tensor_info[" dtype" ];
12161244 if (dtype == " U8" ) {
12171245 continue ;
12181246 }
12191247
12201248 size_t begin = tensor_info[" data_offsets" ][0 ].get <size_t >();
1221- size_t end = tensor_info[" data_offsets" ][1 ].get <size_t >();
1249+ size_t end = tensor_info[" data_offsets" ][1 ].get <size_t >();
12221250
12231251 ggml_type type = str_to_ggml_type (dtype);
12241252 if (type == GGML_TYPE_COUNT) {
12251253 LOG_ERROR (" unsupported dtype '%s' (tensor '%s')" , dtype.c_str (), name.c_str ());
12261254 return false ;
12271255 }
12281256
1257+ nlohmann::json shape = tensor_info[" shape" ];
1258+
12291259 if (shape.size () > SD_MAX_DIMS) {
12301260 LOG_ERROR (" invalid tensor '%s'" , name.c_str ());
12311261 return false ;
12321262 }
12331263
1234- int n_dims = (int )shape.size ();
1235- int64_t ne[ SD_MAX_DIMS] = {1 , 1 , 1 , 1 , 1 };
1264+ int n_dims = (int )shape.size ();
1265+ std::array< int64_t , SD_MAX_DIMS> ne = {1 , 1 , 1 , 1 , 1 };
12361266 for (int i = 0 ; i < n_dims; i++) {
12371267 ne[i] = shape[i].get <int64_t >();
12381268 }
12391269
12401270 if (n_dims == 5 ) {
12411271 n_dims = 4 ;
1242- ne[0 ] = ne[0 ] * ne[1 ];
1243- ne[1 ] = ne[2 ];
1244- ne[2 ] = ne[3 ];
1245- ne[3 ] = ne[4 ];
1272+ ne[0 ] = ne[0 ] * ne[1 ];
1273+ ne[1 ] = ne[2 ];
1274+ ne[2 ] = ne[3 ];
1275+ ne[3 ] = ne[4 ];
12461276 }
12471277
1248- // ggml_n_dims returns 1 for scalars
12491278 if (n_dims == 0 ) {
12501279 n_dims = 1 ;
12511280 }
12521281
1253- if (!starts_with (name, prefix)) {
1254- name = prefix + name;
1282+ std::string full_name = name;
1283+ if (!starts_with (full_name, prefix)) {
1284+ full_name = prefix + full_name;
12551285 }
12561286
1257- TensorStorage tensor_storage (name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
1258- tensor_storage.reverse_ne ();
1287+ SafetensorTask task;
1288+ task.name = std::move (full_name);
1289+ task.type = type;
1290+ task.ne = ne;
1291+ task.n_dims = n_dims;
1292+ task.offset = base_offset + begin;
1293+ task.tensor_data_size = end - begin;
1294+ task.is_bf16 = (dtype == " BF16" );
1295+ task.is_f8_e4m3 = (dtype == " F8_E4M3" );
1296+ task.is_f8_e5m2 = (dtype == " F8_E5M2" );
1297+ task.is_f64 = (dtype == " F64" );
1298+ task.is_i64 = (dtype == " I64" );
12591299
1260- size_t tensor_data_size = end - begin;
1261-
1262- if (dtype == " BF16" ) {
1263- tensor_storage.is_bf16 = true ;
1264- GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1265- } else if (dtype == " F8_E4M3" ) {
1266- tensor_storage.is_f8_e4m3 = true ;
1267- // f8 -> f16
1268- GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1269- } else if (dtype == " F8_E5M2" ) {
1270- tensor_storage.is_f8_e5m2 = true ;
1271- // f8 -> f16
1272- GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1273- } else if (dtype == " F64" ) {
1274- tensor_storage.is_f64 = true ;
1275- // f64 -> f32
1276- GGML_ASSERT (tensor_storage.nbytes () * 2 == tensor_data_size);
1277- } else if (dtype == " I64" ) {
1278- tensor_storage.is_i64 = true ;
1279- // i64 -> i32
1280- GGML_ASSERT (tensor_storage.nbytes () * 2 == tensor_data_size);
1281- } else {
1282- GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size);
1283- }
1300+ tasks.push_back (std::move (task));
1301+ }
12841302
1285- tensor_storages.push_back (tensor_storage);
1286- add_preprocess_tensor_storage_types (tensor_storages_types, tensor_storage.name , tensor_storage.type );
1303+ if (tasks.empty ()) {
1304+ return true ;
1305+ }
12871306
1288- // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
1307+ int num_threads_to_use = n_threads_p > 0 ? n_threads_p : (int )std::thread::hardware_concurrency ();
1308+ if (num_threads_to_use < 1 ) {
1309+ num_threads_to_use = 1 ;
12891310 }
1311+ int n_threads = std::min (num_threads_to_use, (int )tasks.size ());
1312+ if (n_threads < 1 ) {
1313+ n_threads = 1 ;
1314+ }
1315+
1316+ std::vector<TensorStorage> processed (tasks.size ());
1317+
1318+ std::vector<std::thread> workers;
1319+ workers.reserve (n_threads);
1320+
1321+ for (int i = 0 ; i < n_threads; ++i) {
1322+ workers.emplace_back ([&, thread_id = i]() {
1323+ for (size_t idx = thread_id; idx < tasks.size (); idx += n_threads) {
1324+ const auto & task = tasks[idx];
1325+
1326+ TensorStorage tensor_storage (task.name , task.type , task.ne .data (), task.n_dims , file_index, task.offset );
1327+ tensor_storage.reverse_ne ();
1328+
1329+ tensor_storage.is_bf16 = task.is_bf16 ;
1330+ tensor_storage.is_f8_e4m3 = task.is_f8_e4m3 ;
1331+ tensor_storage.is_f8_e5m2 = task.is_f8_e5m2 ;
1332+ tensor_storage.is_f64 = task.is_f64 ;
1333+ tensor_storage.is_i64 = task.is_i64 ;
1334+
1335+ size_t tensor_data_size = task.tensor_data_size ;
1336+
1337+ if (tensor_storage.is_bf16 ) {
1338+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1339+ } else if (tensor_storage.is_f8_e4m3 ) {
1340+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1341+ } else if (tensor_storage.is_f8_e5m2 ) {
1342+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size * 2 );
1343+ } else if (tensor_storage.is_f64 ) {
1344+ GGML_ASSERT (tensor_storage.nbytes () * 2 == tensor_data_size);
1345+ } else if (tensor_storage.is_i64 ) {
1346+ GGML_ASSERT (tensor_storage.nbytes () * 2 == tensor_data_size);
1347+ } else {
1348+ GGML_ASSERT (tensor_storage.nbytes () == tensor_data_size);
1349+ }
1350+
1351+ processed[idx] = std::move (tensor_storage);
1352+ }
1353+ });
1354+ }
1355+
1356+ for (auto & worker : workers) {
1357+ worker.join ();
1358+ }
1359+
1360+
1361+ const size_t prior_size = tensor_storages.size ();
1362+ tensor_storages.resize (prior_size + processed.size ());
1363+
1364+ int append_threads = std::min (num_threads_to_use, (int )processed.size ());
1365+ if (append_threads < 1 ) {
1366+ append_threads = 1 ;
1367+ }
1368+
1369+ std::vector<String2GGMLType> local_types (append_threads);
1370+ std::vector<std::thread> append_workers;
1371+ append_workers.reserve (append_threads);
1372+
1373+ for (int thread_id = 0 ; thread_id < append_threads; ++thread_id) {
1374+ append_workers.emplace_back ([&, thread_id]() {
1375+ auto & local_map = local_types[thread_id];
1376+ for (size_t idx = thread_id; idx < processed.size (); idx += append_threads) {
1377+ size_t target_index = prior_size + idx;
1378+ tensor_storages[target_index] = std::move (processed[idx]);
1379+ add_preprocess_tensor_storage_types (local_map,
1380+ tensor_storages[target_index].name ,
1381+ tensor_storages[target_index].type );
1382+ }
1383+ });
1384+ }
1385+
1386+ for (auto & worker : append_workers) {
1387+ worker.join ();
1388+ }
1389+
1390+ for (auto & local_map : local_types) {
1391+ for (auto & kv : local_map) {
1392+ tensor_storages_types[kv.first ] = kv.second ;
1393+ }
1394+ }
1395+
1396+ processed.clear ();
1397+ processed.shrink_to_fit ();
12901398
12911399 return true ;
12921400}
1293-
12941401/* ================================================= DiffusersModelLoader ==================================================*/
12951402
12961403bool ModelLoader::init_from_diffusers_file (const std::string& file_path, const std::string& prefix) {
1404+ return init_from_diffusers_file (file_path, prefix, 0 );
1405+ }
1406+
1407+ bool ModelLoader::init_from_diffusers_file (const std::string& file_path, const std::string& prefix, int n_threads) {
12971408 std::string unet_path = path_join (file_path, " unet/diffusion_pytorch_model.safetensors" );
12981409 std::string vae_path = path_join (file_path, " vae/diffusion_pytorch_model.safetensors" );
12991410 std::string clip_path = path_join (file_path, " text_encoder/model.safetensors" );
13001411 std::string clip_g_path = path_join (file_path, " text_encoder_2/model.safetensors" );
13011412
1302- if (!init_from_safetensors_file (unet_path, " unet." )) {
1413+ if (!init_from_safetensors_file (unet_path, " unet." , n_threads )) {
13031414 return false ;
13041415 }
13051416 for (auto ts : tensor_storages) {
@@ -1323,15 +1434,15 @@ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const s
13231434 }
13241435 }
13251436
1326- if (!init_from_safetensors_file (vae_path, " vae." )) {
1437+ if (!init_from_safetensors_file (vae_path, " vae." , n_threads )) {
13271438 LOG_WARN (" Couldn't find working VAE in %s" , file_path.c_str ());
13281439 // return false;
13291440 }
1330- if (!init_from_safetensors_file (clip_path, " te." )) {
1441+ if (!init_from_safetensors_file (clip_path, " te." , n_threads )) {
13311442 LOG_WARN (" Couldn't find working text encoder in %s" , file_path.c_str ());
13321443 // return false;
13331444 }
1334- if (!init_from_safetensors_file (clip_g_path, " te.1." )) {
1445+ if (!init_from_safetensors_file (clip_g_path, " te.1." , n_threads )) {
13351446 LOG_DEBUG (" Couldn't find working second text encoder in %s" , file_path.c_str ());
13361447 }
13371448 return true ;
0 commit comments