Skip to content

Commit 1bdc767

Browse files
committed
feat: force using f32 for some layers
1 parent 79c9fe9 commit 1bdc767

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

flux.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Flux {
1313
struct MLPEmbedder : public UnaryBlock {
1414
public:
1515
MLPEmbedder(int64_t in_dim, int64_t hidden_dim) {
16-
blocks["in_layer"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, true));
16+
blocks["in_layer"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, true, true));
1717
blocks["out_layer"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, hidden_dim, true));
1818
}
1919

@@ -449,7 +449,7 @@ namespace Flux {
449449
int64_t patch_size,
450450
int64_t out_channels) {
451451
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
452-
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
452+
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true));
453453
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
454454
}
455455

@@ -634,13 +634,13 @@ namespace Flux {
634634
int64_t out_channels = params.in_channels;
635635
int64_t pe_dim = params.hidden_size / params.num_heads;
636636

637-
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size));
637+
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true, true));
638638
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
639639
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
640640
if (params.guidance_embed) {
641641
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
642642
}
643-
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size));
643+
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size, true, true));
644644

645645
for (int i = 0; i < params.depth; i++) {
646646
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,

ggml_extend.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,9 +1187,10 @@ class Linear : public UnaryBlock {
11871187
int64_t in_features;
11881188
int64_t out_features;
11891189
bool bias;
1190+
bool force_f32;
11901191

11911192
void init_params(struct ggml_context* ctx, ggml_type wtype) {
1192-
if (in_features % ggml_blck_size(wtype) != 0) {
1193+
if (in_features % ggml_blck_size(wtype) != 0 || force_f32) {
11931194
wtype = GGML_TYPE_F32;
11941195
}
11951196
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
@@ -1201,10 +1202,12 @@ class Linear : public UnaryBlock {
12011202
public:
12021203
Linear(int64_t in_features,
12031204
int64_t out_features,
1204-
bool bias = true)
1205+
bool bias = true,
1206+
bool force_f32 = false)
12051207
: in_features(in_features),
12061208
out_features(out_features),
1207-
bias(bias) {}
1209+
bias(bias),
1210+
force_f32(force_f32) {}
12081211

12091212
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
12101213
struct ggml_tensor* w = params["weight"];

mmdit.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ struct TimestepEmbedder : public GGMLBlock {
101101
TimestepEmbedder(int64_t hidden_size,
102102
int64_t frequency_embedding_size = 256)
103103
: frequency_embedding_size(frequency_embedding_size) {
104-
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size));
105-
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
104+
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(frequency_embedding_size, hidden_size, true, true));
105+
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
106106
}
107107

108108
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) {
@@ -125,8 +125,8 @@ struct VectorEmbedder : public GGMLBlock {
125125
public:
126126
VectorEmbedder(int64_t input_dim,
127127
int64_t hidden_size) {
128-
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(input_dim, hidden_size));
129-
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size));
128+
blocks["mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(input_dim, hidden_size, true, true));
129+
blocks["mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, hidden_size, true, true));
130130
}
131131

132132
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
@@ -423,7 +423,7 @@ struct FinalLayer : public GGMLBlock {
423423
int64_t out_channels) {
424424
// total_out_channels is always None
425425
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
426-
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
426+
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels, true, true));
427427
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
428428
}
429429

@@ -510,7 +510,7 @@ struct MMDiT : public GGMLBlock {
510510
blocks["y_embedder"] = std::shared_ptr<GGMLBlock>(new VectorEmbedder(adm_in_channels, hidden_size));
511511
}
512512

513-
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, 1536));
513+
blocks["context_embedder"] = std::shared_ptr<GGMLBlock>(new Linear(4096, 1536, true, true));
514514

515515
for (int i = 0; i < depth; i++) {
516516
blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new JointBlock(hidden_size,

model.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,9 +1740,17 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
17401740
// Pass, do not convert
17411741
} else if (ends_with(name, ".bias")) {
17421742
// Pass, do not convert
1743-
} else if (contains(name, "img_in.") || contains(name, "time_in.in_layer.") || contains(name, "vector_in.in_layer.") || contains(name, "guidance_in.in_layer.") || contains(name, "final_layer.linear.")) {
1743+
} else if (contains(name, "img_in.") ||
1744+
contains(name, "time_in.in_layer.") ||
1745+
contains(name, "vector_in.in_layer.") ||
1746+
contains(name, "guidance_in.in_layer.") ||
1747+
contains(name, "final_layer.linear.")) {
17441748
// Pass, do not convert. For FLUX
1745-
} else if (contains(name, "x_embedder.") || contains(name, "t_embedder.") || contains(name, "y_embedder.") || contains(name, "context_embedder.")) {
1749+
} else if (contains(name, "x_embedder.") ||
1750+
contains(name, "t_embedder.") ||
1751+
contains(name, "y_embedder.") ||
1752+
contains(name, "pos_embed") ||
1753+
contains(name, "context_embedder.")) {
17461754
// Pass, do not convert. For MMDiT
17471755
} else if (contains(name, "time_embed.") || contains(name, "label_emb.")) {
17481756
// Pass, do not convert. For Unet

0 commit comments

Comments
 (0)