Skip to content
14 changes: 13 additions & 1 deletion src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ContextParallelModelPlan,
ContextParallelOutput,
)
from ..models._ulysses_anything_utils import PartitionAnythingSharder
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -208,6 +209,10 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
)
return x
else:
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)


Expand All @@ -233,7 +238,14 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
if self.parallel_config.ulysses_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)

return output[0] if is_tensor else tuple(output)

Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False

_rank: int = None
_world_size: int = None
Expand Down
274 changes: 274 additions & 0 deletions src/diffusers/models/_ulysses_anything_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from: https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/attention/_templated_ulysses.py
import copy
import functools
from typing import Callable, List, Tuple

import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.nn.functional as F

from ..utils.torch_utils import maybe_allow_in_graph


# Helper functions for shape gathering
def _get_rank_world_size(group: dist.ProcessGroup) -> Tuple[int, int]:
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
return rank, world_size


@functools.lru_cache(maxsize=128)
def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
r"""Gather the local size from all ranks.
size: int, local size return: List[int], list of size from all ranks
"""
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)

gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes


# Helper functions to pad/unpad head dimension for QKV and O projections
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
"""
_, world_size = _get_rank_world_size(group)
H_PAD = 0
if H % world_size != 0:
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
"""
rank, world_size = _get_rank_world_size(group)
# Only the last rank may have padding
if H_PAD > 0 and rank == world_size - 1:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
"""
if H is None:
return x, 0

rank, world_size = _get_rank_world_size(group)
H_PAD = 0
# Only the last rank may need padding
if H % world_size != 0:
# We need to broadcast H_PAD to all ranks to keep consistency
# in unpadding step later for all ranks.
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
if rank == world_size - 1:
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD


def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
"""
if H_PAD > 0:
x = x[:, :, :-H_PAD, :]
return x.contiguous()


# Helper functions to for all-to-all communication with Ulysses Anything Attention
def _wait_tensor(tensor) -> torch.Tensor:
if isinstance(tensor, fc.AsyncCollectiveTensor):
tensor = tensor.wait()

return tensor


def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
# query: (B, S_LOCAL, H_GLOBAL, D)
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
extra_kwargs = {}
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
# Add other kwargs if needed in future
return extra_kwargs


@maybe_allow_in_graph
def all_to_all_single_any_qkv_async(
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
"""
_, world_size = _get_rank_world_size(group)
B, S_LOCAL, H, D = x.shape
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
H_LOCAL = (H + H_PAD) // world_size
# (world_size, S_LOCAL, B, H_LOCAL, D)
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()

input_split_sizes = [S_LOCAL] * world_size
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
# since we don't know the actual shape before this timing, thus,
# we have to use all gather to collect the S_LOCAL first.
output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
# (S_GLOBAL, B, H_LOCAL, D)
# -> (B, S_GLOBAL, H_LOCAL, D)
x = x.permute(1, 0, 2, 3).contiguous()
x = _maybe_unpad_qkv_head(x, H_PAD, group)
return x

return wait


@maybe_allow_in_graph
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
"""
# Assume H is provided in kwargs, since we can't infer H from x's shape.
# The padding logic needs H to determine if padding is necessary.
H = kwargs.get("NUM_QO_HEAD", None)
rank, world_size = _get_rank_world_size(group)
x, H_PAD = _maybe_pad_o_head(x, H, group)
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
(B, S_GLOBAL, H_LOCAL, D) = shape

# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]

# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
# b.tensor_split(4)[0].shape[1])

S_LOCAL = kwargs.get("Q_S_LOCAL")
input_split_sizes = _gather_size_by_comm(S_LOCAL, group)
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
output_split_sizes = [S_LOCAL] * world_size
x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)

def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
x = x.permute(2, 1, 0, 3, 4).contiguous()
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
x = _maybe_unpad_o_head(x, H_PAD, group)
return x

return wait


@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
gather_shapes = []
for i in range(world_size):
rank_shape = list(copy.deepcopy(shape))
rank_shape[dim] = gather_dims[i]
gather_shapes.append(rank_shape)
return gather_shapes


@maybe_allow_in_graph
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
_, world_size = _get_rank_world_size(group)
tensor = tensor.contiguous()
shape = tensor.shape
rank_dim = shape[dim]
gather_dims = _gather_size_by_comm(rank_dim, group)

gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)

gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]

dist.all_gather(gathered_tensors, tensor, group=group)
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
return gathered_tensor


class AllGatherAnythingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
ctx.dim = dim
ctx.group = group
ctx.world_size = dist.get_world_size(group)
ctx.rank = dist.get_rank(group)
gathered_tensor = _all_gather_anything(tensor, dim, group)
return gathered_tensor

@staticmethod
def backward(ctx, grad_output):
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
return grad_splits[ctx.rank], None, None


class PartitionAnythingSharder:
@classmethod
def shard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
assert tensor.size()[dim] >= mesh.size(), (
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
)
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]

@classmethod
def unshard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
return tensor
Loading