Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def check_inputs(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ def encode_prompt(
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)

if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):
Expand Down
71 changes: 71 additions & 0 deletions tests/models/transformers/test_models_transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,74 @@ def prepare_dummy_input(self, height, width):

def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()

def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)

# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None

# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_no_mask_2 = model(**inputs_no_mask)

self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())

# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_all_ones_2 = model(**inputs_all_ones)

self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding

inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding

# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)

# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
torch.no_grad(),
):
output_with_padding_2 = model(**inputs_with_padding)

self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])

# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
Loading