|
17 | 17 | import os |
18 | 18 | from typing import Any, Literal |
19 | 19 |
|
| 20 | +from ..conversion_mapping import get_model_conversion_mapping |
20 | 21 | from ..core_model_loading import WeightRenaming, rename_source_key |
21 | 22 | from ..utils import ( |
22 | 23 | CONFIG_NAME, |
|
46 | 47 | logger = logging.get_logger(__name__) |
47 | 48 |
|
48 | 49 |
|
49 | | -# DO NOT MODIFY, KEPT FOR BC ONLY |
50 | | -VLMS = [ |
51 | | - "aria", |
52 | | - "ayavision", |
53 | | - "emu3", |
54 | | - "fuyu", |
55 | | - "gotocr2", |
56 | | - "gemma3", |
57 | | - "internvl", |
58 | | - "llava", # all llava prefixed models fall under this check |
59 | | - "mistral3", |
60 | | - "mllama", |
61 | | - "paligemma", |
62 | | - "qwen2vl", |
63 | | - "qwen2_5_vl", |
64 | | - "videollava", |
65 | | - "vipllava", |
66 | | -] |
67 | | - |
68 | | - |
69 | 50 | class PeftAdapterMixin: |
70 | 51 | """ |
71 | 52 | A class containing all functions for loading and using adapters weights that are supported in PEFT library. For |
@@ -211,11 +192,10 @@ def load_adapter( |
211 | 192 | if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()): |
212 | 193 | raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.") |
213 | 194 |
|
| 195 | + key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None |
| 196 | + weight_conversions = get_model_conversion_mapping(self, key_mapping=key_mapping) |
214 | 197 | # peft only supports low_cpu_mem_usage starting from v0.13.0 |
215 | 198 | peft_load_kwargs = {} |
216 | | - key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None |
217 | | - if key_mapping is None and any(allowed_name in self.__class__.__name__.lower() for allowed_name in VLMS): |
218 | | - key_mapping = self._checkpoint_conversion_mapping |
219 | 199 | peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
220 | 200 |
|
221 | 201 | adapter_name = adapter_name if adapter_name is not None else "default" |
@@ -292,8 +272,8 @@ def load_adapter( |
292 | 272 |
|
293 | 273 | # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility |
294 | 274 | renamings = [] |
295 | | - if key_mapping: |
296 | | - renamings = [entry for entry in key_mapping if isinstance(entry, WeightRenaming)] |
| 275 | + if weight_conversions: |
| 276 | + renamings = [entry for entry in weight_conversions if isinstance(entry, WeightRenaming)] |
297 | 277 | processed_adapter_state_dict = {} |
298 | 278 | prefix = "base_model.model." |
299 | 279 | state_dict = self.state_dict() |
|
0 commit comments