Skip to content

Commit 142ae3d

Browse files
authored
Fix PEFT integration with new weight loader (#42701)
simplify
1 parent 75beab1 commit 142ae3d

File tree

3 files changed

+7
-27
lines changed

3 files changed

+7
-27
lines changed

src/transformers/conversion_mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def get_model_conversion_mapping(
228228
"""
229229
weight_conversions = []
230230

231-
# Load models with key mapping
231+
# Load models with explicit, user-provided key mapping
232232
if key_mapping is not None:
233233
weight_conversions = [WeightRenaming(source_patterns=k, target_patterns=v) for k, v in key_mapping.items()]
234234
elif any(

src/transformers/integrations/peft.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from typing import Any, Literal
1919

20+
from ..conversion_mapping import get_model_conversion_mapping
2021
from ..core_model_loading import WeightRenaming, rename_source_key
2122
from ..utils import (
2223
CONFIG_NAME,
@@ -46,26 +47,6 @@
4647
logger = logging.get_logger(__name__)
4748

4849

49-
# DO NOT MODIFY, KEPT FOR BC ONLY
50-
VLMS = [
51-
"aria",
52-
"ayavision",
53-
"emu3",
54-
"fuyu",
55-
"gotocr2",
56-
"gemma3",
57-
"internvl",
58-
"llava", # all llava prefixed models fall under this check
59-
"mistral3",
60-
"mllama",
61-
"paligemma",
62-
"qwen2vl",
63-
"qwen2_5_vl",
64-
"videollava",
65-
"vipllava",
66-
]
67-
68-
6950
class PeftAdapterMixin:
7051
"""
7152
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
@@ -211,11 +192,10 @@ def load_adapter(
211192
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
212193
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
213194

195+
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
196+
weight_conversions = get_model_conversion_mapping(self, key_mapping=key_mapping)
214197
# peft only supports low_cpu_mem_usage starting from v0.13.0
215198
peft_load_kwargs = {}
216-
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
217-
if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS):
218-
key_mapping = self._checkpoint_conversion_mapping
219199
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
220200

221201
adapter_name = adapter_name if adapter_name is not None else "default"
@@ -292,8 +272,8 @@ def load_adapter(
292272

293273
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
294274
renamings = []
295-
if key_mapping:
296-
renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)]
275+
if weight_conversions:
276+
renamings = [entry for entry in weight_conversions if isinstance(entry, WeightRenaming)]
297277
processed_adapter_state_dict = {}
298278
prefix = "base_model.model."
299279
state_dict = self.state_dict()

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4046,7 +4046,7 @@ def from_pretrained(
40464046
hf_quantizer.postprocess_model(model, config=config) # usually a no-op but sometimes needed
40474047

40484048
if _adapter_model_path is not None:
4049-
adapter_kwargs["key_mapping"] = weight_conversions # TODO: Dynamic weight loader for adapters
4049+
adapter_kwargs["key_mapping"] = key_mapping
40504050
model.load_adapter(
40514051
_adapter_model_path,
40524052
adapter_name=adapter_name,

0 commit comments

Comments
 (0)