Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Jan 16, 2026

What does this PR do?

This pull request sets all ones masks to None.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif mentioned this pull request Jan 16, 2026
6 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif requested a review from sayakpaul January 16, 2026 15:10
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Could you update with the following things?

  • Shed light into what caused the speed regression
  • Add a test with masks in the compilation tests here
  • Do a before and after comparison in the outputs with the PR

@kashif
Copy link
Contributor Author

kashif commented Jan 16, 2026

will do!

model = self.model_class(**init_dict).to(torch_device)
model.eval()

compiled_model = torch.compile(model, mode="default", fullgraph=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some notes:

  • Usually, it should be model.compile() as it doesn't wrap the underlying model into a dynamo wrapper. This way, we don't have to add any extra code to handle it.
  • Why is fullgraph=False here?

Comment on lines 292 to 293
with torch.no_grad():
output_no_mask = compiled_model(**inputs_no_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it lead to graph breaks? If not, then we should add additional contexts:

torch._dynamo.config.patch(error_on_recompile=True),

@dxqb
Copy link
Contributor

dxqb commented Jan 19, 2026

does this PR lack a merge, or is the amount of code changes intentional and really part of this PR only? (+318 - 170)
if intentional, could you explain what it does and why all these changes are necessary?

avoiding masks if they're not necessary has been one line before
attention_mask = attention_mask if not torch.all(text_attention_mask) else None
(on a boolean mask)

@dxqb dxqb mentioned this pull request Jan 19, 2026
7 tasks
@kashif
Copy link
Contributor Author

kashif commented Jan 19, 2026

@dxqb i intially mis-calculated and was adding support for the pipeline to be compiled, i will revert and simplify

@dxqb
Copy link
Contributor

dxqb commented Jan 19, 2026

@dxqb i intially mis-calculated and was adding support for the pipeline to be compiled, i will revert and simplify

thanks!
regarding compile,

  • Regional compilation is usually as efficient as full compilation (even if it was possible). Regional compilation compiles the transformer blocks, but not the entire pipeline.
  • whenever you want to branch depending on GPU data, that's a graph break for torch.compile. Either inefficient, or fails depending on fullgraph. Or, in less abstract terms:
    you want to set your attention mask to None if the entire attention mask is True. But the attention mask lives on GPU. Checking whether all values of a tensor on GPU are True requires a transfer back to CPU - that's always a graph break and cannot be compiled (efficiently).

Therefore, I'd suggest to

  • check for an all-True mask in the pipeline
  • pass a None-Mask to the transformer block in that case - but don't check any GPU data in the transformer block, so the transformer block can be compiled.

@dxqb
Copy link
Contributor

dxqb commented Jan 19, 2026

Thanks!

Could you update with the following things?

* Shed light into what caused the speed regression

Here is a benchmark of the impact of using a mask unnecessarily (second graph): #12870 (comment)
torch SDPA falls back to a non-flash algorithm if a mask is used.

@kashif
Copy link
Contributor Author

kashif commented Jan 19, 2026

yes i also benchmarked the mask and got:

nsight_sdpa_mask_regression

@kashif
Copy link
Contributor Author

kashif commented Jan 19, 2026

thanks @dxqb please check now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants