From 02ae19d448c1379b9a3f48beef74688265313c90 Mon Sep 17 00:00:00 2001 From: Beinsezii Date: Fri, 9 Jan 2026 16:57:46 -0800 Subject: [PATCH] ZImageTransformer2D: Only build attention mask if seqlens are not equal --- .../models/transformers/transformer_z_image.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 5983c34ab640..085e9000d7bd 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -789,9 +789,12 @@ def _prepare_sequence( freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]] # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(item_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in item_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(item_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None @@ -872,9 +875,12 @@ def _build_unified_sequence( unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0) # Attention mask - attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) - for i, seq_len in enumerate(unified_seqlens): - attn_mask[i, :seq_len] = 1 + if all(seq == max_seqlen for seq in unified_seqlens): + attn_mask = None + else: + attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_seqlens): + attn_mask[i, :seq_len] = 1 # Noise mask noise_mask_tensor = None