From 9635441975c60379f2b7f1cf2a1f17ec52ff40ea Mon Sep 17 00:00:00 2001 From: xb Date: Sat, 17 Jan 2026 00:10:29 +0800 Subject: [PATCH] Bugfix Context Parallelism crash when sequence length not divisible by mesh size --- src/diffusers/hooks/context_parallel.py | 3 -- src/diffusers/models/_modeling_parallel.py | 2 + src/diffusers/models/attention_dispatch.py | 61 +++++++++++++++++----- 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..27b09a00e563 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -258,9 +258,6 @@ class EquipartitionSharder: def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor: # NOTE: the following assertion does not have to be true in general. We simply enforce it for now # because the alternate case has not yet been tested/required for any model. - assert tensor.size()[dim] % mesh.size() == 0, ( - "Tensor size along dimension to be sharded must be divisible by mesh size" - ) # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()] diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 1c7703a13c52..f43b3dc0249f 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -77,6 +77,7 @@ class ContextParallelConfig: _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None _ring_local_rank: int = None _ulysses_local_rank: int = None + _pre_allocated_all2all_output_tensor_map: Dict[str, List[torch.Tensor]] = None def __post_init__(self): if self.ring_degree is None: @@ -120,6 +121,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di self._ulysses_mesh = self._mesh["ulysses"] self._ring_local_rank = self._ring_mesh.get_local_rank() self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() + self._pre_allocated_all2all_output_tensor_map = {} @dataclass diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f086c2d42579..1973219f3c68 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -24,6 +24,7 @@ if torch.distributed.is_available(): + import torch.distributed as dist import torch.distributed._functional_collectives as funcol from ..utils import ( @@ -1272,6 +1273,43 @@ def backward(ctx, grad_outputs): ctx.scatter_id, # reversed ) return (None, grad_input, None, None) + + +def _all_to_all(input: torch.Tensor, dim_to_chunk, dim_to_cat, group, tensor_name: str, + pre_allocated_all2all_output_tensor_map: Dict[str, List[torch.Tensor]], + check_if_tensor_map_need_reset: bool = False) -> torch.Tensor: + world_size = dist.get_world_size(group=group) + + input_list = list(torch.chunk(input, world_size, dim=dim_to_chunk)) + for i in range(len(input_list)): + input_list[i] = input_list[i].contiguous() + + key = f"{tensor_name}_{input.shape[dim_to_chunk]}_{input.shape[dim_to_cat]}" + output_tensor_list = pre_allocated_all2all_output_tensor_map.get(key, None) + + if output_tensor_list is None: + if check_if_tensor_map_need_reset: + pre_allocated_all2all_output_tensor_map.clear() + + dtype_ = input.dtype + device_ = input.device + + shape_list_in = [torch.as_tensor(input_list[i].shape, device=device_) for i in range(world_size)] + shape_list_out = [torch.empty_like(shape_list_in[0]) for _ in range(world_size)] + dist.all_to_all(shape_list_out, shape_list_in, group=group) + + output_tensor_list = [ + torch.empty(*_shape, dtype=dtype_, device=device_) + for _shape in shape_list_out + ] + + pre_allocated_all2all_output_tensor_map[key] = output_tensor_list + + dist.all_to_all(output_tensor_list, input_list, group=group) + + output_tensor = torch.cat(output_tensor_list, axis=dim_to_cat) + + return output_tensor class TemplatedRingAttention(torch.autograd.Function): @@ -1426,14 +1464,12 @@ def forward( ctx.backward_op = backward_op ctx._parallel_config = _parallel_config - B, S_Q_LOCAL, H, D = query.shape - _, S_KV_LOCAL, _, _ = key.shape - H_LOCAL = H // world_size - query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() - query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) - query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + query = _all_to_all(query, 2, 1, group, + 'query', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map, True) + key = _all_to_all(key, 2, 1, group, + 'key', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map) + value = _all_to_all(value, 2, 1, group, + 'value', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map) out = forward_op( ctx, @@ -1451,10 +1487,9 @@ def forward( ) if return_lse: out, lse, *_ = out - - out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() - out = _all_to_all_single(out, group) - out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + out = _all_to_all(out, 1, 2, group, + 'out', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map) if return_lse: lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() @@ -1651,8 +1686,6 @@ def _flash_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None - if attn_mask is not None: - raise ValueError("`attn_mask` is not supported for flash-attn 2.") if _parallel_config is None: out = flash_attn_func(