@@ -862,6 +862,14 @@ def set_gguf_parameters(self):
862862 logger .warning (f"Unknown RoPE type: { rope_type } " )
863863 logger .info (f"gguf: rope scaling type = { rope_gguf_type .name } " )
864864
865+ if "mrope_section" in self .rope_parameters :
866+ mrope_section = self .rope_parameters ["mrope_section" ]
867+ # Pad to 4 dimensions [time, height, width, extra]
868+ while len (mrope_section ) < 4 :
869+ mrope_section .append (0 )
870+ self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
871+ logger .info (f"gguf: mrope sections: { mrope_section [:4 ]} " )
872+
865873 if (rope_theta := rope_params .get ("rope_theta" )) is not None :
866874 self .gguf_writer .add_rope_freq_base (rope_theta )
867875 logger .info (f"gguf: rope theta = { rope_theta } " )
@@ -3739,9 +3747,6 @@ class Qwen2VLModel(TextModel):
37393747
37403748 def set_gguf_parameters (self ):
37413749 super ().set_gguf_parameters ()
3742- mrope_section = self .hparams ["rope_scaling" ]["mrope_section" ]
3743- mrope_section += [0 ] * max (0 , 4 - len (mrope_section ))
3744- self .gguf_writer .add_rope_dimension_sections (mrope_section )
37453750
37463751 def set_vocab (self ):
37473752 try :
@@ -4377,6 +4382,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43774382 return super ().modify_tensors (data_torch , name , bid )
43784383
43794384
4385+ @ModelBase .register ("Glm4vForConditionalGeneration" , "Glm4vMoeForConditionalGeneration" )
4386+ class Glm4VVisionModel (Qwen3VLVisionModel ):
4387+ def set_gguf_parameters (self ):
4388+ MmprojModel .set_gguf_parameters (self ) # skip Qwen3VLVisionModel parameters
4389+ assert self .hparams_vision is not None
4390+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .GLM4V )
4391+
4392+ hidden_act = str (self .hparams_vision .get ("hidden_act" , "" )).lower ()
4393+ if hidden_act == "gelu" :
4394+ self .gguf_writer .add_vision_use_gelu (True )
4395+ elif hidden_act == "silu" :
4396+ self .gguf_writer .add_vision_use_silu (True )
4397+
4398+ rms_norm_eps = self .hparams_vision .get ("rms_norm_eps" , 1e-5 )
4399+ self .gguf_writer .add_vision_attention_layernorm_eps (rms_norm_eps )
4400+
4401+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4402+ if name .startswith ("model.visual." ):
4403+ name = name .replace ("model.visual." , "visual." )
4404+ if name .startswith ("visual.merger." ):
4405+ return [(self .map_tensor_name (name ), data_torch )]
4406+ return super ().modify_tensors (data_torch , name , bid )
4407+
4408+
43804409@ModelBase .register ("Qwen3VLForConditionalGeneration" )
43814410class Qwen3VLTextModel (Qwen3Model ):
43824411 model_arch = gguf .MODEL_ARCH .QWEN3VL
@@ -4385,20 +4414,6 @@ def set_gguf_parameters(self):
43854414 super ().set_gguf_parameters ()
43864415
43874416 # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4388- text_config = self .hparams .get ("text_config" , {})
4389- # rope_scaling is deprecated in V5, use rope_parameters instead
4390- rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4391-
4392- if rope_scaling .get ("mrope_section" ):
4393- # mrope_section contains [time, height, width] dimensions
4394- mrope_section = rope_scaling ["mrope_section" ]
4395- # Pad to 4 dimensions [time, height, width, extra]
4396- while len (mrope_section ) < 4 :
4397- mrope_section .append (0 )
4398- self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4399-
4400- logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4401-
44024417 vision_config = self .hparams .get ("vision_config" , {})
44034418 deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
44044419 self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
@@ -4417,22 +4432,6 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
44174432
44184433 def set_gguf_parameters (self ):
44194434 super ().set_gguf_parameters ()
4420-
4421- # Handle MRoPE (Multi-axis Rotary Position Embedding) for Qwen3-VL
4422- text_config = self .hparams .get ("text_config" , {})
4423- # rope_scaling is deprecated in V5, use rope_parameters instead
4424- rope_scaling = text_config .get ("rope_scaling" ) or text_config .get ("rope_parameters" ) or {}
4425-
4426- if rope_scaling .get ("mrope_section" ):
4427- # mrope_section contains [time, height, width] dimensions
4428- mrope_section = rope_scaling ["mrope_section" ]
4429- # Pad to 4 dimensions [time, height, width, extra]
4430- while len (mrope_section ) < 4 :
4431- mrope_section .append (0 )
4432- self .gguf_writer .add_rope_dimension_sections (mrope_section [:4 ])
4433-
4434- logger .info (f"MRoPE sections: { mrope_section [:4 ]} " )
4435-
44364435 vision_config = self .hparams .get ("vision_config" , {})
44374436 deepstack_layer_num = len (vision_config .get ("deepstack_visual_indexes" , []))
44384437 self .gguf_writer .add_num_deepstack_layers (deepstack_layer_num )
@@ -7795,6 +7794,15 @@ def prepare_tensors(self):
77957794@ModelBase .register ("Glm4ForCausalLM" , "Glm4vForConditionalGeneration" )
77967795class Glm4Model (TextModel ):
77977796 model_arch = gguf .MODEL_ARCH .GLM4
7797+ use_mrope = False
7798+ partial_rotary_factor = 0.5
7799+
7800+ def __init__ (self , * args , ** kwargs ):
7801+ super ().__init__ (* args , ** kwargs )
7802+ self .partial_rotary_factor = self .rope_parameters .get ("partial_rotary_factor" , 0.5 )
7803+ if "mrope_section" in self .rope_parameters :
7804+ self .use_mrope = True
7805+ logger .info ("Q/K weight will need to be permuted for M-RoPE" )
77987806
77997807 def set_vocab (self ):
78007808 from transformers import AutoTokenizer
@@ -7816,17 +7824,49 @@ def set_gguf_parameters(self):
78167824 super ().set_gguf_parameters ()
78177825 if (rope_dim := self .hparams .get ("head_dim" )) is None :
78187826 rope_dim = self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
7819- self .gguf_writer .add_rope_dimension_count (int (rope_dim * self .hparams .get ("partial_rotary_factor" , 0.5 )))
7827+ self .gguf_writer .add_rope_dimension_count (int (rope_dim * self .partial_rotary_factor ))
7828+
7829+ @staticmethod
7830+ def normal_to_neox (weights : Tensor , n_head : int , n_head_kv : int , head_dim : int , partial_rotary_factor : float ) -> Tensor :
7831+ orig_shape = weights .shape
7832+ if len (orig_shape ) == 1 :
7833+ weights = weights .unsqueeze (1 ) # [out_dim, 1]
7834+ if len (weights .shape ) != 2 :
7835+ raise ValueError ("Only 1D and 2D tensors are supported." )
7836+ n_effective_heads = weights .shape [0 ] // head_dim
7837+ if n_head_kv is not None and n_effective_heads != n_head :
7838+ if n_effective_heads != n_head_kv :
7839+ raise AssertionError (f"Mismatch in effective heads: computed { n_effective_heads } , expected { n_head } or { n_head_kv } " )
7840+ rotary_dim = int (head_dim * partial_rotary_factor )
7841+ if rotary_dim % 2 != 0 :
7842+ raise ValueError ("rotary_dim must be even." )
7843+ reshaped = weights .reshape (n_effective_heads , head_dim , - 1 )
7844+ rot_part = reshaped [:, :rotary_dim , :]
7845+ non_rot_part = reshaped [:, rotary_dim :, :]
7846+ permuted_rot = torch .cat ((rot_part [:, ::2 , :], rot_part [:, 1 ::2 , :]), dim = 1 )
7847+ combined = torch .cat ((permuted_rot , non_rot_part ), dim = 1 )
7848+ result = combined .reshape (weights .shape )
7849+ return result if len (orig_shape ) != 1 else result .squeeze (1 )
78207850
78217851 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
78227852 if name .startswith ("model.visual." ): # ignore visual part of Glm4v
78237853 return []
78247854 elif name .startswith ("model.language_model." ):
78257855 name = name .replace ("language_model." , "" ) # for Glm4v
7856+ if self .use_mrope :
7857+ n_head = self .hparams ["num_attention_heads" ]
7858+ n_kv_head = self .hparams ["num_key_value_heads" ]
7859+ n_embd = self .hparams ["hidden_size" ]
7860+ head_dim = n_embd // n_head
7861+ # because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here
7862+ if name .endswith (("q_proj.weight" , "q_proj.bias" )):
7863+ data_torch = Glm4Model .normal_to_neox (data_torch , n_head , n_head , head_dim , self .partial_rotary_factor )
7864+ if name .endswith (("k_proj.weight" , "k_proj.bias" )):
7865+ data_torch = Glm4Model .normal_to_neox (data_torch , n_head , n_kv_head , head_dim , self .partial_rotary_factor )
78267866 return super ().modify_tensors (data_torch , name , bid )
78277867
78287868
7829- @ModelBase .register ("Glm4MoeForCausalLM" )
7869+ @ModelBase .register ("Glm4MoeForCausalLM" , "Glm4vMoeForConditionalGeneration" )
78307870class Glm4MoeModel (TextModel ):
78317871 model_arch = gguf .MODEL_ARCH .GLM4_MOE
78327872
@@ -7893,6 +7933,7 @@ def set_gguf_parameters(self):
78937933
78947934 _experts : list [dict [str , Tensor ]] | None = None
78957935
7936+ # note: unlike GLM4V non-MoE, we don't need to permute Q/K here since GLM4V_MOE uses Neox ordering already
78967937 def modify_tensors (
78977938 self , data_torch : Tensor , name : str , bid : int | None
78987939 ) -> Iterable [tuple [str , Tensor ]]:
0 commit comments