@@ -559,6 +559,26 @@ std::string convert_tensor_name(std::string name) {
559559 return new_name;
560560}
561561
562+ void add_preprocess_tensor_storage_types (std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
563+ std::string new_name = convert_tensor_name (name);
564+
565+ if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_weight" )) {
566+ size_t prefix_size = new_name.find (" attn.in_proj_weight" );
567+ std::string prefix = new_name.substr (0 , prefix_size);
568+ tensor_storages_types[prefix + " self_attn.q_proj.weight" ] = type;
569+ tensor_storages_types[prefix + " self_attn.k_proj.weight" ] = type;
570+ tensor_storages_types[prefix + " self_attn.v_proj.weight" ] = type;
571+ } else if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_bias" )) {
572+ size_t prefix_size = new_name.find (" attn.in_proj_bias" );
573+ std::string prefix = new_name.substr (0 , prefix_size);
574+ tensor_storages_types[prefix + " self_attn.q_proj.bias" ] = type;
575+ tensor_storages_types[prefix + " self_attn.k_proj.bias" ] = type;
576+ tensor_storages_types[prefix + " self_attn.v_proj.bias" ] = type;
577+ } else {
578+ tensor_storages_types[new_name] = type;
579+ }
580+ }
581+
562582void preprocess_tensor (TensorStorage tensor_storage,
563583 std::vector<TensorStorage>& processed_tensor_storages) {
564584 std::vector<TensorStorage> result;
@@ -920,7 +940,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
920940 GGML_ASSERT (ggml_nbytes (dummy) == tensor_storage.nbytes ());
921941
922942 tensor_storages.push_back (tensor_storage);
923- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
943+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
924944 }
925945
926946 gguf_free (ctx_gguf_);
@@ -1071,7 +1091,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10711091 }
10721092
10731093 tensor_storages.push_back (tensor_storage);
1074- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
1094+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
10751095
10761096 // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
10771097 }
@@ -1402,7 +1422,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
14021422 // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
14031423 reader.tensor_storage .name = prefix + reader.tensor_storage .name ;
14041424 tensor_storages.push_back (reader.tensor_storage );
1405- tensor_storages_types[ reader.tensor_storage .name ] = reader.tensor_storage .type ;
1425+ add_preprocess_tensor_storage_types ( tensor_storages_types, reader.tensor_storage .name , reader.tensor_storage .type ) ;
14061426
14071427 // LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
14081428 // reset
@@ -1634,11 +1654,20 @@ ggml_type ModelLoader::get_vae_wtype() {
16341654void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
16351655 for (auto & pair : tensor_storages_types) {
16361656 if (prefix.size () < 1 || pair.first .substr (0 , prefix.size ()) == prefix) {
1657+ bool found = false ;
16371658 for (auto & tensor_storage : tensor_storages) {
1638- if (tensor_storage.name == pair.first ) {
1639- if (tensor_should_be_converted (tensor_storage, wtype)) {
1640- pair.second = wtype;
1659+ std::map<std::string, ggml_type> temp;
1660+ add_preprocess_tensor_storage_types (temp, tensor_storage.name , tensor_storage.type );
1661+ for (auto & preprocessed_name : temp) {
1662+ if (preprocessed_name.first == pair.first ) {
1663+ if (tensor_should_be_converted (tensor_storage, wtype)) {
1664+ pair.second = wtype;
1665+ }
1666+ found = true ;
1667+ break ;
16411668 }
1669+ }
1670+ if (found) {
16421671 break ;
16431672 }
16441673 }
0 commit comments