Skip to content

Commit 4d70e16

Browse files
committed
add support for sd3.5 model
1 parent 0e440c3 commit 4d70e16

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,9 @@ SDVersion ModelLoader::get_sd_version() {
13731373
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
13741374
is_flux = true;
13751375
}
1376+
if (tensor_storage.name.find("joint_blocks.0.x_block.attn2.ln_q.weight") != std::string::npos) {
1377+
return VERSION_SD3_5_2B;
1378+
}
13761379
if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) {
13771380
return VERSION_SD3_5_8B;
13781381
}

model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ enum SDVersion {
2626
VERSION_FLUX_DEV,
2727
VERSION_FLUX_SCHNELL,
2828
VERSION_SD3_5_8B,
29+
VERSION_SD3_5_2B,
2930
VERSION_COUNT,
3031
};
3132

stable-diffusion.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ const char* model_version_to_str[] = {
3232
"SD3 2B",
3333
"Flux Dev",
3434
"Flux Schnell",
35-
"SD3.5 8B"};
35+
"SD3.5 8B",
36+
"SD3.5 2B"};
3637

3738
const char* sampling_methods_str[] = {
3839
"Euler A",
@@ -288,7 +289,7 @@ class StableDiffusionGGML {
288289
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
289290
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
290291
}
291-
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
292+
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
292293
scale_factor = 1.5305f;
293294
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
294295
scale_factor = 0.3611;
@@ -311,7 +312,7 @@ class StableDiffusionGGML {
311312
} else {
312313
clip_backend = backend;
313314
bool use_t5xxl = false;
314-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
315+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
315316
use_t5xxl = true;
316317
}
317318
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@@ -322,7 +323,7 @@ class StableDiffusionGGML {
322323
LOG_INFO("CLIP: Using CPU backend");
323324
clip_backend = ggml_backend_cpu_init();
324325
}
325-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
326+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
326327
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
327328
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
328329
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@@ -520,7 +521,7 @@ class StableDiffusionGGML {
520521
is_using_v_parameterization = true;
521522
}
522523

523-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
524+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
524525
LOG_INFO("running in FLOW mode");
525526
denoiser = std::make_shared<DiscreteFlowDenoiser>();
526527
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
@@ -948,7 +949,7 @@ class StableDiffusionGGML {
948949
if (use_tiny_autoencoder) {
949950
C = 4;
950951
} else {
951-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) {
952+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
952953
C = 32;
953954
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
954955
C = 32;
@@ -1281,7 +1282,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12811282
// Sample
12821283
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
12831284
int C = 4;
1284-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
1285+
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
12851286
C = 16;
12861287
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
12871288
C = 16;
@@ -1394,7 +1395,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
13941395

13951396
struct ggml_init_params params;
13961397
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
1397-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
1398+
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
13981399
params.mem_size *= 3;
13991400
}
14001401
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
@@ -1420,15 +1421,15 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14201421
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
14211422

14221423
int C = 4;
1423-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
1424+
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14241425
C = 16;
14251426
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
14261427
C = 16;
14271428
}
14281429
int W = width / 8;
14291430
int H = height / 8;
14301431
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
1431-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
1432+
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14321433
ggml_set_f32(init_latent, 0.0609f);
14331434
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {
14341435
ggml_set_f32(init_latent, 0.1159f);
@@ -1489,7 +1490,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
14891490

14901491
struct ggml_init_params params;
14911492
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
1492-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) {
1493+
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
14931494
params.mem_size *= 2;
14941495
}
14951496
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) {

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock {
457457
bool use_video_decoder = false,
458458
SDVersion version = VERSION_SD1)
459459
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
460-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
460+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
461461
dd_config.z_channels = 16;
462462
use_quant = false;
463463
}

0 commit comments

Comments
 (0)