Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 2 additions & 19 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This yields 21 LoC deletions. We have this pattern in about 32 files. So, this amounts for a 672 deletions. Not bad, IMO.

from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
Expand Down Expand Up @@ -634,6 +634,7 @@ def __init__(

self.gradient_checkpointing = False

@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -675,20 +676,6 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)

hidden_states = self.x_embedder(hidden_states)

Expand Down Expand Up @@ -785,10 +772,6 @@ def forward(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (output,)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,
Expand Down
54 changes: 54 additions & 0 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

import collections
import functools
import importlib
from typing import Optional

Expand Down Expand Up @@ -275,6 +276,59 @@ def get_module_weight(weight_for_adapter, module_name):
module.set_scale(adapter_name, get_module_weight(weight, module_name))


def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.

This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.

Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""

def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND

lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)

# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)

try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)

return wrapper

return decorator


def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
Expand Down