@@ -1397,10 +1397,11 @@ ggml_type ModelLoader::get_sd_wtype() {
13971397 continue ;
13981398 }
13991399
1400- if (tensor_storage.name .find (" .weight" ) != std::string::npos &&
1401- (tensor_storage.name .find (" time_embed" ) != std::string::npos ||
1402- tensor_storage.name .find (" context_embedder" ) != std::string::npos ||
1403- tensor_storage.name .find (" time_in" ) != std::string::npos)) {
1400+ if (ggml_is_quantized (tensor_storage.type )) {
1401+ return tensor_storage.type ;
1402+ }
1403+
1404+ if (tensor_should_be_converted (tensor_storage, GGML_TYPE_Q4_K)) {
14041405 return tensor_storage.type ;
14051406 }
14061407 }
@@ -1420,7 +1421,11 @@ ggml_type ModelLoader::get_conditioner_wtype() {
14201421 continue ;
14211422 }
14221423
1423- if (tensor_storage.name .find (" .weight" ) != std::string::npos) {
1424+ if (ggml_is_quantized (tensor_storage.type )) {
1425+ return tensor_storage.type ;
1426+ }
1427+
1428+ if (tensor_should_be_converted (tensor_storage, GGML_TYPE_Q4_K)) {
14241429 return tensor_storage.type ;
14251430 }
14261431 }
@@ -1437,10 +1442,11 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
14371442 continue ;
14381443 }
14391444
1440- if (tensor_storage.name .find (" .weight" ) != std::string::npos &&
1441- (tensor_storage.name .find (" time_embed" ) != std::string::npos ||
1442- tensor_storage.name .find (" context_embedder" ) != std::string::npos ||
1443- tensor_storage.name .find (" time_in" ) != std::string::npos)) {
1445+ if (ggml_is_quantized (tensor_storage.type )) {
1446+ return tensor_storage.type ;
1447+ }
1448+
1449+ if (tensor_should_be_converted (tensor_storage, GGML_TYPE_Q4_K)) {
14441450 return tensor_storage.type ;
14451451 }
14461452 }
@@ -1458,7 +1464,11 @@ ggml_type ModelLoader::get_vae_wtype() {
14581464 continue ;
14591465 }
14601466
1461- if (tensor_storage.name .find (" .weight" )) {
1467+ if (ggml_is_quantized (tensor_storage.type )) {
1468+ return tensor_storage.type ;
1469+ }
1470+
1471+ if (tensor_should_be_converted (tensor_storage, GGML_TYPE_Q4_K)) {
14621472 return tensor_storage.type ;
14631473 }
14641474 }
@@ -1723,6 +1733,26 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
17231733 return true ;
17241734}
17251735
1736+ bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
1737+ const std::string& name = tensor_storage.name ;
1738+ if (type != GGML_TYPE_COUNT) {
1739+ if (ggml_is_quantized (type) && tensor_storage.ne [0 ] % ggml_blck_size (type) != 0 ) {
1740+ // Pass, do not convert
1741+ } else if (ends_with (name, " .bias" )) {
1742+ // Pass, do not convert
1743+ } else if (contains (name, " img_in." ) || contains (name, " time_in.in_layer." ) || contains (name, " vector_in.in_layer." ) || contains (name, " guidance_in.in_layer." ) || contains (name, " final_layer.linear." )) {
1744+ // Pass, do not convert. For FLUX
1745+ } else if (contains (name, " x_embedder." ) || contains (name, " t_embedder." ) || contains (name, " y_embedder." ) || contains (name, " context_embedder." )) {
1746+ // Pass, do not convert. For MMDiT
1747+ } else if (contains (name, " time_embed." ) || contains (name, " label_emb." )) {
1748+ // Pass, do not convert. For Unet
1749+ } else {
1750+ return true ;
1751+ }
1752+ }
1753+ return false ;
1754+ }
1755+
17261756bool ModelLoader::save_to_gguf_file (const std::string& file_path, ggml_type type) {
17271757 auto backend = ggml_backend_cpu_init ();
17281758 size_t mem_size = 1 * 1024 * 1024 ; // for padding
@@ -1737,12 +1767,8 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
17371767 const std::string& name = tensor_storage.name ;
17381768
17391769 ggml_type tensor_type = tensor_storage.type ;
1740- if (type != GGML_TYPE_COUNT) {
1741- if (ggml_is_quantized (type) && tensor_storage.ne [0 ] % ggml_blck_size (type) != 0 ) {
1742- tensor_type = GGML_TYPE_F16;
1743- } else {
1744- tensor_type = type;
1745- }
1770+ if (tensor_should_be_converted (tensor_storage, type)) {
1771+ tensor_type = type;
17461772 }
17471773
17481774 ggml_tensor* tensor = ggml_new_tensor (ggml_ctx, tensor_type, tensor_storage.n_dims , tensor_storage.ne );
@@ -1792,15 +1818,9 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
17921818 }
17931819
17941820 for (auto & tensor_storage : processed_tensor_storages) {
1795- ggml_type tensor_type = tensor_storage.type ;
1796- if (type != GGML_TYPE_COUNT) {
1797- if (ggml_is_quantized (type) && tensor_storage.ne [0 ] % 32 != 0 ) {
1798- tensor_type = GGML_TYPE_F16;
1799- } else {
1800- tensor_type = type;
1801- }
1821+ if (tensor_should_be_converted (tensor_storage, type)) {
1822+ tensor_storage.type = type;
18021823 }
1803- tensor_storage.type = tensor_type;
18041824 mem_size += tensor_storage.nbytes () + alignment;
18051825 }
18061826
0 commit comments