Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
int32_t num_custom_embeddings = 0;
int32_t num_custom_embeddings_2 = 0;
std::vector<uint8_t> token_embed_custom;
std::vector<std::string> readed_embeddings;
std::map<std::string, std::pair<int, int>> embedding_pos_map;

FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend,
bool offload_params_to_cpu,
Expand Down Expand Up @@ -123,14 +123,17 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}

bool load_embedding(std::string embd_name, std::string embd_path, std::vector<int32_t>& bpe_tokens) {
// the order matters
ModelLoader model_loader;
if (!model_loader.init_from_file_and_convert_name(embd_path)) {
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
return false;
}
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
auto iter = embedding_pos_map.find(embd_name);
if (iter != embedding_pos_map.end()) {
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
for (int i = iter->second.first; i < iter->second.second; i++) {
bpe_tokens.push_back(text_model->model.vocab_size + i);
}
return true;
}
struct ggml_init_params params;
Expand Down Expand Up @@ -161,7 +164,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return true;
};
model_loader.load_tensors(on_load, 1);
readed_embeddings.push_back(embd_name);
int pos_start = num_custom_embeddings;
if (embd) {
int64_t hidden_size = text_model->model.hidden_size;
token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd));
Expand All @@ -188,6 +191,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
}
LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2);
}
int pos_end = num_custom_embeddings;
if (pos_end == pos_start) {
return false;
}
embedding_pos_map[embd_name] = std::pair{pos_start, pos_end};
return true;
}

Expand Down
Loading