From dd5642917e0306cfc5b6764b2eabba2c3c36ce77 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 08:12:38 +0000 Subject: [PATCH 1/6] feat: support Ulysses Anything Attention --- src/diffusers/hooks/context_parallel.py | 8 + src/diffusers/models/_ulysses_anything.py | 403 +++++++++++++++++++++ src/diffusers/models/attention_dispatch.py | 48 ++- src/diffusers/utils/constants.py | 1 + 4 files changed, 445 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/models/_ulysses_anything.py diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..965879f3180b 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -28,7 +28,9 @@ ContextParallelModelPlan, ContextParallelOutput, ) +from ..models._ulysses_anything import PartitionAnythingSharder from ..utils import get_logger +from ..utils.constants import DIFFUSERS_ULYSSES_ANYTHING from ..utils.torch_utils import unwrap_module from .hooks import HookRegistry, ModelHook @@ -256,6 +258,9 @@ def backward(ctx, grad_output): class EquipartitionSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + if DIFFUSERS_ULYSSES_ANYTHING: + return PartitionAnythingSharder.shard_anything(tensor, dim, mesh) + # NOTE: the following assertion does not have to be true in general. We simply enforce it for now # because the alternate case has not yet been tested/required for any model. assert tensor.size()[dim] % mesh.size() == 0, ( @@ -269,6 +274,9 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: + if DIFFUSERS_ULYSSES_ANYTHING: + return PartitionAnythingSharder.unshard_anything(tensor, dim, mesh) + tensor = tensor.contiguous() tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py new file mode 100644 index 000000000000..2aaf573b8166 --- /dev/null +++ b/src/diffusers/models/_ulysses_anything.py @@ -0,0 +1,403 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py +import copy +import functools +from typing import Callable, List, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as fc +import torch.nn.functional as F + +from diffusers.models._modeling_parallel import ParallelConfig + + +def _wait_tensor(tensor) -> torch.Tensor: + if isinstance(tensor, fc.AsyncCollectiveTensor): + tensor = tensor.wait() + + return tensor + + +def _get_rank_world_size( + group: dist.ProcessGroup, +) -> Tuple[int, int]: + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + return rank, world_size + + +@functools.lru_cache(maxsize=128) +def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: + r"""Gather the local size from all ranks. + size: int, local size return: List[int], list of size from all ranks + """ + world_size = dist.get_world_size(group=group) + # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead + comm_backends = str(dist.get_backend(group=group)) + # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") + gather_device = "cpu" if "cpu" in comm_backends else "cuda" + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] + dist.all_gather( + gathered_sizes, + torch.tensor([size], device=gather_device, dtype=torch.int64), + group=group, + ) + + gathered_sizes = [s[0].item() for s in gathered_sizes] + # NOTE: DON'T use tolist here due to graph break - Explanation: + # Backend compiler `inductor` failed with aten._local_scalar_dense.default + return gathered_sizes + + +# Helper functions to pad/unpad head dimension for QKV and O projections +def _maybe_pad_qkv_head( + x: torch.Tensor, + H: int, + group: dist.ProcessGroup, +) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded + tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD + """ + _, world_size = _get_rank_world_size(group) + H_PAD = 0 + if H % world_size != 0: + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2. + # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14. + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_qkv_head( + x: torch.Tensor, + H_PAD: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_GLOBAL, H_LOCAL, D) + """ + rank, world_size = _get_rank_world_size(group) + # Only the last rank may have padding + if H_PAD > 0 and rank == world_size - 1: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +def _maybe_pad_o_head( + x: torch.Tensor, + H: int, + group: dist.ProcessGroup, +) -> Tuple[torch.Tensor, int]: + r"""Maybe pad the head dimension to be divisible by world_size. + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], + padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD + """ + if H is None: + return x, 0 + + rank, world_size = _get_rank_world_size(group) + H_PAD = 0 + # Only the last rank may need padding + if H % world_size != 0: + # We need to broadcast H_PAD to all ranks to keep consistency + # in unpadding step later for all ranks. + H_PAD = world_size - (H % world_size) + NEW_H_LOCAL = (H + H_PAD) // world_size + assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}" + if rank == world_size - 1: + x = F.pad(x, (0, 0, 0, H_PAD)).contiguous() + return x, H_PAD + + +def _maybe_unpad_o_head( + x: torch.Tensor, + H_PAD: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + r"""Maybe unpad the head dimension. + x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, + unpadded tensor (B, S_LOCAL, H_GLOBAL, D) + """ + if H_PAD > 0: + x = x[:, :, :-H_PAD, :] + return x.contiguous() + + +# Helper functions to for all-to-all communication with Ulysses Anything Attention +def _comm_metadata( + query: torch.Tensor, + **kwargs, +) -> dict: + num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) + extra_kwargs = {} + extra_kwargs["num_qo_head"] = num_qo_head + # May ddd other kwargs if needed in future + return extra_kwargs + + +@torch.compiler.allow_in_graph +def _all_to_all_single_any_qkv_async( + x: torch.Tensor, + group: dist.ProcessGroup, + **kwargs, +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) + """ + _, world_size = _get_rank_world_size(group) + B, S_LOCAL, H, D = x.shape + x, H_PAD = _maybe_pad_qkv_head(x, H, group) + H_LOCAL = (H + H_PAD) // world_size + # (world_size, S_LOCAL, B, H_LOCAL, D) + x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + + input_split_sizes = [S_LOCAL] * world_size + # S_LOCAL maybe not equal for all ranks in dynamic shape case, + # since we don't know the actual shape before this timing, thus, + # we have to use all gather to collect the S_LOCAL first. + output_split_sizes = _gather_size_by_comm(S_LOCAL, group) + x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D) + x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + # (S_GLOBAL, B, H_LOCAL, D) + # -> (B, S_GLOBAL, H_LOCAL, D) + x = x.permute(1, 0, 2, 3).contiguous() + x = _maybe_unpad_qkv_head(x, H_PAD, group) + return x + + return wait + + +@torch.compiler.allow_in_graph +def _all_to_all_single_any_o_async( + x: torch.Tensor, + group: dist.ProcessGroup, + **kwargs, +) -> Callable[..., torch.Tensor]: + r""" + x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) + """ + # Assume H is provided in kwargs, since we can't infer H from x's shape. + # The padding logic needs H to determine if padding is necessary. + H = kwargs.get("num_qo_head", None) + rank, world_size = _get_rank_world_size(group) + x, H_PAD = _maybe_pad_o_head(x, H, group) + shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) + (B, S_GLOBAL, H_LOCAL, D) = shape + # NOTE: We use tensor_split here to ensure the same split policy + # that we have used in the EquipartitionSharder sharding strategy. Please + # note that the 'tensor_split' splits a tensor into multiple sub-tensors, + # all of which are views of input, thus may not introduce extra IO access. + input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] + # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] + S_LOCAL = input_split_sizes[rank] + x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) + output_split_sizes = [S_LOCAL] * world_size + x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) + + def wait() -> torch.Tensor: + nonlocal x, H_PAD + x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D) + x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D) + x = x.permute(2, 1, 0, 3, 4).contiguous() + x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D) + x = _maybe_unpad_o_head(x, H_PAD, group) + return x + + return wait + + +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = _comm_metadata(query) + query_wait = _all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = _all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = _all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + +@functools.lru_cache(maxsize=64) +def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: + gather_shapes = [] + for i in range(world_size): + rank_shape = list(copy.deepcopy(shape)) + rank_shape[dim] = gather_dims[i] + gather_shapes.append(rank_shape) + return gather_shapes + + +@torch.compiler.allow_in_graph +def _all_gather_anything( # noqa: F811 + tensor: torch.Tensor, + dim: int, + group: dist.device_mesh.DeviceMesh, +) -> torch.Tensor: + _, world_size = _get_rank_world_size(group) + tensor = tensor.contiguous() + shape = tensor.shape + rank_dim = shape[dim] + gather_dims = _gather_size_by_comm(rank_dim, group) + + gather_shapes = _fill_gather_shapes( + tuple(shape), + tuple(gather_dims), + dim, + world_size, + ) + + gathered_tensors = [ + torch.empty( + shape, + device=tensor.device, + dtype=tensor.dtype, + ) + for shape in gather_shapes + ] + + dist.all_gather(gathered_tensors, tensor, group=group) + gathered_tensor = torch.cat(gathered_tensors, dim=dim) + return gathered_tensor + + +class AllGatherAnythingFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + dim: int, + group: dist.device_mesh.DeviceMesh, + ): + ctx.dim = dim + ctx.group = group + ctx.world_size = dist.get_world_size(group) + ctx.rank = dist.get_rank(group) + gathered_tensor = _all_gather_anything(tensor, dim, group) + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! + grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim) + return grad_splits[ctx.rank], None, None + + +class PartitionAnythingSharder: + @classmethod + def shard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + assert tensor.size()[dim] >= mesh.size(), ( + f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." + ) + # NOTE: We use `tensor_split` instead of chunk, because the `chunk` + # function may return fewer than the specified number of chunks! For example, + # x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks: + # (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to + # inconsistencies when sharding tensors across multiple devices. In contrast, + # tensor_split will always return the specified number of chunks, the last chunk + # may be smaller if the tensor size is not divisible by the number of chunks. + # For example, torch.tensor_split(x, 4) will return 4 chunks: + # (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])). + return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] + + @classmethod + def unshard_anything( + cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh + ) -> torch.Tensor: + tensor = tensor.contiguous() + # NOTE: We use AllGatherAnythingFunction to support gathering + # tensors with complex and uneven sizes across all ranks. It handles the + # case where the tensor size (the seq_len of hidden_states) along the + # specified dimension is not divisible by the number of ranks in the mesh. + tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) + return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61c478b03c4f..b76283eb1cfc 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -43,7 +43,8 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ULYSSES_ANYTHING +from ._ulysses_anything import TemplatedUlyssesAnythingAttention if TYPE_CHECKING: @@ -1618,20 +1619,37 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) + if DIFFUSERS_ULYSSES_ANYTHING: + # For Any sequence lengths and Any head num support + return TemplatedUlyssesAnythingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + return TemplatedUlyssesAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) else: raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index c46fa4363483..b9407d8945dd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,6 +46,7 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES +DIFFUSERS_ULYSSES_ANYTHING = os.getenv("DIFFUSERS_ULYSSES_ANYTHING", "0").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are From 123f5264bbdd2163cf224194365e0d931f1a325c Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 09:16:33 +0000 Subject: [PATCH 2/6] feat: support Ulysses Anything Attention --- src/diffusers/models/_ulysses_anything.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py index 2aaf573b8166..3517c3ede4d4 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything.py @@ -22,8 +22,8 @@ import torch.distributed._functional_collectives as fc import torch.nn.functional as F -from diffusers.models._modeling_parallel import ParallelConfig - +from ..utils.torch_utils import maybe_allow_in_graph +from ._modeling_parallel import ParallelConfig def _wait_tensor(tensor) -> torch.Tensor: if isinstance(tensor, fc.AsyncCollectiveTensor): @@ -153,7 +153,7 @@ def _comm_metadata( return extra_kwargs -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_to_all_single_any_qkv_async( x: torch.Tensor, group: dist.ProcessGroup, @@ -189,7 +189,7 @@ def wait() -> torch.Tensor: return wait -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_to_all_single_any_o_async( x: torch.Tensor, group: dist.ProcessGroup, @@ -315,7 +315,7 @@ def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, wo return gather_shapes -@torch.compiler.allow_in_graph +@maybe_allow_in_graph def _all_gather_anything( # noqa: F811 tensor: torch.Tensor, dim: int, From af9af62d2d1e64eda8df5b7b95687f3840f46932 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Mon, 19 Jan 2026 09:17:11 +0000 Subject: [PATCH 3/6] feat: support Ulysses Anything Attention --- src/diffusers/models/_ulysses_anything.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything.py index 3517c3ede4d4..a5594ec8cbd5 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything.py @@ -25,6 +25,7 @@ from ..utils.torch_utils import maybe_allow_in_graph from ._modeling_parallel import ParallelConfig + def _wait_tensor(tensor) -> torch.Tensor: if isinstance(tensor, fc.AsyncCollectiveTensor): tensor = tensor.wait() From b4d3f077f47325ccd949fbae68fa8cfdf0c5a446 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 20 Jan 2026 02:24:39 +0000 Subject: [PATCH 4/6] feat: support Ulysses Anything Attention --- src/diffusers/hooks/context_parallel.py | 22 ++- src/diffusers/models/_modeling_parallel.py | 3 + ...anything.py => _ulysses_anything_utils.py} | 187 +++--------------- src/diffusers/models/attention_dispatch.py | 86 +++++++- src/diffusers/utils/constants.py | 1 - 5 files changed, 124 insertions(+), 175 deletions(-) rename src/diffusers/models/{_ulysses_anything.py => _ulysses_anything_utils.py} (65%) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 965879f3180b..53e2b53d986e 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -28,9 +28,8 @@ ContextParallelModelPlan, ContextParallelOutput, ) -from ..models._ulysses_anything import PartitionAnythingSharder +from ..models._ulysses_anything_utils import PartitionAnythingSharder from ..utils import get_logger -from ..utils.constants import DIFFUSERS_ULYSSES_ANYTHING from ..utils.torch_utils import unwrap_module from .hooks import HookRegistry, ModelHook @@ -210,6 +209,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> ) return x else: + if self.parallel_config.ulysses_anything: + return PartitionAnythingSharder.shard_anything( + x, cp_input.split_dim, self.parallel_config._flattened_mesh + ) return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) @@ -235,7 +238,14 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + if self.parallel_config.ulysses_anything: + output[i] = PartitionAnythingSharder.unshard_anything( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) + else: + output[i] = EquipartitionSharder.unshard( + output[i], cpm.gather_dim, self.parallel_config._flattened_mesh + ) return output[0] if is_tensor else tuple(output) @@ -258,9 +268,6 @@ def backward(ctx, grad_output): class EquipartitionSharder: @classmethod def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: - if DIFFUSERS_ULYSSES_ANYTHING: - return PartitionAnythingSharder.shard_anything(tensor, dim, mesh) - # NOTE: the following assertion does not have to be true in general. We simply enforce it for now # because the alternate case has not yet been tested/required for any model. assert tensor.size()[dim] % mesh.size() == 0, ( @@ -274,9 +281,6 @@ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_me @classmethod def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: - if DIFFUSERS_ULYSSES_ANYTHING: - return PartitionAnythingSharder.unshard_anything(tensor, dim, mesh) - tensor = tensor.contiguous() tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 1c7703a13c52..f301ba771cc7 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -67,6 +67,9 @@ class ContextParallelConfig: convert_to_fp32: bool = True # TODO: support alltoall rotate_method: Literal["allgather", "alltoall"] = "allgather" + # Whether to enable ulysses anything attention to support + # any sequence lengths and any head numbers. + ulysses_anything: bool = False _rank: int = None _world_size: int = None diff --git a/src/diffusers/models/_ulysses_anything.py b/src/diffusers/models/_ulysses_anything_utils.py similarity index 65% rename from src/diffusers/models/_ulysses_anything.py rename to src/diffusers/models/_ulysses_anything_utils.py index a5594ec8cbd5..fd85f374d0f5 100644 --- a/src/diffusers/models/_ulysses_anything.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -15,7 +15,7 @@ # Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py import copy import functools -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Tuple import torch import torch.distributed as dist @@ -23,19 +23,10 @@ import torch.nn.functional as F from ..utils.torch_utils import maybe_allow_in_graph -from ._modeling_parallel import ParallelConfig -def _wait_tensor(tensor) -> torch.Tensor: - if isinstance(tensor, fc.AsyncCollectiveTensor): - tensor = tensor.wait() - - return tensor - - -def _get_rank_world_size( - group: dist.ProcessGroup, -) -> Tuple[int, int]: +# Helper functions for shape gathering +def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) return rank, world_size @@ -50,7 +41,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead comm_backends = str(dist.get_backend(group=group)) # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") - gather_device = "cpu" if "cpu" in comm_backends else "cuda" + gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] dist.all_gather( gathered_sizes, @@ -65,11 +56,7 @@ def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: # Helper functions to pad/unpad head dimension for QKV and O projections -def _maybe_pad_qkv_head( - x: torch.Tensor, - H: int, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, int]: +def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD @@ -86,11 +73,7 @@ def _maybe_pad_qkv_head( return x, H_PAD -def _maybe_unpad_qkv_head( - x: torch.Tensor, - H_PAD: int, - group: dist.ProcessGroup, -) -> torch.Tensor: +def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_GLOBAL, H_LOCAL, D) @@ -102,11 +85,7 @@ def _maybe_unpad_qkv_head( return x.contiguous() -def _maybe_pad_o_head( - x: torch.Tensor, - H: int, - group: dist.ProcessGroup, -) -> Tuple[torch.Tensor, int]: +def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]: r"""Maybe pad the head dimension to be divisible by world_size. x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD @@ -128,11 +107,7 @@ def _maybe_pad_o_head( return x, H_PAD -def _maybe_unpad_o_head( - x: torch.Tensor, - H_PAD: int, - group: dist.ProcessGroup, -) -> torch.Tensor: +def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor: r"""Maybe unpad the head dimension. x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor, unpadded tensor (B, S_LOCAL, H_GLOBAL, D) @@ -143,10 +118,14 @@ def _maybe_unpad_o_head( # Helper functions to for all-to-all communication with Ulysses Anything Attention -def _comm_metadata( - query: torch.Tensor, - **kwargs, -) -> dict: +def _wait_tensor(tensor) -> torch.Tensor: + if isinstance(tensor, fc.AsyncCollectiveTensor): + tensor = tensor.wait() + + return tensor + + +def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) extra_kwargs = {} extra_kwargs["num_qo_head"] = num_qo_head @@ -155,10 +134,8 @@ def _comm_metadata( @maybe_allow_in_graph -def _all_to_all_single_any_qkv_async( - x: torch.Tensor, - group: dist.ProcessGroup, - **kwargs, +def all_to_all_single_any_qkv_async( + x: torch.Tensor, group: dist.ProcessGroup, **kwargs ) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D) @@ -191,11 +168,7 @@ def wait() -> torch.Tensor: @maybe_allow_in_graph -def _all_to_all_single_any_o_async( - x: torch.Tensor, - group: dist.ProcessGroup, - **kwargs, -) -> Callable[..., torch.Tensor]: +def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]: r""" x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D) """ @@ -207,9 +180,7 @@ def _all_to_all_single_any_o_async( shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) (B, S_GLOBAL, H_LOCAL, D) = shape # NOTE: We use tensor_split here to ensure the same split policy - # that we have used in the EquipartitionSharder sharding strategy. Please - # note that the 'tensor_split' splits a tensor into multiple sub-tensors, - # all of which are views of input, thus may not introduce extra IO access. + # that we have used in the EquipartitionSharder sharding strategy. input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] @@ -230,82 +201,6 @@ def wait() -> torch.Tensor: return wait -class TemplatedUlyssesAnythingAttention(torch.autograd.Function): - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - dropout_p: float, - is_causal: bool, - scale: Optional[float], - enable_gqa: bool, - return_lse: bool, - forward_op, - backward_op, - _parallel_config: Optional["ParallelConfig"] = None, - **kwargs, - ): - ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh - group = ulysses_mesh.get_group() - - ctx.forward_op = forward_op - ctx.backward_op = backward_op - ctx._parallel_config = _parallel_config - - metadata = _comm_metadata(query) - query_wait = _all_to_all_single_any_qkv_async(query, group, **metadata) - key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata) - value_wait = _all_to_all_single_any_qkv_async(value, group, **metadata) - - query = query_wait() # type: torch.Tensor - key = key_wait() # type: torch.Tensor - value = value_wait() # type: torch.Tensor - - out = forward_op( - ctx, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - _save_ctx=False, - _parallel_config=_parallel_config, - ) - if return_lse: - out, lse, *_ = out - - # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) - out_wait = _all_to_all_single_any_o_async(out, group, **metadata) - - if return_lse: - # lse: (B, S_Q_GLOBAL, H_LOCAL) - lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) - lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata) - out = out_wait() # type: torch.Tensor - lse = lse_wait() # type: torch.Tensor - lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) - else: - out = out_wait() # type: torch.Tensor - lse = None - - return (out, lse) if return_lse else out - - @staticmethod - def backward( - ctx: torch.autograd.function.FunctionCtx, - grad_out: torch.Tensor, - *args, - ): - raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") - - @functools.lru_cache(maxsize=64) def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]: gather_shapes = [] @@ -317,32 +212,16 @@ def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, wo @maybe_allow_in_graph -def _all_gather_anything( # noqa: F811 - tensor: torch.Tensor, - dim: int, - group: dist.device_mesh.DeviceMesh, -) -> torch.Tensor: +def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor: _, world_size = _get_rank_world_size(group) tensor = tensor.contiguous() shape = tensor.shape rank_dim = shape[dim] gather_dims = _gather_size_by_comm(rank_dim, group) - gather_shapes = _fill_gather_shapes( - tuple(shape), - tuple(gather_dims), - dim, - world_size, - ) + gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size) - gathered_tensors = [ - torch.empty( - shape, - device=tensor.device, - dtype=tensor.dtype, - ) - for shape in gather_shapes - ] + gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes] dist.all_gather(gathered_tensors, tensor, group=group) gathered_tensor = torch.cat(gathered_tensors, dim=dim) @@ -351,12 +230,7 @@ def _all_gather_anything( # noqa: F811 class AllGatherAnythingFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - dim: int, - group: dist.device_mesh.DeviceMesh, - ): + def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh): ctx.dim = dim ctx.group = group ctx.world_size = dist.get_world_size(group) @@ -381,14 +255,7 @@ def shard_anything( f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}." ) # NOTE: We use `tensor_split` instead of chunk, because the `chunk` - # function may return fewer than the specified number of chunks! For example, - # x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks: - # (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to - # inconsistencies when sharding tensors across multiple devices. In contrast, - # tensor_split will always return the specified number of chunks, the last chunk - # may be smaller if the tensor size is not divisible by the number of chunks. - # For example, torch.tensor_split(x, 4) will return 4 chunks: - # (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])). + # function may return fewer than the specified number of chunks! return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())] @classmethod @@ -396,9 +263,5 @@ def unshard_anything( cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh ) -> torch.Tensor: tensor = tensor.contiguous() - # NOTE: We use AllGatherAnythingFunction to support gathering - # tensors with complex and uneven sizes across all ranks. It handles the - # case where the tensor size (the seq_len of hidden_states) along the - # specified dimension is not divisible by the number of ranks in the mesh. tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group()) return tensor diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index b76283eb1cfc..4cd238dd5502 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -43,8 +43,12 @@ is_xformers_available, is_xformers_version, ) -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ULYSSES_ANYTHING -from ._ulysses_anything import TemplatedUlyssesAnythingAttention +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from ._ulysses_anything_utils import ( + all_to_all_single_any_o_async, + all_to_all_single_any_qkv_async, + ulysses_anything_metadata, +) if TYPE_CHECKING: @@ -1502,6 +1506,82 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None +class TemplatedUlyssesAnythingAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + **kwargs, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + group = ulysses_mesh.get_group() + + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config + + metadata = ulysses_anything_metadata(query) + query_wait = all_to_all_single_any_qkv_async(query, group, **metadata) + key_wait = all_to_all_single_any_qkv_async(key, group, **metadata) + value_wait = all_to_all_single_any_qkv_async(value, group, **metadata) + + query = query_wait() # type: torch.Tensor + key = key_wait() # type: torch.Tensor + value = value_wait() # type: torch.Tensor + + out = forward_op( + ctx, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=False, # ulysses anything only support forward pass now. + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D) + out_wait = all_to_all_single_any_o_async(out, group, **metadata) + + if return_lse: + # lse: (B, S_Q_GLOBAL, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1) + lse_wait = all_to_all_single_any_o_async(lse, group, **metadata) + out = out_wait() # type: torch.Tensor + lse = lse_wait() # type: torch.Tensor + lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL) + else: + out = out_wait() # type: torch.Tensor + lse = None + + return (out, lse) if return_lse else out + + @staticmethod + def backward( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + ): + raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.") + + def _templated_unified_attention( query: torch.Tensor, key: torch.Tensor, @@ -1619,7 +1699,7 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - if DIFFUSERS_ULYSSES_ANYTHING: + if _parallel_config.context_parallel_config.ulysses_anything: # For Any sequence lengths and Any head num support return TemplatedUlyssesAnythingAttention.apply( query, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index b9407d8945dd..c46fa4363483 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -46,7 +46,6 @@ DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_ULYSSES_ANYTHING = os.getenv("DIFFUSERS_ULYSSES_ANYTHING", "0").upper() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are From 403c204ab8d29b5ef6197ebeb46f97c6e923fac6 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 21 Jan 2026 13:19:36 +0000 Subject: [PATCH 5/6] fix UAA broken while using joint attn --- .../models/_ulysses_anything_utils.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py index fd85f374d0f5..c17c0ab84812 100644 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -32,7 +32,6 @@ def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: return rank, world_size -@functools.lru_cache(maxsize=128) def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: r"""Gather the local size from all ranks. size: int, local size return: List[int], list of size from all ranks @@ -126,10 +125,12 @@ def _wait_tensor(tensor) -> torch.Tensor: def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict: - num_qo_head = query.shape[2] # (B, S_LOCAL, H_GLOBAL, D) + # query: (B, S_LOCAL, H_GLOBAL, D) + assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)" extra_kwargs = {} - extra_kwargs["num_qo_head"] = num_qo_head - # May ddd other kwargs if needed in future + extra_kwargs["NUM_QO_HEAD"] = query.shape[2] + extra_kwargs["Q_S_LOCAL"] = query.shape[1] + # Add other kwargs if needed in future return extra_kwargs @@ -174,17 +175,22 @@ def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **k """ # Assume H is provided in kwargs, since we can't infer H from x's shape. # The padding logic needs H to determine if padding is necessary. - H = kwargs.get("num_qo_head", None) + H = kwargs.get("NUM_QO_HEAD", None) rank, world_size = _get_rank_world_size(group) x, H_PAD = _maybe_pad_o_head(x, H, group) shape = x.shape # (B, S_GLOBAL, H_LOCAL, D) (B, S_GLOBAL, H_LOCAL, D) = shape - # NOTE: We use tensor_split here to ensure the same split policy - # that we have used in the EquipartitionSharder sharding strategy. - input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)] + # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..] # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..] - S_LOCAL = input_split_sizes[rank] + + # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer + # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then, + # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] + + # b.tensor_split(4)[0].shape[1]) + + S_LOCAL = kwargs.get("Q_S_LOCAL") + input_split_sizes = _gather_size_by_comm(S_LOCAL, group) x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D) output_split_sizes = [S_LOCAL] * world_size x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group) From 9280e2b6321304e52e76b90db32107d60b0bc1c3 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 21 Jan 2026 13:24:30 +0000 Subject: [PATCH 6/6] update --- src/diffusers/models/_ulysses_anything_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/_ulysses_anything_utils.py b/src/diffusers/models/_ulysses_anything_utils.py index c17c0ab84812..284ba6241b99 100644 --- a/src/diffusers/models/_ulysses_anything_utils.py +++ b/src/diffusers/models/_ulysses_anything_utils.py @@ -32,6 +32,7 @@ def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]: return rank, world_size +@functools.lru_cache(maxsize=128) def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: r"""Gather the local size from all ranks. size: int, local size return: List[int], list of size from all ranks