-
Notifications
You must be signed in to change notification settings - Fork 6.7k
feat: support Ulysses Anything Attention #12996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for starting this work!
Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.
Now, we can enable ulysses anything attention through the ContextParallelConfig: pipe.transformer.enable_parallelism(
config=ContextParallelConfig(
ulysses_degree=world_size,
ulysses_anything=True,
)
) |
|
Let me prepare more test cases for ulysses anything. |
FLUX.1-dev, 1008x1008import os
import argparse
import time
import torch
import torch.distributed as dist
from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
ContextParallelConfig,
PipelineQuantizationConfig,
)
def parse_args():
parser = argparse.ArgumentParser(description="Context Parallelism")
# torch.compile flags
parser.add_argument(
"--compile",
action="store_true",
help="Enable torch.compile for the pipeline if set.",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Enable quantization for the pipeline if set.",
)
parser.add_argument(
"--ulysses-anything",
action="store_true",
help="Enable debug mode if set.",
)
# height and width
parser.add_argument(
"--height",
type=int,
default=None,
help="Height of the generated image.",
)
parser.add_argument(
"--width",
type=int,
default=None,
help="Width of the generated image.",
)
return parser.parse_args()
args = parse_args()
print(args)
if dist.is_available():
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
rank = dist.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
world_size = dist.get_world_size()
torch.cuda.set_device(device)
else:
rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
world_size = 1
pipe: FluxPipeline = FluxPipeline.from_pretrained(
os.environ.get(
"FLUX_DIR",
"black-forest-labs/FLUX.1-dev",
),
torch_dtype=torch.bfloat16,
quantization_config=(
PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["text_encoder_2"],
)
if args.quantize
else None
),
).to("cuda")
assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
pipe.transformer.enable_parallelism(
config=ContextParallelConfig(
ulysses_degree=world_size,
ulysses_anything=args.ulysses_anything,
)
)
pipe.set_progress_bar_config(disable=rank != 0)
# Set default prompt
prompt = "A cat holding a sign that says hello world"
height = 1008 if args.height is None else args.height
width = 1008 if args.width is None else args.width
def run_pipe(pipe: FluxPipeline):
image = pipe(
prompt,
height=height,
width=width,
num_inference_steps=28,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
return image
if args.compile:
torch._dynamo.config.recompile_limit = 256
torch._dynamo.config.accumulated_recompile_limit = 8096
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer)
# warmup
_ = run_pipe(pipe)
start = time.time()
image = run_pipe(pipe)
end = time.time()
if rank == 0:
time_cost = end - start
save_path = f"flux.{height}x{width}_ulysses{world_size}.png"
print(f"Time cost: {time_cost:.2f}s")
print(f"Saving image to {save_path}")
image.save(save_path)
if dist.is_initialized():
dist.destroy_process_group()test cmds: torchrun --nproc_per_node=1 test_flux.py # baseline
torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything # working as expected
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expectedbefore this pr: torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
...
[rank1]: File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 158, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 216, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 273, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh sizeafter this pr: torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected
...
|████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.35it/s]
Time cost: 12.58s
Saving image to flux.1008x1008_ulysses2_compile1.png
@sayakpaul I can also provide a case to demonstrate support for any head number via z‑image (e.g., head num = 30 with Ulysses Anything‑4) in a separate PR after this one is ready. For now, you can quickly try it using the examples in cache‑dit. cd cache-dit/exmaples
torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py zimage --parallel ulysses --ulysses-anythinglogs: INFO 01-20 03:07:50 [base.py:622] ----------------------------------------------------------------------------------------------------
INFO 01-20 03:07:50 [base.py:395] 🤖 Example Init Config Summary:
INFO 01-20 03:07:50 [base.py:418] - Model: /workspace/dev/vipdev/hf_models/Z-Image-Turbo
INFO 01-20 03:07:50 [base.py:418] - Task Type: T2I - Text to Image
INFO 01-20 03:07:50 [base.py:418] - Torch Dtype: torch.bfloat16
INFO 01-20 03:07:50 [base.py:418] - LoRA Weights: None
INFO 01-20 03:07:50 [base.py:212] 🤖 Example Input Summary:
INFO 01-20 03:07:50 [base.py:212] - prompt: Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.
INFO 01-20 03:07:50 [base.py:212] - height: 1024
INFO 01-20 03:07:50 [base.py:212] - width: 1024
INFO 01-20 03:07:50 [base.py:212] - guidance_scale: 0.0
INFO 01-20 03:07:50 [base.py:212] - num_inference_steps: 9
INFO 01-20 03:07:50 [base.py:212] - generator: device cpu, seed 0
INFO 01-20 03:07:50 [base.py:307] 🤖 Example Output Summary:
INFO 01-20 03:07:50 [base.py:323] - Model: zimage
INFO 01-20 03:07:50 [base.py:323] - Optimization: C0_Q0_NONE_Ulysses4_ulysses_anything
INFO 01-20 03:07:50 [base.py:323] - Device: NVIDIA L20 x 4
INFO 01-20 03:07:50 [base.py:323] - Load Time: 12.48s
INFO 01-20 03:07:50 [base.py:323] - Warmup Time: 3.08s
INFO 01-20 03:07:50 [base.py:323] - Inference Time: 2.36s
INFO 01-20 03:07:50 [base.py:246] Image saved to zimage.1024x1024.C0_Q0_NONE_Ulysses4_ulysses_anything.png
INFO 01-20 03:07:50 [base.py:633] ---------------------------------------------------------------------------------------------------- |
|
@sayakpaul Hi~ can you take a look to the latest updates? thanks~ |



fixed #12706, this pr implement ulysses anything attention for diffusers in order to support [ANY] sequence lengths and [ANY] head num for ulysses.
@sayakpaul @DN6 @yiyixuxu
About Ulysses Anything Attention
Please refer to our docs for more details. link: https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention
Test
Qwen-Image && Qwen-Image-2512
test cmds:
before this pr:
after this pr: