Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,6 @@

/*================================================== CLIPTokenizer ===================================================*/

__STATIC_INLINE__ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
std::regex re("<lora:([^:]+):([^>]+)>");
std::smatch matches;
std::unordered_map<std::string, float> filename2multiplier;

while (std::regex_search(text, matches, re)) {
std::string filename = matches[1].str();
float multiplier = std::stof(matches[2].str());

text = std::regex_replace(text, re, "", std::regex_constants::format_first_only);

if (multiplier == 0.f) {
continue;
}

if (filename2multiplier.find(filename) == filename2multiplier.end()) {
filename2multiplier[filename] = multiplier;
} else {
filename2multiplier[filename] += multiplier;
}
}

return std::make_pair(filename2multiplier, text);
}

__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
Expand Down
146 changes: 138 additions & 8 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ struct SDContextParams {
std::string lora_model_dir;

std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_array;
std::vector<sd_embedding_t> embedding_vec;

rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
Expand Down Expand Up @@ -952,13 +952,13 @@ struct SDContextParams {
}

sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
embedding_array.clear();
embedding_array.reserve(embedding_map.size());
embedding_vec.clear();
embedding_vec.reserve(embedding_map.size());
for (const auto& kv : embedding_map) {
sd_embedding_t item;
item.name = kv.first.c_str();
item.path = kv.second.c_str();
embedding_array.emplace_back(item);
embedding_vec.emplace_back(item);
}

sd_ctx_params_t sd_ctx_params = {
Expand All @@ -975,8 +975,8 @@ struct SDContextParams {
taesd_path.c_str(),
control_net_path.c_str(),
lora_model_dir.c_str(),
embedding_array.data(),
static_cast<uint32_t>(embedding_array.size()),
embedding_vec.data(),
static_cast<uint32_t>(embedding_vec.size()),
photo_maker_path.c_str(),
tensor_type_rules.c_str(),
vae_decode_only,
Expand Down Expand Up @@ -1030,6 +1030,15 @@ static std::string vec_str_to_string(const std::vector<std::string>& v) {
return oss.str();
}

static bool is_absolute_path(const std::string& p) {
#ifdef _WIN32
// Windows: C:/path or C:\path
return p.size() > 1 && std::isalpha(static_cast<unsigned char>(p[0])) && p[1] == ':';
#else
return !p.empty() && p[0] == '/';
#endif
}

struct SDGenerationParams {
std::string prompt;
std::string negative_prompt;
Expand Down Expand Up @@ -1072,6 +1081,10 @@ struct SDGenerationParams {

int upscale_repeats = 1;

std::map<std::string, float> lora_map;
std::map<std::string, float> high_noise_lora_map;
std::vector<sd_lora_t> lora_vec;

SDGenerationParams() {
sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params);
Expand Down Expand Up @@ -1442,7 +1455,88 @@ struct SDGenerationParams {
return options;
}

bool process_and_check(SDMode mode) {
void extract_and_remove_lora(const std::string& lora_model_dir) {
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
std::smatch m;

std::string tmp = prompt;

while (std::regex_search(tmp, m, re)) {
std::string raw_path = m[1].str();
const std::string raw_mul = m[2].str();

float mul = 0.f;
try {
mul = std::stof(raw_mul);
} catch (...) {
tmp = m.suffix().str();
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
continue;
}

bool is_high_noise = false;
static const std::string prefix = "|high_noise|";
if (raw_path.rfind(prefix, 0) == 0) {
raw_path.erase(0, prefix.size());
is_high_noise = true;
}

fs::path final_path;
if (is_absolute_path(raw_path)) {
final_path = raw_path;
} else {
final_path = fs::path(lora_model_dir) / raw_path;
}
if (!fs::exists(final_path)) {
bool found = false;
for (const auto& ext : valid_ext) {
fs::path try_path = final_path;
try_path += ext;
if (fs::exists(try_path)) {
final_path = try_path;
found = true;
break;
}
}
if (!found) {
printf("can not found lora %s\n", final_path.lexically_normal().string().c_str());
tmp = m.suffix().str();
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
continue;
}
}

const std::string key = final_path.lexically_normal().string();

if (is_high_noise)
high_noise_lora_map[key] += mul;
else
lora_map[key] += mul;

prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);

tmp = m.suffix().str();
}

for (const auto& kv : lora_map) {
sd_lora_t item;
item.is_high_noise = false;
item.path = kv.first.c_str();
item.multiplier = kv.second;
lora_vec.emplace_back(item);
}

for (const auto& kv : high_noise_lora_map) {
sd_lora_t item;
item.is_high_noise = true;
item.path = kv.first.c_str();
item.multiplier = kv.second;
lora_vec.emplace_back(item);
}
}

bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
if (width <= 0) {
fprintf(stderr, "error: the width must be greater than 0\n");
return false;
Expand Down Expand Up @@ -1553,14 +1647,44 @@ struct SDGenerationParams {
seed = rand();
}

extract_and_remove_lora(lora_model_dir);

return true;
}

std::string to_string() const {
char* sample_params_str = sd_sample_params_to_str(&sample_params);
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);

std::ostringstream lora_ss;
lora_ss << "{\n";
for (auto it = lora_map.begin(); it != lora_map.end(); ++it) {
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
if (std::next(it) != lora_map.end()) {
lora_ss << ",";
}
lora_ss << "\n";
}
lora_ss << " }";
std::string loras_str = lora_ss.str();

lora_ss = std::ostringstream();
;
lora_ss << "{\n";
for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) {
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
if (std::next(it) != high_noise_lora_map.end()) {
lora_ss << ",";
}
lora_ss << "\n";
}
lora_ss << " }";
std::string high_noise_loras_str = lora_ss.str();

std::ostringstream oss;
oss << "SDGenerationParams {\n"
<< " loras: \"" << loras_str << "\",\n"
<< " high_noise_loras: \"" << high_noise_loras_str << "\",\n"
<< " prompt: \"" << prompt << "\",\n"
<< " negative_prompt: \"" << negative_prompt << "\",\n"
<< " clip_skip: " << clip_skip << ",\n"
Expand Down Expand Up @@ -1626,7 +1750,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
exit(cli_params.normal_exit ? 0 : 1);
}

if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode)) {
if (!cli_params.process_and_check() ||
!ctx_params.process_and_check(cli_params.mode) ||
!gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) {
print_usage(argc, argv, options_vec);
exit(1);
}
Expand Down Expand Up @@ -2139,6 +2265,8 @@ int main(int argc, const char* argv[]) {

if (cli_params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
Expand Down Expand Up @@ -2170,6 +2298,8 @@ int main(int argc, const char* argv[]) {
num_results = gen_params.batch_count;
} else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
Expand Down
51 changes: 21 additions & 30 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,28 +937,17 @@ class StableDiffusionGGML {
float multiplier,
ggml_backend_t backend,
LoraModel::filter_t lora_tensor_filter = nullptr) {
std::string lora_name = lora_id;
std::string high_noise_tag = "|high_noise|";
bool is_high_noise = false;
if (starts_with(lora_name, high_noise_tag)) {
lora_name = lora_name.substr(high_noise_tag.size());
std::string lora_path = lora_id;
static std::string high_noise_tag = "|high_noise|";
bool is_high_noise = false;
if (starts_with(lora_path, high_noise_tag)) {
lora_path = lora_path.substr(high_noise_tag.size());
is_high_noise = true;
LOG_DEBUG("high noise lora: %s", lora_name.c_str());
}
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
std::string file_path;
if (file_exists(st_file_path)) {
file_path = st_file_path;
} else if (file_exists(ckpt_file_path)) {
file_path = ckpt_file_path;
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return nullptr;
LOG_DEBUG("high noise lora: %s", lora_path.c_str());
}
auto lora = std::make_shared<LoraModel>(lora_id, backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
auto lora = std::make_shared<LoraModel>(lora_id, backend, lora_path, is_high_noise ? "model.high_noise_" : "", version);
if (!lora->load_from_file(n_threads, lora_tensor_filter)) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
LOG_WARN("load lora tensors from %s failed", lora_path.c_str());
return nullptr;
}

Expand Down Expand Up @@ -1143,12 +1132,15 @@ class StableDiffusionGGML {
}
}

std::string apply_loras_from_prompt(const std::string& prompt) {
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier

for (auto& kv : lora_f2m) {
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
void apply_loras(const sd_lora_t* loras, uint32_t lora_count) {
std::unordered_map<std::string, float> lora_f2m;
for (int i = 0; i < lora_count; i++) {
std::string lora_id = SAFE_STR(loras[i].path);
if (loras[i].is_high_noise) {
lora_id = "|high_noise|" + lora_id;
}
lora_f2m[lora_id] = loras[i].multiplier;
LOG_DEBUG("lora %s:%.2f", lora_id.c_str(), loras[i].multiplier);
}
int64_t t0 = ggml_time_ms();
if (apply_lora_immediately) {
Expand All @@ -1159,9 +1151,7 @@ class StableDiffusionGGML {
int64_t t1 = ggml_time_ms();
if (!lora_f2m.empty()) {
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str());
}
return result_pair.second;
}

ggml_tensor* id_encoder(ggml_context* work_ctx,
Expand Down Expand Up @@ -2815,8 +2805,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int sample_steps = sigmas.size() - 1;

int64_t t0 = ggml_time_ms();
// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);

// Photo Maker
std::string prompt_text_only;
Expand Down Expand Up @@ -3188,6 +3176,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g

size_t t0 = ggml_time_ms();

// Apply lora
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);

enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
if (sample_method == SAMPLE_METHOD_COUNT) {
sample_method = sd_get_default_sample_method(sd_ctx);
Expand Down Expand Up @@ -3487,7 +3478,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
int64_t t0 = ggml_time_ms();

// Apply lora
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count);

ggml_tensor* init_latent = nullptr;
ggml_tensor* clip_vision_output = nullptr;
Expand Down
10 changes: 10 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,14 @@ typedef struct {
} sd_easycache_params_t;

typedef struct {
bool is_high_noise;
float multiplier;
const char* path;
} sd_lora_t;

typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt;
const char* negative_prompt;
int clip_skip;
Expand All @@ -265,6 +273,8 @@ typedef struct {
} sd_img_gen_params_t;

typedef struct {
const sd_lora_t* loras;
uint32_t lora_count;
const char* prompt;
const char* negative_prompt;
int clip_skip;
Expand Down
Loading
Loading