@@ -713,6 +713,9 @@ def load_hparams(dir_model: Path, is_mistral_format: bool):
713713 if "llm_config" in config :
714714 # rename for InternVL
715715 config ["text_config" ] = config ["llm_config" ]
716+ if "lm_config" in config :
717+ # rename for GlmASR
718+ config ["text_config" ] = config ["lm_config" ]
716719 if "thinker_config" in config :
717720 # rename for Qwen2.5-Omni
718721 config ["text_config" ] = config ["thinker_config" ]["text_config" ]
@@ -1529,6 +1532,21 @@ def _try_set_pooling_type(self) -> None:
15291532 raise NotImplementedError ("Only MEAN, CLS, and LAST pooling types supported" )
15301533 self .gguf_writer .add_pooling_type (pooling_type )
15311534
1535+ def _set_vocab_glmedge (self ):
1536+ from transformers import AutoTokenizer
1537+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
1538+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = True )
1539+ tokens , toktypes , tokpre = self .get_vocab_base ()
1540+ self .gguf_writer .add_tokenizer_model ("gpt2" )
1541+ self .gguf_writer .add_tokenizer_pre (tokpre )
1542+ self .gguf_writer .add_token_list (tokens )
1543+ self .gguf_writer .add_token_types (toktypes )
1544+ special_vocab ._set_special_token ("eos" , tokenizer .get_added_vocab ()["<|endoftext|>" ])
1545+ special_vocab ._set_special_token ("eot" , tokenizer .get_added_vocab ()["<|user|>" ])
1546+ special_vocab ._set_special_token ("unk" , tokenizer .get_added_vocab ()["<|endoftext|>" ])
1547+ special_vocab ._set_special_token ("bos" , tokenizer .get_added_vocab ()["<|endoftext|>" ])
1548+ special_vocab .add_to_gguf (self .gguf_writer )
1549+
15321550 def _set_vocab_interns1 (self ):
15331551 tokens : list [str ] = []
15341552 toktypes : list [int ] = []
@@ -1658,7 +1676,7 @@ class MmprojModel(ModelBase):
16581676 preprocessor_config : dict [str , Any ]
16591677 global_config : dict [str , Any ]
16601678
1661- n_block_keys = ["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" ]
1679+ n_block_keys = ["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" , "encoder_layers" ]
16621680
16631681 has_vision_encoder : bool = True # by default
16641682 has_audio_encoder : bool = False
@@ -1734,7 +1752,8 @@ def get_vision_config(self) -> dict[str, Any] | None:
17341752 return self .global_config .get (config_name )
17351753
17361754 def get_audio_config (self ) -> dict [str , Any ] | None :
1737- return self .global_config .get ("audio_config" )
1755+ mm_config_key = "whisper_config" if "whisper_config" in self .hparams else "audio_config"
1756+ return self .global_config .get (mm_config_key )
17381757
17391758 def set_type (self ):
17401759 self .gguf_writer .add_type (gguf .GGUFType .MMPROJ )
@@ -2372,8 +2391,13 @@ def __init__(self, *args, **kwargs):
23722391 # fix for SmolVLM2, missing `num_attention_heads` in config.json
23732392 if self .hf_arch == "VLlama3ForCausalLM" :
23742393 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
2394+ hparams = ModelBase .load_hparams (self .dir_model , is_mistral_format = False )
2395+ self .origin_hf_arch = hparams .get ('architectures' , [None ])[0 ]
23752396
23762397 def set_vocab (self ):
2398+ if self .origin_hf_arch == "GlmasrModel" :
2399+ return self ._set_vocab_glmedge ()
2400+
23772401 if self .is_mistral_format :
23782402 return self ._set_vocab_mistral ()
23792403
@@ -2444,6 +2468,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
24442468 "vision_language_adapter." ,
24452469 "patch_merger." ,
24462470 "pre_mm_projector_norm" ,
2471+ "audio_encoder." ,
24472472 ]
24482473
24492474 is_multimodal_tensor = "vision_tower" in name \
@@ -8846,6 +8871,63 @@ def __init__(self, *args, **kwargs):
88468871 raise NotImplementedError ("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument" )
88478872
88488873
8874+ @ModelBase .register ("GlmasrModel" )
8875+ class GlmASRWhisperEncoderModel (MmprojModel ):
8876+ has_vision_encoder = False
8877+ has_audio_encoder = True
8878+
8879+ def __init__ (self , * args , ** kwargs ):
8880+ super ().__init__ (* args , ** kwargs )
8881+ if "hidden_size" not in self .hparams and "intermediate_size" not in self .hparams :
8882+ self .hparams ["hidden_size" ] = self .hparams ["d_model" ]
8883+ self .hparams ["intermediate_size" ] = self .hparams ["encoder_ffn_dim" ]
8884+ self .hparams ["num_attention_heads" ] = self .hparams ["encoder_attention_heads" ]
8885+
8886+ def set_gguf_parameters (self ):
8887+ super ().set_gguf_parameters ()
8888+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .GLMA )
8889+ self .gguf_writer .add_audio_num_mel_bins (self .hparams ["num_mel_bins" ])
8890+ self .gguf_writer .add_audio_attention_layernorm_eps (self .hparams .get ("layer_norm_eps" , 1e-5 ))
8891+ self .gguf_writer .add_audio_stack_factor (self .global_config ["merge_factor" ])
8892+
8893+ def tensor_force_quant (self , name , new_name , bid , n_dims ):
8894+ if ".conv" in name and ".weight" in name :
8895+ return gguf .GGMLQuantizationType .F16
8896+ return super ().tensor_force_quant (name , new_name , bid , n_dims )
8897+
8898+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8899+ del bid # unused
8900+
8901+ if name .startswith ("model." ) or name .startswith ("lm_head." ):
8902+ # skip language model tensors
8903+ return []
8904+
8905+ if name .startswith ("audio_encoder.whisper." ):
8906+ name = name .replace ("audio_encoder.whisper." ,"audio_tower." )
8907+ if "audio_encoder.layer_norm." in name or "audio_encoder.proj." in name :
8908+ name = name .replace ("audio_encoder." , "audio_encoder.adapting." )
8909+
8910+ if name .startswith ("audio_encoder.audio_bos_eos_token." ):
8911+ return [(self .map_tensor_name ("model.vision.boi" ), data_torch [0 ]), (self .map_tensor_name ("model.vision.eoi" ), data_torch [1 ])]
8912+
8913+ if name .startswith ("audio_encoder.adapting." ):
8914+ name = name .replace ("audio_encoder.adapting." ,"audio.multi_modal_projector." )
8915+ if ".layer_norm." in name :
8916+ name = name .replace (".layer_norm." , ".ln_pre." )
8917+ if ".0." in name :
8918+ name = name .replace (".0." , ".linear_1." )
8919+ if ".2." in name :
8920+ name = name .replace (".2." , ".linear_2." )
8921+ if ".proj." in name :
8922+ return []
8923+
8924+ if "conv1.bias" in name or "conv2.bias" in name :
8925+ # transpose conv1 and conv2 bias
8926+ data_torch = data_torch .unsqueeze (- 1 )
8927+
8928+ return [(self .map_tensor_name (name ), data_torch )]
8929+
8930+
88498931@ModelBase .register ("Qwen2AudioForConditionalGeneration" )
88508932class WhisperEncoderModel (MmprojModel ):
88518933 has_vision_encoder = False # no vision encoder
0 commit comments