From 0052f58c4afd1fd25c1c397f6e5061d32b1de6d8 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Tue, 13 Jan 2026 13:25:17 +0000 Subject: [PATCH 1/4] fix qwen-image cp --- src/diffusers/models/transformers/transformer_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a8c98201d96b..ce3af64933d0 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -761,11 +761,11 @@ class QwenImageTransformer2DModel( _no_split_modules = ["QwenImageTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + # Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702 _cp_plan = { - "": { + "transformer_blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), From caf595a6075f0625bc78a487b81b23499b2c32cb Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 14 Jan 2026 02:10:32 +0000 Subject: [PATCH 2/4] relax attn_mask limit for cp --- src/diffusers/models/attention_dispatch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f4ec49703850..f086c2d42579 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1573,8 +1573,6 @@ def _templated_context_parallel_attention( backward_op, _parallel_config: Optional["ParallelConfig"] = None, ): - if attn_mask is not None: - raise ValueError("Attention mask is not yet supported for templated attention.") if is_causal: raise ValueError("Causal attention is not yet supported for templated attention.") if enable_gqa: From 22db1fdef4464d8a22cab522130b53fde7d71ede Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 14 Jan 2026 02:24:28 +0000 Subject: [PATCH 3/4] CP plan compatible with zero_cond_t --- .../models/transformers/transformer_qwenimage.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index ce3af64933d0..7980dd30d57e 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -826,6 +826,17 @@ def __init__( self.gradient_checkpointing = False self.zero_cond_t = zero_cond_t + # Make CP plan compatible with zero_cond_t + if self.zero_cond_t: + # modulate_index: [b, l=seq_len], introduce by Qwen-Image-Edit-2511 + self._cp_plan.update( + { + "transformer_blocks.*": { + "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + } + } + ) + def forward( self, hidden_states: torch.Tensor, From 6314b7799e2e57b4ad13fdd03a2c00d5d769d218 Mon Sep 17 00:00:00 2001 From: DefTruth Date: Wed, 14 Jan 2026 03:37:39 +0000 Subject: [PATCH 4/4] move modulate_index plan to top level --- .../models/transformers/transformer_qwenimage.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 7980dd30d57e..cf11d8e01fb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -767,6 +767,9 @@ class QwenImageTransformer2DModel( "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, + "transformer_blocks.*": { + "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), @@ -826,17 +829,6 @@ def __init__( self.gradient_checkpointing = False self.zero_cond_t = zero_cond_t - # Make CP plan compatible with zero_cond_t - if self.zero_cond_t: - # modulate_index: [b, l=seq_len], introduce by Qwen-Image-Edit-2511 - self._cp_plan.update( - { - "transformer_blocks.*": { - "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), - } - } - ) - def forward( self, hidden_states: torch.Tensor,