Skip to content

Commit 5420a5e

Browse files
committed
fix(lora): adjust convert_weight and add lora_config helper
1 parent fc1edac commit 5420a5e

File tree

4 files changed

+115
-30
lines changed

4 files changed

+115
-30
lines changed

python/mlc_llm/cli/convert_weight.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def _parse_output(path: Union[str, Path]) -> Path:
3131
path.mkdir(parents=True, exist_ok=True)
3232
return path
3333

34+
def _parse_lora_adapter(path: Union[str, Path]) -> Path:
35+
path = Path(path)
36+
if not path.exists():
37+
raise argparse.ArgumentTypeError(f"LoRA adapter path does not exist: {path}")
38+
return path
39+
3440
parser = ArgumentParser("MLC AutoLLM Quantization Framework")
3541
parser.add_argument(
3642
"config",
@@ -77,8 +83,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
7783
required=True,
7884
help=HELP["output_quantize"] + " (required)",
7985
)
80-
<<<<<<< Updated upstream
81-
=======
86+
8287
# Mutually exclusive LoRA options: merge vs separate
8388
lora_group = parser.add_mutually_exclusive_group()
8489
lora_group.add_argument(
@@ -99,7 +104,6 @@ def _parse_output(path: Union[str, Path]) -> Path:
99104
default=1.0,
100105
help="Scaling factor for LoRA when used with --lora-separate (default: %(default)s).",
101106
)
102-
>>>>>>> Stashed changes
103107

104108
parsed = parser.parse_args(argv)
105109
parsed.source, parsed.source_format = detect_weight(
@@ -116,10 +120,7 @@ def _parse_output(path: Union[str, Path]) -> Path:
116120
source=parsed.source,
117121
source_format=parsed.source_format,
118122
output=parsed.output,
119-
<<<<<<< Updated upstream
120-
=======
121123
lora_adapter=parsed.lora_adapter,
122124
lora_separate=parsed.lora_separate,
123125
lora_alpha=parsed.lora_alpha,
124-
>>>>>>> Stashed changes
125126
)

python/mlc_llm/interface/convert_weight.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from io import StringIO
77
from pathlib import Path
8-
from typing import Any, Dict, Iterator, Tuple
8+
from typing import Any, Dict, Iterator, Optional, Tuple
99

1010
from tvm import tir
1111
from tvm.contrib import tvmjs
@@ -34,14 +34,11 @@ class ConversionArgs: # pylint: disable=too-many-instance-attributes
3434
source: Path
3535
source_format: str
3636
output: Path
37-
<<<<<<< Updated upstream
38-
=======
3937
# Legacy merge-mode
4038
lora_adapter: Optional[Path] = None
4139
# New separate-mode
4240
lora_separate: Optional[Path] = None
4341
lora_alpha: float = 1.0
44-
>>>>>>> Stashed changes
4542

4643
def display(self) -> None:
4744
"""Display the arguments to stdout."""
@@ -58,20 +55,23 @@ def _device_to_str(device: Device) -> str:
5855
print(f" {bold('--source'):<25} {self.source}", file=out)
5956
print(f" {bold('--source-format'):<25} {self.source_format}", file=out)
6057
print(f" {bold('--output'):<25} {self.output}", file=out)
61-
<<<<<<< Updated upstream
62-
=======
6358
if self.lora_adapter:
6459
print(f" {bold('--lora-adapter'):<25} {self.lora_adapter}", file=out)
6560
if self.lora_separate:
6661
print(f" {bold('--lora-separate'):<25} {self.lora_separate}", file=out)
6762
print(f" {bold('--lora-alpha'):<25} {self.lora_alpha}", file=out)
68-
>>>>>>> Stashed changes
6963
print(out.getvalue().rstrip())
7064

7165

66+
def _merge_lora_weights(args: ConversionArgs) -> Path:
67+
"""Merge LoRA weights into base model weights (legacy mode)."""
68+
# TODO: Implement LoRA weight merging for legacy mode
69+
# For now, just return the original source path
70+
logger.warning("LoRA weight merging not yet implemented, using base weights only")
71+
return args.source
72+
73+
7274
def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals
73-
<<<<<<< Updated upstream
74-
=======
7575
# ------------------------------------------------------------------
7676
# Handle LoRA: separate-pack or legacy merge
7777
# ------------------------------------------------------------------
@@ -93,7 +93,6 @@ def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-loc
9393
# legacy merge path (if provided)
9494
source_path = _merge_lora_weights(args) if args.lora_adapter else args.source
9595

96-
>>>>>>> Stashed changes
9796
pre_shards_num = os.getenv("MLC_INTERNAL_PRESHARD_NUM")
9897
# model config & quantization config
9998
model_config = args.model.config.from_file(args.config)
@@ -160,7 +159,7 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]:
160159
nonlocal total_params, total_bytes
161160
with Target.from_device(args.device), tqdm.redirect():
162161
loader = LOADER[args.source_format](
163-
path=args.source,
162+
path=source_path,
164163
extern_param_map=args.model.source[args.source_format](
165164
model_config, args.quantization
166165
),
@@ -175,13 +174,11 @@ def _param_generator() -> Iterator[Tuple[str, NDArray]]:
175174
total_params = loader.stats.total_param_num
176175

177176
def _metadata_callback() -> Dict[str, Any]:
178-
return {
177+
metadata = {
179178
"ParamSize": len(param_names),
180179
"ParamBytes": total_bytes,
181180
"BitsPerParam": total_bytes * 8.0 / total_params,
182181
}
183-
<<<<<<< Updated upstream
184-
=======
185182
# Add LoRA metadata if adapter was used
186183
if args.lora_separate:
187184
metadata["LoRASeparate"] = True
@@ -191,7 +188,6 @@ def _metadata_callback() -> Dict[str, Any]:
191188
metadata["LoRAAdapter"] = str(args.lora_adapter)
192189
metadata["LoRAMerged"] = True
193190
return metadata
194-
>>>>>>> Stashed changes
195191

196192
# dump to output directory
197193
tvmjs.dump_ndarray_cache(
@@ -215,13 +211,10 @@ def _metadata_callback() -> Dict[str, Any]:
215211
green("Bits per parameter"),
216212
total_bytes * 8.0 / total_params,
217213
)
218-
<<<<<<< Updated upstream
219-
=======
220214
if args.lora_separate:
221215
logger.info("%s: %s", green("LoRA adapter packed from"), bold(str(args.lora_separate)))
222216
elif args.lora_adapter:
223217
logger.info("%s: %s", green("LoRA adapter merged from"), bold(str(args.lora_adapter)))
224-
>>>>>>> Stashed changes
225218
logger.info("Saved to directory: %s", bold(str(args.output)))
226219

227220

@@ -233,11 +226,6 @@ def convert_weight( # pylint: disable=too-many-arguments
233226
source: Path,
234227
source_format: str,
235228
output: Path,
236-
<<<<<<< Updated upstream
237-
):
238-
"""MLC LLM's weight conversation and quantization flow."""
239-
args = ConversionArgs(config, quantization, model, device, source, source_format, output)
240-
=======
241229
lora_adapter: Optional[Path] = None,
242230
lora_separate: Optional[Path] = None,
243231
lora_alpha: float = 1.0,
@@ -255,6 +243,5 @@ def convert_weight( # pylint: disable=too-many-arguments
255243
lora_separate,
256244
lora_alpha,
257245
)
258-
>>>>>>> Stashed changes
259246
args.display()
260247
_convert_args(args)

python/mlc_llm/lora/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""LoRA (Low-Rank Adaptation) module for MLC LLM."""
2+
3+
from .lora import upload_lora, set_lora, get_registered_lora_dirs
4+
from .lora_config import LoRAConfig
5+
6+
__all__ = [
7+
"upload_lora",
8+
"set_lora",
9+
"get_registered_lora_dirs",
10+
"LoRAConfig",
11+
]

python/mlc_llm/lora/lora_config.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""LoRA configuration dataclass for MLC LLM."""
2+
3+
from dataclasses import dataclass
4+
from typing import List, Optional
5+
6+
7+
@dataclass
8+
class LoRAConfig:
9+
"""Configuration for LoRA (Low-Rank Adaptation) parameters.
10+
11+
This configuration is used to define LoRA adaptation parameters
12+
for fine-tuning large language models with low-rank matrices.
13+
14+
Parameters
15+
----------
16+
r : int
17+
LoRA rank (dimension of the low-rank matrices). Common values are 4, 8, 16, 32.
18+
Higher values provide more capacity but increase parameters.
19+
20+
lora_alpha : float
21+
LoRA scaling factor. Controls the magnitude of the LoRA adaptation.
22+
Typically set to the same value as r or higher.
23+
24+
lora_dropout : float
25+
Dropout probability for LoRA layers during training.
26+
Set to 0.0 for inference.
27+
28+
target_modules : List[str]
29+
List of module names to apply LoRA to.
30+
Common targets: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
31+
32+
fan_in_fan_out : bool
33+
Whether the layer uses fan_in_fan_out convention.
34+
Set to True for Conv1D layers, False for Linear layers.
35+
36+
bias : str
37+
Bias type for LoRA layers. Options: "none", "all", "lora_only"
38+
39+
task_type : Optional[str]
40+
Task type for the LoRA adaptation (e.g., "CAUSAL_LM")
41+
42+
inference_mode : bool
43+
Whether the model is in inference mode.
44+
45+
merge_weights : bool
46+
Whether to merge LoRA weights into base weights during inference.
47+
"""
48+
49+
r: int = 8
50+
lora_alpha: float = 16.0
51+
lora_dropout: float = 0.1
52+
target_modules: List[str] = None
53+
fan_in_fan_out: bool = False
54+
bias: str = "none"
55+
task_type: Optional[str] = None
56+
inference_mode: bool = False
57+
merge_weights: bool = True
58+
59+
def __post_init__(self):
60+
"""Set default target modules if not provided."""
61+
if self.target_modules is None:
62+
self.target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
63+
64+
@property
65+
def scaling(self) -> float:
66+
"""Return the scaling factor for LoRA: alpha / r."""
67+
return self.lora_alpha / self.r
68+
69+
def to_dict(self) -> dict:
70+
"""Convert configuration to dictionary."""
71+
return {
72+
"r": self.r,
73+
"lora_alpha": self.lora_alpha,
74+
"lora_dropout": self.lora_dropout,
75+
"target_modules": self.target_modules,
76+
"fan_in_fan_out": self.fan_in_fan_out,
77+
"bias": self.bias,
78+
"task_type": self.task_type,
79+
"inference_mode": self.inference_mode,
80+
"merge_weights": self.merge_weights,
81+
}
82+
83+
@classmethod
84+
def from_dict(cls, config_dict: dict) -> "LoRAConfig":
85+
"""Create configuration from dictionary."""
86+
return cls(**config_dict)

0 commit comments

Comments
 (0)