From afa4a23c6c25d81296aced4f76448bbd09284702 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 19 Jan 2026 10:04:24 +0530 Subject: [PATCH] feat: implement apply_lora_scale to remove boilerplate. --- .../models/transformers/transformer_flux.py | 21 +------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/peft_utils.py | 54 +++++++++++++++++++ 3 files changed, 57 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a4464432425..f6bcaa6735a9 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -634,6 +634,7 @@ def __init__( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -675,20 +676,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) @@ -785,10 +772,6 @@ def forward( hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index e726bbb46913..2c9f4c995001 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -130,6 +130,7 @@ from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( + apply_lora_scale, check_peft_version, delete_adapter_layers, get_adapter_name, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 12066ee3f89b..d6f5bcde2d97 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -16,6 +16,7 @@ """ import collections +import functools import importlib from typing import Optional @@ -275,6 +276,59 @@ def get_module_weight(weight_for_adapter, module_name): module.set_scale(adapter_name, get_module_weight(weight, module_name)) +def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"): + """ + Decorator to automatically handle LoRA layer scaling/unscaling in forward methods. + + This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward + pass, and ensures unscaling happens after, even if an exception occurs. + + Args: + kwargs_name (`str`, defaults to `"joint_attention_kwargs"`): + The name of the keyword argument that contains the LoRA scale. Common values include + "joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc. + """ + + def decorator(forward_fn): + @functools.wraps(forward_fn) + def wrapper(self, *args, **kwargs): + from . import USE_PEFT_BACKEND + + lora_scale = 1.0 + attention_kwargs = kwargs.get(kwargs_name) + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + kwargs[kwargs_name] = attention_kwargs + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + if ( + not USE_PEFT_BACKEND + and attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective." + ) + + # Apply LoRA scaling if using PEFT backend + if USE_PEFT_BACKEND: + scale_lora_layers(self, lora_scale) + + try: + # Execute the forward pass + result = forward_fn(self, *args, **kwargs) + return result + finally: + # Always unscale, even if forward pass raises an exception + if USE_PEFT_BACKEND: + unscale_lora_layers(self, lora_scale) + + return wrapper + + return decorator + + def check_peft_version(min_version: str) -> None: r""" Checks if the version of PEFT is compatible.