diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..21515df60897 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -262,6 +262,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..714ec0423804 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -324,6 +324,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..f8521318e630 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -305,6 +305,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..353aadcbf08a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -309,6 +309,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..75be62cf6db2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -321,6 +321,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..bc688aeee319 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -323,6 +323,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..2c9da7545e8a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -305,6 +305,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..536a7984e8e8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -316,6 +316,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 53d2c169ee63..0a53a0ac7719 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -328,6 +328,9 @@ def encode_prompt( prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def get_image_caption(self, prompt_image, use_en_prompt=True, device=None): diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 384954dfbad7..e6b19377b14f 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -276,3 +276,74 @@ def prepare_dummy_input(self, height, width): def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() + + def test_torch_compile_with_and_without_mask(self): + """Test that torch.compile works with both None mask and padding mask.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + model.compile(mode="default", fullgraph=True) + + # Test 1: Run with None mask (no padding, all tokens are valid) + inputs_no_mask = inputs.copy() + inputs_no_mask["encoder_hidden_states_mask"] = None + + # First run to allow compilation + with torch.no_grad(): + output_no_mask = model(**inputs_no_mask) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_no_mask_2 = model(**inputs_no_mask) + + self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 2: Run with all-ones mask (should behave like None) + inputs_all_ones = inputs.copy() + # Keep the all-ones mask + self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + + # First run to allow compilation + with torch.no_grad(): + output_all_ones = model(**inputs_all_ones) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_all_ones_2 = model(**inputs_all_ones) + + self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Run with actual padding mask (has zeros) + inputs_with_padding = inputs.copy() + mask_with_padding = inputs["encoder_hidden_states_mask"].clone() + mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + + inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + + # First run to allow compilation + with torch.no_grad(): + output_with_padding = model(**inputs_with_padding) + + # Second run to verify no recompilation + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + torch.no_grad(), + ): + output_with_padding_2 = model(**inputs_with_padding) + + self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Verify that outputs are different (mask should affect results) + self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))