55import os
66from io import StringIO
77from pathlib import Path
8- from typing import Any , Dict , Iterator , Tuple
8+ from typing import Any , Dict , Iterator , Optional , Tuple
99
1010from tvm import tir
1111from tvm .contrib import tvmjs
@@ -34,14 +34,11 @@ class ConversionArgs: # pylint: disable=too-many-instance-attributes
3434 source : Path
3535 source_format : str
3636 output : Path
37- < << << << Updated upstream
38- == == == =
3937 # Legacy merge-mode
4038 lora_adapter : Optional [Path ] = None
4139 # New separate-mode
4240 lora_separate : Optional [Path ] = None
4341 lora_alpha : float = 1.0
44- >> >> >> > Stashed changes
4542
4643 def display (self ) -> None :
4744 """Display the arguments to stdout."""
@@ -58,20 +55,23 @@ def _device_to_str(device: Device) -> str:
5855 print (f" { bold ('--source' ):<25} { self .source } " , file = out )
5956 print (f" { bold ('--source-format' ):<25} { self .source_format } " , file = out )
6057 print (f" { bold ('--output' ):<25} { self .output } " , file = out )
61- << << << < Updated upstream
62- == == == =
6358 if self .lora_adapter :
6459 print (f" { bold ('--lora-adapter' ):<25} { self .lora_adapter } " , file = out )
6560 if self .lora_separate :
6661 print (f" { bold ('--lora-separate' ):<25} { self .lora_separate } " , file = out )
6762 print (f" { bold ('--lora-alpha' ):<25} { self .lora_alpha } " , file = out )
68- >> >> >> > Stashed changes
6963 print (out .getvalue ().rstrip ())
7064
7165
66+ def _merge_lora_weights (args : ConversionArgs ) -> Path :
67+ """Merge LoRA weights into base model weights (legacy mode)."""
68+ # TODO: Implement LoRA weight merging for legacy mode
69+ # For now, just return the original source path
70+ logger .warning ("LoRA weight merging not yet implemented, using base weights only" )
71+ return args .source
72+
73+
7274def _convert_args (args : ConversionArgs ) -> None : # pylint: disable=too-many-locals
73- << << << < Updated upstream
74- == == == =
7575 # ------------------------------------------------------------------
7676 # Handle LoRA: separate-pack or legacy merge
7777 # ------------------------------------------------------------------
@@ -93,7 +93,6 @@ def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-loc
9393 # legacy merge path (if provided)
9494 source_path = _merge_lora_weights (args ) if args .lora_adapter else args .source
9595
96- >> >> >> > Stashed changes
9796 pre_shards_num = os .getenv ("MLC_INTERNAL_PRESHARD_NUM" )
9897 # model config & quantization config
9998 model_config = args .model .config .from_file (args .config )
@@ -160,7 +159,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]:
160159 nonlocal total_params , total_bytes
161160 with Target .from_device (args .device ), tqdm .redirect ():
162161 loader = LOADER [args .source_format ](
163- path = args . source ,
162+ path = source_path ,
164163 extern_param_map = args .model .source [args .source_format ](
165164 model_config , args .quantization
166165 ),
@@ -175,13 +174,11 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]:
175174 total_params = loader .stats .total_param_num
176175
177176 def _metadata_callback () -> Dict [str , Any ]:
178- return {
177+ metadata = {
179178 "ParamSize" : len (param_names ),
180179 "ParamBytes" : total_bytes ,
181180 "BitsPerParam" : total_bytes * 8.0 / total_params ,
182181 }
183- << << << < Updated upstream
184- == == == =
185182 # Add LoRA metadata if adapter was used
186183 if args .lora_separate :
187184 metadata ["LoRASeparate" ] = True
@@ -191,7 +188,6 @@ def _metadata_callback() -> Dict[str, Any]:
191188 metadata ["LoRAAdapter" ] = str (args .lora_adapter )
192189 metadata ["LoRAMerged" ] = True
193190 return metadata
194- >> >> >> > Stashed changes
195191
196192 # dump to output directory
197193 tvmjs .dump_ndarray_cache (
@@ -215,13 +211,10 @@ def _metadata_callback() -> Dict[str, Any]:
215211 green ("Bits per parameter" ),
216212 total_bytes * 8.0 / total_params ,
217213 )
218- << << << < Updated upstream
219- == == == =
220214 if args .lora_separate :
221215 logger .info ("%s: %s" , green ("LoRA adapter packed from" ), bold (str (args .lora_separate )))
222216 elif args .lora_adapter :
223217 logger .info ("%s: %s" , green ("LoRA adapter merged from" ), bold (str (args .lora_adapter )))
224- >> >> >> > Stashed changes
225218 logger .info ("Saved to directory: %s" , bold (str (args .output )))
226219
227220
@@ -233,11 +226,6 @@ def convert_weight( # pylint: disable=too-many-arguments
233226 source : Path ,
234227 source_format : str ,
235228 output : Path ,
236- << << << < Updated upstream
237- ):
238- """MLC LLM 's weight conversation and quantization flow ."""
239- args = ConversionArgs (config , quantization , model , device , source , source_format , output )
240- == == == =
241229 lora_adapter : Optional [Path ] = None ,
242230 lora_separate : Optional [Path ] = None ,
243231 lora_alpha : float = 1.0 ,
@@ -255,6 +243,5 @@ def convert_weight( # pylint: disable=too-many-arguments
255243 lora_separate ,
256244 lora_alpha ,
257245 )
258- >> >> >> > Stashed changes
259246 args .display ()
260247 _convert_args (args )
0 commit comments