Skip to content

Commit fa36271

Browse files
BenjaminBossanleaderofARS
authored andcommitted
FIX Error when trying to load non-LoRA PEFT (huggingface#42663)
* FIX Error when trying to load non-LoRA PEFT This PR fixes a bug that prevented non-LoRA PEFT adapters to be loaded into a transformers model. A test for this has been added. Additionally, also testing if a non-LoRA adapter can be added to a transformers model. This was not broken but still lacked test coverage. * Reviewer feedback: Remove check completely
1 parent 7d49fb9 commit fa36271

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

src/transformers/integrations/peft.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,6 @@ def load_adapter(
279279
)
280280
peft_config.inference_mode = not is_trainable
281281

282-
if peft_config.peft_type != PeftType.LORA:
283-
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
284-
285282
if not hotswap:
286283
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
287284
# Create and add fresh new adapters into the model, unless the weights are hotswapped

tests/peft_integration/test_peft_integration.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,60 @@ def test_peft_pipeline_no_warning(self):
889889
# Generate text to verify pipeline works
890890
_ = lora_generator(text, max_new_tokens=20)
891891

892+
def test_non_lora_load_adapter(self):
893+
"""
894+
Check that loading a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
895+
"""
896+
from peft import LoKrConfig, get_peft_model
897+
898+
inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
899+
atol, rtol = 1e-4, 1e-4
900+
901+
for model_id in self.transformers_test_model_ids:
902+
for transformers_class in self.transformers_test_model_classes:
903+
model = transformers_class.from_pretrained(model_id).to(torch_device)
904+
with torch.inference_mode():
905+
output_base = model(inputs).logits
906+
907+
peft_config = LoKrConfig(init_weights=False)
908+
peft_model = get_peft_model(model, peft_config)
909+
with torch.inference_mode():
910+
output_peft = peft_model(inputs).logits
911+
912+
# sanity check: should be different
913+
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)
914+
915+
with tempfile.TemporaryDirectory() as tmpdirname:
916+
peft_model.save_pretrained(tmpdirname)
917+
del model, peft_model
918+
919+
model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
920+
with torch.inference_mode():
921+
output_transformers = model(inputs).logits
922+
self.assertTrue(torch.allclose(output_peft, output_transformers, atol=atol, rtol=rtol))
923+
924+
def test_non_lora_add_adapter(self):
925+
"""
926+
Check that adding a non-LoRA adapter works. Using LoKr as an example, not testing all possible PEFT methods.
927+
"""
928+
from peft import LoKrConfig
929+
930+
inputs = torch.randint(0, 100, (1, 10)).to(torch_device)
931+
atol, rtol = 1e-4, 1e-4
932+
933+
for model_id in self.transformers_test_model_ids:
934+
for transformers_class in self.transformers_test_model_classes:
935+
model = transformers_class.from_pretrained(model_id).to(torch_device)
936+
with torch.inference_mode():
937+
output_base = model(inputs).logits
938+
939+
peft_config = LoKrConfig(init_weights=False)
940+
model.add_adapter(peft_config)
941+
with torch.inference_mode():
942+
output_peft = model(inputs).logits
943+
# should be different
944+
assert not torch.allclose(output_base, output_peft, atol=atol, rtol=rtol)
945+
892946

893947
@require_peft
894948
@require_torch

0 commit comments

Comments
 (0)