Skip to content

Commit 3d86c6c

Browse files
ngxsonggerganov
andauthored
model: support GLM4V vision encoder (#18042)
* convert ok * no deepstack * less new tensors * cgraph ok * add mrope for text model * faster patch merger * add GGML_ROPE_TYPE_MRNORM * add support for metal * move glm4v do dedicated graph * convert: add norm_embd * clip: add debugging fn * working correctly * fix style * use bicubic * fix mrope metal * improve cpu * convert to neox ordering on conversion * revert backend changes * force stop if using old weight * support moe variant * fix conversion * fix convert (2) * Update tools/mtmd/clip-graph.h Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * process mrope_section on TextModel base class * resolve conflict merge --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 9963b81 commit 3d86c6c

File tree

17 files changed

+413
-80
lines changed

17 files changed

+413
-80
lines changed

convert_hf_to_gguf.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,14 @@ def set_gguf_parameters(self):
862862
logger.warning(f"Unknown RoPE type: {rope_type}")
863863
logger.info(f"gguf: rope scaling type = {rope_gguf_type.name}")
864864

865+
if "mrope_section" in self.rope_parameters:
866+
mrope_section = self.rope_parameters["mrope_section"]
867+
# Pad to 4 dimensions [time, height, width, extra]
868+
while len(mrope_section) < 4:
869+
mrope_section.append(0)
870+
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
871+
logger.info(f"gguf: mrope sections: {mrope_section[:4]}")
872+
865873
if (rope_theta := rope_params.get("rope_theta")) is not None:
866874
self.gguf_writer.add_rope_freq_base(rope_theta)
867875
logger.info(f"gguf: rope theta = {rope_theta}")
@@ -3739,9 +3747,6 @@ class Qwen2VLModel(TextModel):
37393747

37403748
def set_gguf_parameters(self):
37413749
super().set_gguf_parameters()
3742-
mrope_section = self.hparams["rope_scaling"]["mrope_section"]
3743-
mrope_section += [0] * max(0, 4 - len(mrope_section))
3744-
self.gguf_writer.add_rope_dimension_sections(mrope_section)
37453750

37463751
def set_vocab(self):
37473752
try:
@@ -4377,6 +4382,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43774382
return super().modify_tensors(data_torch, name, bid)
43784383

43794384

4385+
@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration")
4386+
class Glm4VVisionModel(Qwen3VLVisionModel):
4387+
def set_gguf_parameters(self):
4388+
MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters
4389+
assert self.hparams_vision is not None
4390+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GLM4V)
4391+
4392+
hidden_act = str(self.hparams_vision.get("hidden_act", "")).lower()
4393+
if hidden_act == "gelu":
4394+
self.gguf_writer.add_vision_use_gelu(True)
4395+
elif hidden_act == "silu":
4396+
self.gguf_writer.add_vision_use_silu(True)
4397+
4398+
rms_norm_eps = self.hparams_vision.get("rms_norm_eps", 1e-5)
4399+
self.gguf_writer.add_vision_attention_layernorm_eps(rms_norm_eps)
4400+
4401+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
4402+
if name.startswith("model.visual."):
4403+
name = name.replace("model.visual.", "visual.")
4404+
if name.startswith("visual.merger."):
4405+
return [(self.map_tensor_name(name), data_torch)]
4406+
return super().modify_tensors(data_torch, name, bid)
4407+
4408+
43804409
@ModelBase.register("Qwen3VLForConditionalGeneration")
43814410
class Qwen3VLTextModel(Qwen3Model):
43824411
model_arch = gguf.MODEL_ARCH.QWEN3VL
@@ -4385,20 +4414,6 @@ def set_gguf_parameters(self):
43854414
super().set_gguf_parameters()
43864415

43874416
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4388-
text_config = self.hparams.get("text_config", {})
4389-
# rope_scaling is deprecated in V5, use rope_parameters instead
4390-
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4391-
4392-
if rope_scaling.get("mrope_section"):
4393-
# mrope_section contains [time, height, width] dimensions
4394-
mrope_section = rope_scaling["mrope_section"]
4395-
# Pad to 4 dimensions [time, height, width, extra]
4396-
while len(mrope_section) < 4:
4397-
mrope_section.append(0)
4398-
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4399-
4400-
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4401-
44024417
vision_config = self.hparams.get("vision_config", {})
44034418
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
44044419
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@@ -4417,22 +4432,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
44174432

44184433
def set_gguf_parameters(self):
44194434
super().set_gguf_parameters()
4420-
4421-
# Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4422-
text_config = self.hparams.get("text_config", {})
4423-
# rope_scaling is deprecated in V5, use rope_parameters instead
4424-
rope_scaling = text_config.get("rope_scaling") or text_config.get("rope_parameters") or {}
4425-
4426-
if rope_scaling.get("mrope_section"):
4427-
# mrope_section contains [time, height, width] dimensions
4428-
mrope_section = rope_scaling["mrope_section"]
4429-
# Pad to 4 dimensions [time, height, width, extra]
4430-
while len(mrope_section) < 4:
4431-
mrope_section.append(0)
4432-
self.gguf_writer.add_rope_dimension_sections(mrope_section[:4])
4433-
4434-
logger.info(f"MRoPE sections: {mrope_section[:4]}")
4435-
44364435
vision_config = self.hparams.get("vision_config", {})
44374436
deepstack_layer_num = len(vision_config.get("deepstack_visual_indexes", []))
44384437
self.gguf_writer.add_num_deepstack_layers(deepstack_layer_num)
@@ -7795,6 +7794,15 @@ def prepare_tensors(self):
77957794
@ModelBase.register("Glm4ForCausalLM", "Glm4vForConditionalGeneration")
77967795
class Glm4Model(TextModel):
77977796
model_arch = gguf.MODEL_ARCH.GLM4
7797+
use_mrope = False
7798+
partial_rotary_factor = 0.5
7799+
7800+
def __init__(self, *args, **kwargs):
7801+
super().__init__(*args, **kwargs)
7802+
self.partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 0.5)
7803+
if "mrope_section" in self.rope_parameters:
7804+
self.use_mrope = True
7805+
logger.info("Q/K weight will need to be permuted for M-RoPE")
77987806

77997807
def set_vocab(self):
78007808
from transformers import AutoTokenizer
@@ -7816,17 +7824,49 @@ def set_gguf_parameters(self):
78167824
super().set_gguf_parameters()
78177825
if (rope_dim := self.hparams.get("head_dim")) is None:
78187826
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
7819-
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
7827+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.partial_rotary_factor))
7828+
7829+
@staticmethod
7830+
def normal_to_neox(weights: Tensor, n_head: int, n_head_kv: int, head_dim: int, partial_rotary_factor: float) -> Tensor:
7831+
orig_shape = weights.shape
7832+
if len(orig_shape) == 1:
7833+
weights = weights.unsqueeze(1) # [out_dim, 1]
7834+
if len(weights.shape) != 2:
7835+
raise ValueError("Only 1D and 2D tensors are supported.")
7836+
n_effective_heads = weights.shape[0] // head_dim
7837+
if n_head_kv is not None and n_effective_heads != n_head:
7838+
if n_effective_heads != n_head_kv:
7839+
raise AssertionError(f"Mismatch in effective heads: computed {n_effective_heads}, expected {n_head} or {n_head_kv}")
7840+
rotary_dim = int(head_dim * partial_rotary_factor)
7841+
if rotary_dim % 2 != 0:
7842+
raise ValueError("rotary_dim must be even.")
7843+
reshaped = weights.reshape(n_effective_heads, head_dim, -1)
7844+
rot_part = reshaped[:, :rotary_dim, :]
7845+
non_rot_part = reshaped[:, rotary_dim:, :]
7846+
permuted_rot = torch.cat((rot_part[:, ::2, :], rot_part[:, 1::2, :]), dim=1)
7847+
combined = torch.cat((permuted_rot, non_rot_part), dim=1)
7848+
result = combined.reshape(weights.shape)
7849+
return result if len(orig_shape) != 1 else result.squeeze(1)
78207850

78217851
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
78227852
if name.startswith("model.visual."): # ignore visual part of Glm4v
78237853
return []
78247854
elif name.startswith("model.language_model."):
78257855
name = name.replace("language_model.", "") # for Glm4v
7856+
if self.use_mrope:
7857+
n_head = self.hparams["num_attention_heads"]
7858+
n_kv_head = self.hparams["num_key_value_heads"]
7859+
n_embd = self.hparams["hidden_size"]
7860+
head_dim = n_embd // n_head
7861+
# because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
7862+
if name.endswith(("q_proj.weight", "q_proj.bias")):
7863+
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor)
7864+
if name.endswith(("k_proj.weight", "k_proj.bias")):
7865+
data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_kv_head, head_dim, self.partial_rotary_factor)
78267866
return super().modify_tensors(data_torch, name, bid)
78277867

78287868

7829-
@ModelBase.register("Glm4MoeForCausalLM")
7869+
@ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration")
78307870
class Glm4MoeModel(TextModel):
78317871
model_arch = gguf.MODEL_ARCH.GLM4_MOE
78327872

@@ -7893,6 +7933,7 @@ def set_gguf_parameters(self):
78937933

78947934
_experts: list[dict[str, Tensor]] | None = None
78957935

7936+
# note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already
78967937
def modify_tensors(
78977938
self, data_torch: Tensor, name: str, bid: int | None
78987939
) -> Iterable[tuple[str, Tensor]]:

gguf-py/gguf/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ class MODEL_TENSOR(IntEnum):
643643
V_MMPROJ_PEG = auto()
644644
V_ENC_EMBD_CLS = auto()
645645
V_ENC_EMBD_PATCH = auto()
646+
V_ENC_EMBD_NORM = auto()
646647
V_ENC_EMBD_POS = auto()
647648
V_ENC_INPUT_NORM = auto()
648649
V_ENC_ATTN_QKV = auto()
@@ -661,6 +662,7 @@ class MODEL_TENSOR(IntEnum):
661662
V_LAYER_SCALE_2 = auto()
662663
V_PRE_NORM = auto()
663664
V_POST_NORM = auto()
665+
V_MM_POST_NORM = auto()
664666
V_MM_INP_NORM = auto()
665667
V_MM_INP_PROJ = auto() # gemma3
666668
V_MM_SOFT_EMB_NORM = auto() # gemma3
@@ -1016,6 +1018,7 @@ class MODEL_TENSOR(IntEnum):
10161018
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
10171019
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
10181020
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
1021+
MODEL_TENSOR.V_ENC_EMBD_NORM: "v.norm_embd",
10191022
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
10201023
MODEL_TENSOR.V_ENC_ATTN_QKV: "v.blk.{bid}.attn_qkv",
10211024
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
@@ -1034,6 +1037,7 @@ class MODEL_TENSOR(IntEnum):
10341037
MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2",
10351038
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
10361039
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
1040+
MODEL_TENSOR.V_MM_POST_NORM: "mm.post_norm",
10371041
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
10381042
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
10391043
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
@@ -1094,6 +1098,7 @@ class MODEL_TENSOR(IntEnum):
10941098
MODEL_TENSOR.V_MMPROJ_PEG,
10951099
MODEL_TENSOR.V_ENC_EMBD_CLS,
10961100
MODEL_TENSOR.V_ENC_EMBD_PATCH,
1101+
MODEL_TENSOR.V_ENC_EMBD_NORM,
10971102
MODEL_TENSOR.V_ENC_EMBD_POS,
10981103
MODEL_TENSOR.V_ENC_INPUT_NORM,
10991104
MODEL_TENSOR.V_ENC_ATTN_QKV,
@@ -1112,6 +1117,7 @@ class MODEL_TENSOR(IntEnum):
11121117
MODEL_TENSOR.V_LAYER_SCALE_2,
11131118
MODEL_TENSOR.V_PRE_NORM,
11141119
MODEL_TENSOR.V_POST_NORM,
1120+
MODEL_TENSOR.V_MM_POST_NORM,
11151121
MODEL_TENSOR.V_MM_INP_PROJ,
11161122
MODEL_TENSOR.V_MM_INP_NORM,
11171123
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
@@ -3357,6 +3363,7 @@ class VisionProjectorType:
33573363
LIGHTONOCR = "lightonocr"
33583364
COGVLM = "cogvlm"
33593365
JANUS_PRO = "janus_pro"
3366+
GLM4V = "glm4v"
33603367

33613368

33623369
# Items here are (block size, type size)

gguf-py/gguf/tensor_mapping.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,7 @@ class TensorNameMap:
12121212
MODEL_TENSOR.V_MMPROJ_FC: (
12131213
"model.connector.modality_projection.proj", # SmolVLM
12141214
"model.vision.linear_proj.linear_proj", # cogvlm
1215+
"visual.merger.proj", # glm4v
12151216
),
12161217

12171218
MODEL_TENSOR.V_MMPROJ_MLP: (
@@ -1245,6 +1246,10 @@ class TensorNameMap:
12451246
"model.vision.patch_embedding.proj", # cogvlm
12461247
),
12471248

1249+
MODEL_TENSOR.V_ENC_EMBD_NORM: (
1250+
"visual.post_conv_layernorm", # glm4v
1251+
),
1252+
12481253
MODEL_TENSOR.V_ENC_EMBD_POS: (
12491254
"vision_tower.vision_model.embeddings.position_embedding",
12501255
"model.vision_tower.embeddings.position_embeddings", # Intern-S1
@@ -1254,6 +1259,7 @@ class TensorNameMap:
12541259
"vision_tower.patch_embed.pos_emb", # kimi-vl
12551260
"visual.pos_embed", # qwen3vl
12561261
"model.vision.patch_embedding.position_embedding", # cogvlm
1262+
"visual.embeddings.position_embedding", # glm4v
12571263
),
12581264

12591265
MODEL_TENSOR.V_ENC_ATTN_QKV: (
@@ -1409,6 +1415,11 @@ class TensorNameMap:
14091415
"vision_model.layernorm_post", # llama4
14101416
"visual.merger.ln_q", # qwen2vl
14111417
"vision_tower.encoder.final_layernorm", # kimi-vl
1418+
"visual.post_layernorm", # glm4v
1419+
),
1420+
1421+
MODEL_TENSOR.V_MM_POST_NORM: (
1422+
"visual.merger.post_projection_norm", # glm4v
14121423
),
14131424

14141425
MODEL_TENSOR.V_MM_INP_PROJ: (
@@ -1478,6 +1489,7 @@ class TensorNameMap:
14781489
MODEL_TENSOR.V_MM_PATCH_MERGER: (
14791490
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 - hf
14801491
"patch_merger.merging_layer", # mistral
1492+
"visual.downsample", # glm4v
14811493
),
14821494

14831495
MODEL_TENSOR.V_DS_NORM: (
@@ -1498,14 +1510,17 @@ class TensorNameMap:
14981510

14991511
MODEL_TENSOR.V_MM_UP: (
15001512
"model.vision.linear_proj.dense_h_to_4h", # cogvlm
1513+
"visual.merger.up_proj", # glm4v
15011514
),
15021515

15031516
MODEL_TENSOR.V_MM_DOWN: (
15041517
"model.vision.linear_proj.dense_4h_to_h", # cogvlm
1518+
"visual.merger.down_proj", # glm4v
15051519
),
15061520

15071521
MODEL_TENSOR.V_MM_GATE: (
15081522
"model.vision.linear_proj.gate_proj", # cogvlm
1523+
"visual.merger.gate_proj", # glm4v
15091524
),
15101525

15111526
MODEL_TENSOR.V_TOK_BOI: (

src/llama-hparams.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,7 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
231231

232232
return false;
233233
}
234+
235+
bool llama_hparams::use_mrope() const {
236+
return rope_sections[0] > 0 && rope_sections[1] > 0;
237+
}

src/llama-hparams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ struct llama_hparams {
270270
// TODO: think of a better place for this function
271271
// TODO: pack the SWA params in a struct?
272272
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
273+
274+
bool use_mrope() const;
273275
};
274276

275277
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-model.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,7 +1689,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16891689
} break;
16901690
case LLM_ARCH_GLM4:
16911691
{
1692-
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1692+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1693+
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
16931694
switch (hparams.n_layer) {
16941695
case 40: type = LLM_TYPE_9B; break;
16951696
case 61: type = LLM_TYPE_32B; break;
@@ -1698,8 +1699,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
16981699
} break;
16991700
case LLM_ARCH_GLM4_MOE:
17001701
{
1701-
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1702-
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1702+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
1703+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1704+
ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false);
17031705

17041706
// MoE parameters
17051707
ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert);
@@ -7792,7 +7794,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
77927794
case LLM_ARCH_DEEPSEEK2:
77937795
case LLM_ARCH_PLM:
77947796
case LLM_ARCH_CHATGLM:
7795-
case LLM_ARCH_GLM4:
77967797
case LLM_ARCH_GRANITE:
77977798
case LLM_ARCH_GRANITE_MOE:
77987799
case LLM_ARCH_GRANITE_HYBRID:
@@ -7854,7 +7855,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
78547855
case LLM_ARCH_LFM2:
78557856
case LLM_ARCH_LFM2MOE:
78567857
case LLM_ARCH_SMALLTHINKER:
7857-
case LLM_ARCH_GLM4_MOE:
78587858
case LLM_ARCH_SEED_OSS:
78597859
case LLM_ARCH_GROVEMOE:
78607860
case LLM_ARCH_APERTUS:
@@ -7871,6 +7871,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
78717871
case LLM_ARCH_QWEN3VLMOE:
78727872
return LLAMA_ROPE_TYPE_IMROPE;
78737873

7874+
case LLM_ARCH_GLM4:
7875+
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NORM;
7876+
case LLM_ARCH_GLM4_MOE:
7877+
return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX;
7878+
78747879
// all model arches should be listed explicitly here
78757880
case LLM_ARCH_UNKNOWN:
78767881
GGML_ABORT("unknown architecture");

0 commit comments

Comments
 (0)