Skip to content

Conversation

@DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Jan 19, 2026

fixed #12706, this pr implement ulysses anything attention for diffusers in order to support [ANY] sequence lengths and [ANY] head num for ulysses.

  • Support any sequence lengths
  • Support any head num (e.g, Z-Image, head num = 30)
  • NO extra padding while sequence length is not divisible by the number of devices
  • NO loss of precision

@sayakpaul @DN6 @yiyixuxu

About Ulysses Anything Attention

Please refer to our docs for more details. link: https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention

Test

Qwen-Image && Qwen-Image-2512

import torch
import argparse
from diffusers import QwenImagePipeline
import torch.distributed as dist
from diffusers import ContextParallelConfig
from diffusers.quantizers import PipelineQuantizationConfig


def parse_args():
    parser = argparse.ArgumentParser(description="Test Qwen-Image with Context Parallelism")
    parser.add_argument(
        "--use_2512",
        action="store_true",
        help="Use Qwen-Image-2512 model if set, otherwise use 2509 model.",
    )
    # torch.compile flags
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable torch.compile for the pipeline if set.",
    )
    parser.add_argument(
        "--quantize",
        action="store_true",
        help="Enable quantization for the pipeline if set.",
    )
    parser.add_argument(
        "--ulysses-anything",
        action="store_true",
        help="Enable debug mode if set.",
    )
    return parser.parse_args()

args = parse_args()

if dist.is_available():
    dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

if args.use_2512:
    model_id = "Qwen/Qwen-Image-2512"
else:
    model_id = "Qwen/Qwen-Image"

pipe = QwenImagePipeline.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder", "transformer"],
        )
    ) if args.quantize else None,
)

if args.quantize:
    pipe.to(device)
else:
    pipe.enable_model_cpu_offload(device=device)

pipe.transformer.set_attention_backend("native")
if world_size > 1:
    from diffusers import QwenImageTransformer2DModel
    assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(
            ulysses_degree=world_size,
            ulysses_anything=args.ulysses_anything,
        )
    )

pipe.set_progress_bar_config(disable=rank != 0)

positive_magic = {
        "en": ", Ultra HD, 4K, cinematic composition.",  # for english prompt
        "zh": ", 超清,4K,电影级构图.",  # for chinese prompt
}
prompt = (
        "A coffee shop entrance features a chalkboard sign reading "
        '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
        'displaying "通义千问". Next to it hangs a poster showing a '
        "beautiful Chinese woman, and beneath the poster is written "
        '"π≈3.1415926-53589793-23846264-33832795-02384197". '
        "Ultra HD, 4K, cinematic composition"
)


if args.compile:
    torch._dynamo.config.recompile_limit = 256
    torch._dynamo.config.accumulated_recompile_limit = 8096
    torch._inductor.config.reorder_for_compute_comm_overlap = True
    pipe.transformer.compile_repeated_blocks()


def run_pipe():
    with torch.inference_mode():
        inputs = {
            "prompt": prompt + positive_magic["en"],
            "generator": torch.Generator(device="cpu").manual_seed(0),
            "true_cfg_scale": 4.0,
            "negative_prompt": " ",
            "num_inference_steps": 50,
            "num_images_per_prompt": 1,
            "height": 1024,
            "width": 1024,
        }
        output = pipe(**inputs)
        output_image = output.images[0]
    return output_image


if args.compile:
    # Warm-up run for compilation
    for _ in range(2):
        run_pipe()


output_image = run_pipe()

model_version = "2512" if args.use_2512 else None
if world_size > 1:
    if model_version is not None:
        save_path = f"output_image_{model_version}_ulysses{world_size}.png"
    else:
        save_path = f"output_image_ulysses{world_size}.png"
else:
    if model_version is not None:
        save_path = f"output_image_{model_version}.png"
    else:
        save_path = f"output_image.png"
if rank == 0:
    output_image.save(save_path)
    print(f"image saved at {save_path}")

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

torchrun --nproc_per_node=1 test_qwen_image.py --use_2512 # baseline 2512
torchrun --nproc_per_node=2 test_qwen_image.py --use_2512 # cp2 2512
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 # cp4 2512, standard ulysses failed
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything # cp4 2512, working as expected
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --compile # cp4 2512 + compile, working as expected

before this pr:

torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 # standard ulysses failed

...
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 190, in new_forward
[rank3]:     return function_reference.post_forward(module, output)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 201, in post_forward
[rank3]:     current_output = self._prepare_cp_input(current_output, cpm)
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 213, in _prepare_cp_input
[rank3]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 266, in shard
[rank3]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

after this pr:

torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything

...
Attention backends are an experimental feature and the API may be subject to change.
`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning.
100%|█████████████████████████████████████████████████████████| 50/50 [00:49<00:00,  1.02it/s]
image saved at output_image_2512_ulysses4.png
Qwen-Image-2512 Qwen-Image-2512 Ulysses-2 Qwen-Image-2512 Ulysses-Anything-4
L20x1 w/ offload, 101s L20x2 w/ offload, 75s L20x4 w/ offload, 49s, standard Ulysses failed
output_image_2512 output_image_2512_ulysses2 output_image_2512_ulysses4
  • compile (w/o offload, L20 48GiB)
# NO Compile, ~35s
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --quantize 
# w/ Compile, ~32s
torchrun --nproc_per_node=4 test_qwen_image.py --use_2512 --ulysses-anything --quantize --compile

@DefTruth DefTruth marked this pull request as ready for review January 19, 2026 08:32
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Jan 19, 2026
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting this work!

Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.

@DefTruth
Copy link
Contributor Author

Thanks for starting this work!

Let's try to think if we can allow the user to more explicitly specify Ulysses Anything through the configs instead of an env var.

Now, we can enable ulysses anything attention through the ContextParallelConfig:

pipe.transformer.enable_parallelism(
    config=ContextParallelConfig(
        ulysses_degree=world_size,
        ulysses_anything=True,
    )
)

@DefTruth
Copy link
Contributor Author

Let me prepare more test cases for ulysses anything.

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 20, 2026

FLUX.1-dev, 1008x1008

import os
import argparse
import time
import torch
import torch.distributed as dist
from diffusers import (
    FluxPipeline,
    FluxTransformer2DModel,
    ContextParallelConfig,
    PipelineQuantizationConfig,
)

def parse_args():
    parser = argparse.ArgumentParser(description="Context Parallelism")
    # torch.compile flags
    parser.add_argument(
        "--compile",
        action="store_true",
        help="Enable torch.compile for the pipeline if set.",
    )
    parser.add_argument(
        "--quantize",
        action="store_true",
        help="Enable quantization for the pipeline if set.",
    )
    parser.add_argument(
        "--ulysses-anything",
        action="store_true",
        help="Enable debug mode if set.",
    )
    # height and width
    parser.add_argument(
        "--height",
        type=int,
        default=None,
        help="Height of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=None,
        help="Width of the generated image.",
    )
    return parser.parse_args()

args = parse_args()

print(args)

if dist.is_available():
    dist.init_process_group(backend="cpu:gloo,cuda:nccl")
    rank = dist.get_rank()
    device = torch.device("cuda", rank % torch.cuda.device_count())
    world_size = dist.get_world_size()
    torch.cuda.set_device(device)
else:
    rank = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    world_size = 1

pipe: FluxPipeline = FluxPipeline.from_pretrained(
    os.environ.get(
        "FLUX_DIR",
        "black-forest-labs/FLUX.1-dev",
    ),
    torch_dtype=torch.bfloat16,
    quantization_config=(
        PipelineQuantizationConfig(
            quant_backend="bitsandbytes_4bit",
            quant_kwargs={
                "load_in_4bit": True,
                "bnb_4bit_quant_type": "nf4",
                "bnb_4bit_compute_dtype": torch.bfloat16,
            },
            components_to_quantize=["text_encoder_2"],
        )
        if args.quantize
        else None
    ),
).to("cuda")


assert isinstance(pipe.transformer, FluxTransformer2DModel)
pipe.transformer.set_attention_backend("native")
if world_size > 1:
    pipe.transformer.enable_parallelism(
        config=ContextParallelConfig(
            ulysses_degree=world_size,
            ulysses_anything=args.ulysses_anything,
        )
    )


pipe.set_progress_bar_config(disable=rank != 0)

# Set default prompt
prompt = "A cat holding a sign that says hello world"


height = 1008 if args.height is None else args.height
width = 1008 if args.width is None else args.width


def run_pipe(pipe: FluxPipeline):
    image = pipe(
        prompt,
        height=height,
        width=width,
        num_inference_steps=28,
        generator=torch.Generator("cpu").manual_seed(0),
    ).images[0]
    return image


if args.compile:
    torch._dynamo.config.recompile_limit = 256
    torch._dynamo.config.accumulated_recompile_limit = 8096
    torch._inductor.config.reorder_for_compute_comm_overlap = True
    pipe.transformer = torch.compile(pipe.transformer)

# warmup
_ = run_pipe(pipe)

start = time.time()
image = run_pipe(pipe)
end = time.time()

if rank == 0:
    time_cost = end - start
    save_path = f"flux.{height}x{width}_ulysses{world_size}.png"
    print(f"Time cost: {time_cost:.2f}s")
    print(f"Saving image to {save_path}")
    image.save(save_path)

if dist.is_initialized():
    dist.destroy_process_group()

test cmds:

torchrun --nproc_per_node=1 test_flux.py # baseline
torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything # working as expected
torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected

before this pr:

torchrun --nproc_per_node=2 test_flux.py # standard ulysses, failed

...
[rank1]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]:     args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 158, in pre_forward
[rank1]:     input_val = self._prepare_cp_input(input_val, cpm)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 216, in _prepare_cp_input
[rank1]:     return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/workspace/dev/vipshop/diffusers/src/diffusers/hooks/context_parallel.py", line 273, in shard
[rank1]:     assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

after this pr:

torchrun --nproc_per_node=2 test_flux.py --ulysses-anything --compile # working as expected

...
|████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.35it/s]
Time cost: 12.58s
Saving image to flux.1008x1008_ulysses2_compile1.png
FLUX.1-dev FLUX.1-dev Ulysses-Anything 2 FLUX.1-dev Ulysses-Anything 2 + compile
L20x1, 23.26s L20x2, 13.55s L20x2, 12.58s
flux 1008x1008_ulysses1 flux 1008x1008_ulysses2_compile0 flux 1008x1008_ulysses2_compile1

@sayakpaul I can also provide a case to demonstrate support for any head number via z‑image (e.g., head num = 30 with Ulysses Anything‑4) in a separate PR after this one is ready. For now, you can quickly try it using the examples in cache‑dit.

cd cache-dit/exmaples
torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py zimage --parallel ulysses --ulysses-anything

logs:

INFO 01-20 03:07:50 [base.py:622] ----------------------------------------------------------------------------------------------------
INFO 01-20 03:07:50 [base.py:395] 🤖 Example Init Config Summary:
INFO 01-20 03:07:50 [base.py:418] - Model: /workspace/dev/vipdev/hf_models/Z-Image-Turbo
INFO 01-20 03:07:50 [base.py:418] - Task Type: T2I - Text to Image
INFO 01-20 03:07:50 [base.py:418] - Torch Dtype: torch.bfloat16
INFO 01-20 03:07:50 [base.py:418] - LoRA Weights: None
INFO 01-20 03:07:50 [base.py:212] 🤖 Example Input Summary:
INFO 01-20 03:07:50 [base.py:212] - prompt: Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights.
INFO 01-20 03:07:50 [base.py:212] - height: 1024
INFO 01-20 03:07:50 [base.py:212] - width: 1024
INFO 01-20 03:07:50 [base.py:212] - guidance_scale: 0.0
INFO 01-20 03:07:50 [base.py:212] - num_inference_steps: 9
INFO 01-20 03:07:50 [base.py:212] - generator: device cpu, seed 0
INFO 01-20 03:07:50 [base.py:307] 🤖 Example Output Summary:
INFO 01-20 03:07:50 [base.py:323] - Model: zimage
INFO 01-20 03:07:50 [base.py:323] - Optimization: C0_Q0_NONE_Ulysses4_ulysses_anything
INFO 01-20 03:07:50 [base.py:323] - Device: NVIDIA L20 x 4
INFO 01-20 03:07:50 [base.py:323] - Load Time: 12.48s
INFO 01-20 03:07:50 [base.py:323] - Warmup Time: 3.08s
INFO 01-20 03:07:50 [base.py:323] - Inference Time: 2.36s
INFO 01-20 03:07:50 [base.py:246] Image saved to zimage.1024x1024.C0_Q0_NONE_Ulysses4_ulysses_anything.png
INFO 01-20 03:07:50 [base.py:633] ----------------------------------------------------------------------------------------------------

@DefTruth
Copy link
Contributor Author

DefTruth commented Jan 21, 2026

@sayakpaul Hi~ can you take a look to the latest updates? thanks~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Ulysses Attention for any sequence length w/o padding

2 participants