Skip to content

Commit dc32392

Browse files
committed
Fix: preprocess tensor names in tensor types map
1 parent 15d2799 commit dc32392

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

model.cpp

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,26 @@ std::string convert_tensor_name(std::string name) {
559559
return new_name;
560560
}
561561

562+
void add_preprocess_tensor_storage_types(std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
563+
std::string new_name = convert_tensor_name(name);
564+
565+
if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) {
566+
size_t prefix_size = new_name.find("attn.in_proj_weight");
567+
std::string prefix = new_name.substr(0, prefix_size);
568+
tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type;
569+
tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type;
570+
tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type;
571+
} else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) {
572+
size_t prefix_size = new_name.find("attn.in_proj_bias");
573+
std::string prefix = new_name.substr(0, prefix_size);
574+
tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type;
575+
tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type;
576+
tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type;
577+
} else {
578+
tensor_storages_types[new_name] = type;
579+
}
580+
}
581+
562582
void preprocess_tensor(TensorStorage tensor_storage,
563583
std::vector<TensorStorage>& processed_tensor_storages) {
564584
std::vector<TensorStorage> result;
@@ -920,7 +940,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
920940
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
921941

922942
tensor_storages.push_back(tensor_storage);
923-
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
943+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
924944
}
925945

926946
gguf_free(ctx_gguf_);
@@ -1071,7 +1091,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10711091
}
10721092

10731093
tensor_storages.push_back(tensor_storage);
1074-
tensor_storages_types[tensor_storage.name] = tensor_storage.type;
1094+
add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type);
10751095

10761096
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
10771097
}
@@ -1402,7 +1422,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
14021422
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
14031423
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
14041424
tensor_storages.push_back(reader.tensor_storage);
1405-
tensor_storages_types[reader.tensor_storage.name] = reader.tensor_storage.type;
1425+
add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type);
14061426

14071427
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
14081428
// reset
@@ -1634,11 +1654,20 @@ ggml_type ModelLoader::get_vae_wtype() {
16341654
void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
16351655
for (auto& pair : tensor_storages_types) {
16361656
if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) {
1657+
bool found = false;
16371658
for (auto& tensor_storage : tensor_storages) {
1638-
if (tensor_storage.name == pair.first) {
1639-
if (tensor_should_be_converted(tensor_storage, wtype)) {
1640-
pair.second = wtype;
1659+
std::map<std::string, ggml_type> temp;
1660+
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
1661+
for (auto& preprocessed_name : temp) {
1662+
if (preprocessed_name.first == pair.first) {
1663+
if (tensor_should_be_converted(tensor_storage, wtype)) {
1664+
pair.second = wtype;
1665+
}
1666+
found = true;
1667+
break;
16411668
}
1669+
}
1670+
if (found) {
16421671
break;
16431672
}
16441673
}

0 commit comments

Comments
 (0)