From 661febb64b3f99def7f042871816e197745985f3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Jan 2026 14:22:36 +0000 Subject: [PATCH 1/8] avoid creating attention masks when there is no padding --- .../pipelines/qwenimage/pipeline_qwenimage.py | 26 ++---- .../pipeline_qwenimage_controlnet.py | 20 ++--- .../pipeline_qwenimage_controlnet_inpaint.py | 20 ++--- .../qwenimage/pipeline_qwenimage_edit.py | 20 ++--- .../pipeline_qwenimage_edit_inpaint.py | 20 ++--- .../qwenimage/pipeline_qwenimage_edit_plus.py | 65 +++++++++----- .../qwenimage/pipeline_qwenimage_img2img.py | 25 ++---- .../qwenimage/pipeline_qwenimage_inpaint.py | 25 ++---- .../qwenimage/pipeline_qwenimage_layered.py | 25 ++---- src/diffusers/pipelines/qwenimage/utils.py | 89 +++++++++++++++++++ 10 files changed, 183 insertions(+), 152 deletions(-) create mode 100644 src/diffusers/pipelines/qwenimage/utils.py diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..bf289be905ff 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -27,6 +27,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -210,14 +211,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -248,19 +242,15 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..28803542867a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -274,14 +275,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -313,16 +307,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..4c0a96a4eb3d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -256,14 +257,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -294,16 +288,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..e65be467df54 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -257,14 +258,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -298,16 +292,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..40a0d9f35464 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -29,6 +29,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -268,14 +269,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -310,16 +304,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..a33366d7d1df 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -270,14 +271,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -312,16 +306,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask @@ -724,6 +714,15 @@ def __call__( max_sequence_length=max_sequence_length, ) + use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled + if use_batch_cfg: + prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -799,7 +798,11 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): + if use_batch_cfg: + latent_model_input = torch.cat([latent_model_input] * 2) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + if use_batch_cfg: noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -811,20 +814,36 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( + neg_noise_pred, noise_pred = noise_pred.chunk(2) + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + + if do_true_cfg: comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..c0aa0d56dd8f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -13,6 +13,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -217,14 +218,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -291,19 +285,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..52326f9001eb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -14,6 +14,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -227,14 +228,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -302,19 +296,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 53d2c169ee63..4da2406e046f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -275,14 +276,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -314,19 +308,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py new file mode 100644 index 000000000000..7271fce9304e --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/utils.py @@ -0,0 +1,89 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def build_prompt_embeds_and_mask(split_hidden_states): + seq_lens = [e.size(0) for e in split_hidden_states] + max_seq_len = max(seq_lens) + if all(seq_len == max_seq_len for seq_len in seq_lens): + prompt_embeds = torch.stack(split_hidden_states) + return prompt_embeds, None + + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + return prompt_embeds, encoder_attention_mask + + +def slice_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, max_sequence_length): + prompt_embeds = prompt_embeds[:, :max_sequence_length] + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + return prompt_embeds, prompt_embeds_mask + + +def repeat_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, num_images_per_prompt): + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + return prompt_embeds, prompt_embeds_mask + + +def concat_prompt_embeds_for_cfg( + prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask +): + pos_len = prompt_embeds.shape[1] + neg_len = negative_prompt_embeds.shape[1] + max_len = max(pos_len, neg_len) + + def _pad_prompt(embeds, mask): + orig_len = embeds.shape[1] + if orig_len != max_len: + pad_len = max_len - orig_len + embeds = torch.cat([embeds, embeds.new_zeros(embeds.shape[0], pad_len, embeds.shape[2])], dim=1) + if mask is None and orig_len != max_len: + mask = torch.ones((embeds.shape[0], orig_len), dtype=torch.long, device=embeds.device) + if mask is not None and mask.shape[1] != max_len: + pad_len = max_len - mask.shape[1] + mask = torch.cat([mask, mask.new_zeros(mask.shape[0], pad_len)], dim=1) + return embeds, mask + + prompt_embeds, prompt_embeds_mask = _pad_prompt(prompt_embeds, prompt_embeds_mask) + negative_prompt_embeds, negative_prompt_embeds_mask = _pad_prompt( + negative_prompt_embeds, negative_prompt_embeds_mask + ) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if prompt_embeds_mask is None and negative_prompt_embeds_mask is None: + prompt_embeds_mask = None + else: + batch_half = prompt_embeds.shape[0] // 2 + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones((batch_half, max_len), dtype=torch.long, device=prompt_embeds.device) + if negative_prompt_embeds_mask is None: + negative_prompt_embeds_mask = torch.ones( + (batch_half, max_len), dtype=torch.long, device=prompt_embeds.device + ) + prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) + + return prompt_embeds, prompt_embeds_mask From 5507b5e8b982cb837df5354f3d4ca3f3fefecedd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Jan 2026 14:28:05 +0000 Subject: [PATCH 2/8] make fix-copies --- src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 4da2406e046f..11d11167d359 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -314,7 +314,6 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( prompt_embeds, prompt_embeds_mask, max_sequence_length ) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) From 4839fcfc312c7e85a76cd99df7b616ea813fa372 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 17 Jan 2026 23:34:48 +0000 Subject: [PATCH 3/8] torch compile tests --- .../controlnets/controlnet_qwenimage.py | 2 - .../transformers/transformer_qwenimage.py | 3 -- .../pipelines/qwenimage/pipeline_qwenimage.py | 3 ++ src/diffusers/pipelines/qwenimage/utils.py | 3 ++ .../test_models_transformer_qwenimage.py | 42 +++++++++++++++++++ 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285eec1..78a566549377 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -213,10 +213,8 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..8cf0b19d09d0 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -935,11 +935,8 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bf289be905ff..88c6d74f92a8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -252,6 +252,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py index 7271fce9304e..7c91fec05a0a 100644 --- a/src/diffusers/pipelines/qwenimage/utils.py +++ b/src/diffusers/pipelines/qwenimage/utils.py @@ -86,4 +86,7 @@ def _pad_prompt(embeds, mask): ) prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 384954dfbad7..6acd7fb500ee 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -276,3 +276,45 @@ def prepare_dummy_input(self, height, width): def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() + + def test_torch_compile_with_and_without_mask(self): + """Test that torch.compile works with both None mask and padding mask.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + compiled_model = torch.compile(model, mode="default", fullgraph=False) + + # Test 1: Run with None mask (no padding, all tokens are valid) + inputs_no_mask = inputs.copy() + inputs_no_mask["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output_no_mask = compiled_model(**inputs_no_mask) + + self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 2: Run with all-ones mask (should behave like None) + inputs_all_ones = inputs.copy() + # Keep the all-ones mask + self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + + with torch.no_grad(): + output_all_ones = compiled_model(**inputs_all_ones) + + self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Run with actual padding mask (has zeros) + inputs_with_padding = inputs.copy() + mask_with_padding = inputs["encoder_hidden_states_mask"].clone() + mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + + inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + + with torch.no_grad(): + output_with_padding = compiled_model(**inputs_with_padding) + + self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Verify that outputs are different (mask should affect results) + self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) From 23150e46ab7ee87132fe4d23808d0c0b2961631b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Jan 2026 11:35:24 +0000 Subject: [PATCH 4/8] set all ones mask to none --- .../pipelines/qwenimage/pipeline_qwenimage_controlnet.py | 3 +++ .../qwenimage/pipeline_qwenimage_controlnet_inpaint.py | 3 +++ src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_edit_plus.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_img2img.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_layered.py | 3 +++ 8 files changed, 24 insertions(+) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 28803542867a..4ee51151701f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -314,6 +314,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 4c0a96a4eb3d..1f39ce08246e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -295,6 +295,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index e65be467df54..fd278def7245 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -299,6 +299,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 40a0d9f35464..4b56e8c9daa8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -311,6 +311,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index a33366d7d1df..d5fc2b78ae73 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -313,6 +313,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index c0aa0d56dd8f..4da613e4d6a2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -296,6 +296,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 52326f9001eb..109148c0f923 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -307,6 +307,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 11d11167d359..c8c5994e612b 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -318,6 +318,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): From 47f6585e445f989089d683914d8d6c2fa7629970 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Jan 2026 17:29:03 +0000 Subject: [PATCH 5/8] fix positional encoding from becoming > 4096 --- .../transformers/transformer_qwenimage.py | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 8cf0b19d09d0..74198d7303ba 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -235,7 +235,7 @@ def forward( video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], txt_seq_lens: Optional[List[int]] = None, device: torch.device = None, - max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, + max_txt_seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -245,9 +245,9 @@ def forward( Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. - max_txt_seq_len (`int` or `torch.Tensor`, *optional*): + max_txt_seq_len (`int`, *optional*): The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + sequence length. """ # Handle deprecated txt_seq_lens parameter if txt_seq_lens is not None: @@ -296,9 +296,14 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + pos_freqs = self.pos_freqs.to(device) + + # Clamp text sequence length to avoid buffer overflow + buffer_size = pos_freqs.shape[0] + available_space = buffer_size - max_vid_index + safe_txt_seq_len = min(max_txt_seq_len, available_space) + + txt_freqs = pos_freqs[max_vid_index : max_vid_index + safe_txt_seq_len] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -367,7 +372,7 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - max_txt_seq_len: Union[int, torch.Tensor], + max_txt_seq_len: int, device: torch.device = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -375,9 +380,9 @@ def forward( video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer structures. - max_txt_seq_len (`int` or `torch.Tensor`): + max_txt_seq_len (`int`): The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). + sequence length. device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. """ @@ -417,9 +422,15 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - max_txt_seq_len_int = int(max_txt_seq_len) - # Create device-specific copy for text freqs without modifying self.pos_freqs - txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] + + pos_freqs = self.pos_freqs.to(device) + + # Clamp text sequence length to avoid buffer overflow + buffer_size = pos_freqs.shape[0] + available_space = buffer_size - max_vid_index + safe_txt_seq_len = min(max_txt_seq_len, available_space) + + txt_freqs = pos_freqs[max_vid_index : max_vid_index + safe_txt_seq_len] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -920,7 +931,7 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, per_sample_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) @@ -933,6 +944,8 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) + # Pass the static text_seq_len to RoPE (encoder_hidden_states.shape[1]) + # The RoPE class will clamp it to avoid buffer overflow image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} From da6e128e9f1e15b300318f59a3640d7d418c4dca Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Jan 2026 08:49:04 +0000 Subject: [PATCH 6/8] fix from review --- .../test_models_transformer_qwenimage.py | 39 ++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 6acd7fb500ee..e6b19377b14f 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -282,27 +282,46 @@ def test_torch_compile_with_and_without_mask(self): init_dict, inputs = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) model.eval() - - compiled_model = torch.compile(model, mode="default", fullgraph=False) + model.compile(mode="default", fullgraph=True) # Test 1: Run with None mask (no padding, all tokens are valid) inputs_no_mask = inputs.copy() inputs_no_mask["encoder_hidden_states_mask"] = None + # First run to allow compilation with torch.no_grad(): - output_no_mask = compiled_model(**inputs_no_mask) + output_no_mask = model(**inputs_no_mask) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_no_mask_2 = model(**inputs_no_mask) self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) # Test 2: Run with all-ones mask (should behave like None) inputs_all_ones = inputs.copy() # Keep the all-ones mask self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + # First run to allow compilation with torch.no_grad(): - output_all_ones = compiled_model(**inputs_all_ones) + output_all_ones = model(**inputs_all_ones) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_all_ones_2 = model(**inputs_all_ones) self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) # Test 3: Run with actual padding mask (has zeros) inputs_with_padding = inputs.copy() @@ -311,10 +330,20 @@ def test_torch_compile_with_and_without_mask(self): inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + # First run to allow compilation with torch.no_grad(): - output_with_padding = compiled_model(**inputs_with_padding) + output_with_padding = model(**inputs_with_padding) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_with_padding_2 = model(**inputs_with_padding) self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) # Verify that outputs are different (mask should affect results) self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) From 283df92342538a47b895d2a54ab05549da2a8dc2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Jan 2026 10:51:18 +0000 Subject: [PATCH 7/8] slice freqs_cis to match the input sequence length --- .../models/transformers/transformer_qwenimage.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 74198d7303ba..320c8d8e3fe8 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -136,6 +136,14 @@ def apply_rotary_emb_qwen( return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + seq_len = x.shape[1] + + # Handle shape mismatch: slice freqs_cis to match the input sequence length. + if freqs_cis.dim() == 3 and freqs_cis.shape[1] > seq_len: + freqs_cis = freqs_cis[:, :seq_len] + elif freqs_cis.dim() == 2 and freqs_cis.shape[0] > seq_len: + freqs_cis = freqs_cis[:seq_len] + freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) @@ -264,7 +272,8 @@ def forward( max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens if max_txt_seq_len is None: - raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") + # The RoPE computation will clamp this to the available buffer space below. + max_txt_seq_len = 4096 # Validate batch inference with variable-sized images if isinstance(video_fhw, list) and len(video_fhw) > 1: From 3a0fd2db8f56c7067fc50852ef5683234cdba302 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Jan 2026 21:31:59 +0000 Subject: [PATCH 8/8] keep only attenton masking change --- .../controlnets/controlnet_qwenimage.py | 2 + .../transformers/transformer_qwenimage.py | 53 ++++------- .../pipelines/qwenimage/pipeline_qwenimage.py | 26 ++++-- .../pipeline_qwenimage_controlnet.py | 20 +++- .../pipeline_qwenimage_controlnet_inpaint.py | 20 +++- .../qwenimage/pipeline_qwenimage_edit.py | 20 +++- .../pipeline_qwenimage_edit_inpaint.py | 20 +++- .../qwenimage/pipeline_qwenimage_edit_plus.py | 65 +++++-------- .../qwenimage/pipeline_qwenimage_img2img.py | 25 +++-- .../qwenimage/pipeline_qwenimage_inpaint.py | 25 +++-- .../qwenimage/pipeline_qwenimage_layered.py | 26 ++++-- src/diffusers/pipelines/qwenimage/utils.py | 92 ------------------- 12 files changed, 172 insertions(+), 222 deletions(-) delete mode 100644 src/diffusers/pipelines/qwenimage/utils.py diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index 78a566549377..fa374285eec1 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -213,8 +213,10 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) + # Construct joint attention mask once to avoid reconstructing in every block block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 320c8d8e3fe8..cf11d8e01fb4 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -136,14 +136,6 @@ def apply_rotary_emb_qwen( return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - seq_len = x.shape[1] - - # Handle shape mismatch: slice freqs_cis to match the input sequence length. - if freqs_cis.dim() == 3 and freqs_cis.shape[1] > seq_len: - freqs_cis = freqs_cis[:, :seq_len] - elif freqs_cis.dim() == 2 and freqs_cis.shape[0] > seq_len: - freqs_cis = freqs_cis[:seq_len] - freqs_cis = freqs_cis.unsqueeze(1) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) @@ -243,7 +235,7 @@ def forward( video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], txt_seq_lens: Optional[List[int]] = None, device: torch.device = None, - max_txt_seq_len: Optional[int] = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -253,9 +245,9 @@ def forward( Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used. device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. - max_txt_seq_len (`int`, *optional*): + max_txt_seq_len (`int` or `torch.Tensor`, *optional*): The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). """ # Handle deprecated txt_seq_lens parameter if txt_seq_lens is not None: @@ -272,8 +264,7 @@ def forward( max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens if max_txt_seq_len is None: - # The RoPE computation will clamp this to the available buffer space below. - max_txt_seq_len = 4096 + raise ValueError("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.") # Validate batch inference with variable-sized images if isinstance(video_fhw, list) and len(video_fhw) > 1: @@ -305,14 +296,9 @@ def forward( else: max_vid_index = max(height, width, max_vid_index) - pos_freqs = self.pos_freqs.to(device) - - # Clamp text sequence length to avoid buffer overflow - buffer_size = pos_freqs.shape[0] - available_space = buffer_size - max_vid_index - safe_txt_seq_len = min(max_txt_seq_len, available_space) - - txt_freqs = pos_freqs[max_vid_index : max_vid_index + safe_txt_seq_len] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -381,7 +367,7 @@ def rope_params(self, index, dim, theta=10000): def forward( self, video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], - max_txt_seq_len: int, + max_txt_seq_len: Union[int, torch.Tensor], device: torch.device = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -389,9 +375,9 @@ def forward( video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer structures. - max_txt_seq_len (`int`): + max_txt_seq_len (`int` or `torch.Tensor`): The maximum text sequence length for RoPE computation. This should match the encoder hidden states - sequence length. + sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility). device: (`torch.device`, *optional*): The device on which to perform the RoPE computation. """ @@ -431,15 +417,9 @@ def forward( max_vid_index = max(height, width, max_vid_index) max_vid_index = max(max_vid_index, layer_num) - - pos_freqs = self.pos_freqs.to(device) - - # Clamp text sequence length to avoid buffer overflow - buffer_size = pos_freqs.shape[0] - available_space = buffer_size - max_vid_index - safe_txt_seq_len = min(max_txt_seq_len, available_space) - - txt_freqs = pos_freqs[max_vid_index : max_vid_index + safe_txt_seq_len] + max_txt_seq_len_int = int(max_txt_seq_len) + # Create device-specific copy for text freqs without modifying self.pos_freqs + txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...] vid_freqs = torch.cat(vid_freqs, dim=0) return vid_freqs, txt_freqs @@ -940,7 +920,7 @@ def forward( encoder_hidden_states = self.txt_in(encoder_hidden_states) # Use the encoder_hidden_states sequence length for RoPE computation and normalize mask - text_seq_len, per_sample_len, encoder_hidden_states_mask = compute_text_seq_len_from_mask( + text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask( encoder_hidden_states, encoder_hidden_states_mask ) @@ -953,12 +933,13 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states, additional_t_cond) ) - # Pass the static text_seq_len to RoPE (encoder_hidden_states.shape[1]) - # The RoPE class will clamp it to avoid buffer overflow image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) + # Construct joint attention mask once to avoid reconstructing in every block + # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: + # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 88c6d74f92a8..21515df60897 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -27,7 +27,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -211,7 +210,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -242,15 +248,19 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, max_sequence_length - ) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 4ee51151701f..714ec0423804 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -275,7 +274,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -307,12 +313,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 1f39ce08246e..f8521318e630 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -257,7 +256,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -288,12 +294,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index fd278def7245..353aadcbf08a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -258,7 +257,14 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -292,12 +298,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 4b56e8c9daa8..75be62cf6db2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -29,7 +29,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -269,7 +268,14 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -304,12 +310,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index d5fc2b78ae73..bc688aeee319 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -271,7 +270,14 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -306,12 +312,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None @@ -717,15 +727,6 @@ def __call__( max_sequence_length=max_sequence_length, ) - use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled - if use_batch_cfg: - prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg( - prompt_embeds, - prompt_embeds_mask, - negative_prompt_embeds, - negative_prompt_embeds_mask, - ) - # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -801,11 +802,7 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - if use_batch_cfg: - latent_model_input = torch.cat([latent_model_input] * 2) - timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - - if use_batch_cfg: + with self.transformer.cache_context("cond"): noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -817,36 +814,20 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred, noise_pred = noise_pred.chunk(2) - else: - with self.transformer.cache_context("cond"): - noise_pred = self.transformer( + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=prompt_embeds_mask, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, - img_shapes=img_shapes, - attention_kwargs=self.attention_kwargs, - return_dict=False, - )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - - if do_true_cfg: + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 4da613e4d6a2..2c9da7545e8a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -13,7 +13,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -218,7 +217,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -285,16 +291,19 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, max_sequence_length - ) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 109148c0f923..536a7984e8e8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -14,7 +14,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -228,7 +227,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -296,16 +302,19 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, max_sequence_length - ) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index c8c5994e612b..0a53a0ac7719 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -28,7 +28,6 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput -from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -276,7 +275,14 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -308,15 +314,19 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, max_sequence_length - ) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( - prompt_embeds, prompt_embeds_mask, num_images_per_prompt - ) + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) if prompt_embeds_mask is not None and prompt_embeds_mask.all(): prompt_embeds_mask = None diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py deleted file mode 100644 index 7c91fec05a0a..000000000000 --- a/src/diffusers/pipelines/qwenimage/utils.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -def build_prompt_embeds_and_mask(split_hidden_states): - seq_lens = [e.size(0) for e in split_hidden_states] - max_seq_len = max(seq_lens) - if all(seq_len == max_seq_len for seq_len in seq_lens): - prompt_embeds = torch.stack(split_hidden_states) - return prompt_embeds, None - - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) - return prompt_embeds, encoder_attention_mask - - -def slice_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, max_sequence_length): - prompt_embeds = prompt_embeds[:, :max_sequence_length] - if prompt_embeds_mask is not None: - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - return prompt_embeds, prompt_embeds_mask - - -def repeat_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, num_images_per_prompt): - batch_size, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - if prompt_embeds_mask is not None: - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) - return prompt_embeds, prompt_embeds_mask - - -def concat_prompt_embeds_for_cfg( - prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask -): - pos_len = prompt_embeds.shape[1] - neg_len = negative_prompt_embeds.shape[1] - max_len = max(pos_len, neg_len) - - def _pad_prompt(embeds, mask): - orig_len = embeds.shape[1] - if orig_len != max_len: - pad_len = max_len - orig_len - embeds = torch.cat([embeds, embeds.new_zeros(embeds.shape[0], pad_len, embeds.shape[2])], dim=1) - if mask is None and orig_len != max_len: - mask = torch.ones((embeds.shape[0], orig_len), dtype=torch.long, device=embeds.device) - if mask is not None and mask.shape[1] != max_len: - pad_len = max_len - mask.shape[1] - mask = torch.cat([mask, mask.new_zeros(mask.shape[0], pad_len)], dim=1) - return embeds, mask - - prompt_embeds, prompt_embeds_mask = _pad_prompt(prompt_embeds, prompt_embeds_mask) - negative_prompt_embeds, negative_prompt_embeds_mask = _pad_prompt( - negative_prompt_embeds, negative_prompt_embeds_mask - ) - - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - if prompt_embeds_mask is None and negative_prompt_embeds_mask is None: - prompt_embeds_mask = None - else: - batch_half = prompt_embeds.shape[0] // 2 - if prompt_embeds_mask is None: - prompt_embeds_mask = torch.ones((batch_half, max_len), dtype=torch.long, device=prompt_embeds.device) - if negative_prompt_embeds_mask is None: - negative_prompt_embeds_mask = torch.ones( - (batch_half, max_len), dtype=torch.long, device=prompt_embeds.device - ) - prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) - - if prompt_embeds_mask is not None and prompt_embeds_mask.all(): - prompt_embeds_mask = None - - return prompt_embeds, prompt_embeds_mask