Skip to content

Commit 6a46206

Browse files
committed
add chroma radiance support
1 parent d05e46c commit 6a46206

File tree

6 files changed

+603
-223
lines changed

6 files changed

+603
-223
lines changed

flux.hpp

Lines changed: 459 additions & 111 deletions
Large diffs are not rendered by default.

model.cpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,7 +1766,6 @@ bool ModelLoader::model_is_unet() {
17661766

17671767
SDVersion ModelLoader::get_sd_version() {
17681768
TensorStorage token_embedding_weight, input_block_weight;
1769-
bool input_block_checked = false;
17701769

17711770
bool has_multiple_encoders = false;
17721771
bool is_unet = false;
@@ -1778,12 +1777,12 @@ SDVersion ModelLoader::get_sd_version() {
17781777
bool has_img_emb = false;
17791778

17801779
for (auto& tensor_storage : tensor_storages) {
1781-
if (!(is_xl || is_flux)) {
1780+
if (!(is_xl)) {
17821781
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
17831782
is_flux = true;
1784-
if (input_block_checked) {
1785-
break;
1786-
}
1783+
}
1784+
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
1785+
return VERSION_CHROMA_RADIANCE;
17871786
}
17881787
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
17891788
return VERSION_SD3;
@@ -1800,22 +1799,19 @@ SDVersion ModelLoader::get_sd_version() {
18001799
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
18011800
has_img_emb = true;
18021801
}
1803-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
1802+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos ||
1803+
tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
18041804
is_unet = true;
18051805
if (has_multiple_encoders) {
18061806
is_xl = true;
1807-
if (input_block_checked) {
1808-
break;
1809-
}
18101807
}
18111808
}
1812-
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
1809+
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos ||
1810+
tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
1811+
tensor_storage.name.find("te.1") != std::string::npos) {
18131812
has_multiple_encoders = true;
18141813
if (is_unet) {
18151814
is_xl = true;
1816-
if (input_block_checked) {
1817-
break;
1818-
}
18191815
}
18201816
}
18211817
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
@@ -1831,12 +1827,10 @@ SDVersion ModelLoader::get_sd_version() {
18311827
token_embedding_weight = tensor_storage;
18321828
// break;
18331829
}
1834-
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
1835-
input_block_weight = tensor_storage;
1836-
input_block_checked = true;
1837-
if (is_xl || is_flux) {
1838-
break;
1839-
}
1830+
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" ||
1831+
tensor_storage.name == "model.diffusion_model.img_in.weight" ||
1832+
tensor_storage.name == "unet.conv_in.weight") {
1833+
input_block_weight = tensor_storage;
18401834
}
18411835
}
18421836
if (is_wan) {

model.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ enum SDVersion {
3434
VERSION_FLUX_FILL,
3535
VERSION_FLUX_CONTROLS,
3636
VERSION_FLEX_2,
37+
VERSION_CHROMA_RADIANCE,
3738
VERSION_WAN2,
3839
VERSION_WAN2_2_I2V,
3940
VERSION_WAN2_2_TI2V,
@@ -70,7 +71,11 @@ static inline bool sd_version_is_sd3(SDVersion version) {
7071
}
7172

7273
static inline bool sd_version_is_flux(SDVersion version) {
73-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
74+
if (version == VERSION_FLUX ||
75+
version == VERSION_FLUX_FILL ||
76+
version == VERSION_FLUX_CONTROLS ||
77+
version == VERSION_FLEX_2 ||
78+
version == VERSION_CHROMA_RADIANCE) {
7479
return true;
7580
}
7681
return false;

qwen_image.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ namespace Qwen {
649649

650650
static void load_from_file_and_test(const std::string& file_path) {
651651
// cuda q8: pass
652-
// cuda q8 fa: nan
652+
// cuda q8 fa: pass
653653
// ggml_backend_t backend = ggml_backend_cuda_init(0);
654654
ggml_backend_t backend = ggml_backend_cpu_init();
655655
ggml_type model_data_type = GGML_TYPE_Q8_0;

0 commit comments

Comments
 (0)