Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,6 @@ class EquipartitionSharder:
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
# NOTE: the following assertion does not have to be true in general. We simply enforce it for now
# because the alternate case has not yet been tested/required for any model.
assert tensor.size()[dim] % mesh.size() == 0, (
"Tensor size along dimension to be sharded must be divisible by mesh size"
)

# The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
# return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/_modeling_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ContextParallelConfig:
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
_ring_local_rank: int = None
_ulysses_local_rank: int = None
_pre_allocated_all2all_output_tensor_map: Dict[str, List[torch.Tensor]] = None

def __post_init__(self):
if self.ring_degree is None:
Expand Down Expand Up @@ -120,6 +121,7 @@ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.di
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._pre_allocated_all2all_output_tensor_map = {}


@dataclass
Expand Down
61 changes: 47 additions & 14 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@


if torch.distributed.is_available():
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol

from ..utils import (
Expand Down Expand Up @@ -1272,6 +1273,43 @@ def backward(ctx, grad_outputs):
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)


def _all_to_all(input: torch.Tensor, dim_to_chunk, dim_to_cat, group, tensor_name: str,
pre_allocated_all2all_output_tensor_map: Dict[str, List[torch.Tensor]],
check_if_tensor_map_need_reset: bool = False) -> torch.Tensor:
world_size = dist.get_world_size(group=group)

input_list = list(torch.chunk(input, world_size, dim=dim_to_chunk))
for i in range(len(input_list)):
input_list[i] = input_list[i].contiguous()

key = f"{tensor_name}_{input.shape[dim_to_chunk]}_{input.shape[dim_to_cat]}"
output_tensor_list = pre_allocated_all2all_output_tensor_map.get(key, None)

if output_tensor_list is None:
if check_if_tensor_map_need_reset:
pre_allocated_all2all_output_tensor_map.clear()

dtype_ = input.dtype
device_ = input.device

shape_list_in = [torch.as_tensor(input_list[i].shape, device=device_) for i in range(world_size)]
shape_list_out = [torch.empty_like(shape_list_in[0]) for _ in range(world_size)]
dist.all_to_all(shape_list_out, shape_list_in, group=group)

output_tensor_list = [
torch.empty(*_shape, dtype=dtype_, device=device_)
for _shape in shape_list_out
]

pre_allocated_all2all_output_tensor_map[key] = output_tensor_list

dist.all_to_all(output_tensor_list, input_list, group=group)

output_tensor = torch.cat(output_tensor_list, axis=dim_to_cat)

return output_tensor


class TemplatedRingAttention(torch.autograd.Function):
Expand Down Expand Up @@ -1426,14 +1464,12 @@ def forward(
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config

B, S_Q_LOCAL, H, D = query.shape
_, S_KV_LOCAL, _, _ = key.shape
H_LOCAL = H // world_size
query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
query = _all_to_all(query, 2, 1, group,
'query', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map, True)
key = _all_to_all(key, 2, 1, group,
'key', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map)
value = _all_to_all(value, 2, 1, group,
'value', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map)

out = forward_op(
ctx,
Expand All @@ -1451,10 +1487,9 @@ def forward(
)
if return_lse:
out, lse, *_ = out

out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()

out = _all_to_all(out, 1, 2, group,
'out', _parallel_config.context_parallel_config._pre_allocated_all2all_output_tensor_map)

if return_lse:
lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
Expand Down Expand Up @@ -1651,8 +1686,6 @@ def _flash_attention(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")

if _parallel_config is None:
out = flash_attn_func(
Expand Down