-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[attention backends] use dedicated wrappers from fa3 for cp. #13165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -276,7 +276,11 @@ class _HubKernelConfig: | |||
| _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { | ||||
| # TODO: temporary revision for now. Remove when merged upstream into `main`. | ||||
| AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( | ||||
| repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" | ||||
| repo_id="kernels-community/flash-attn3", | ||||
| function_attr="flash_attn_func", | ||||
| revision="fake-ops-return-probs", | ||||
| wrapped_forward_attr="flash_attn_interface._flash_attn_forward", | ||||
| wrapped_backward_attr="flash_attn_interface._flash_attn_backward", | ||||
| ), | ||||
| AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( | ||||
| repo_id="kernels-community/flash-attn3", | ||||
|
|
@@ -1237,36 +1241,62 @@ def _flash_attention_3_hub_forward_op( | |||
| if enable_gqa: | ||||
| raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.") | ||||
|
|
||||
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn | ||||
| out = func( | ||||
| q=query, | ||||
| k=key, | ||||
| v=value, | ||||
| softmax_scale=scale, | ||||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] | ||||
| wrapped_forward_fn = config.wrapped_forward_fn | ||||
| if wrapped_forward_fn is None: | ||||
| raise RuntimeError( | ||||
| "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` " | ||||
| "for context parallel execution." | ||||
| ) | ||||
|
|
||||
| if scale is None: | ||||
| scale = query.shape[-1] ** (-0.5) | ||||
|
|
||||
| out, softmax_lse, *_ = wrapped_forward_fn( | ||||
| query, | ||||
| key, | ||||
| value, | ||||
| None, | ||||
| None, # k_new, v_new | ||||
| None, # qv | ||||
| None, # out | ||||
| None, | ||||
| None, | ||||
| None, # cu_seqlens_q/k/k_new | ||||
| None, | ||||
| None, # seqused_q/k | ||||
| None, | ||||
| None, # max_seqlen_q/k | ||||
| None, | ||||
| None, | ||||
| None, # page_table, kv_batch_idx, leftpad_k | ||||
| None, | ||||
| None, | ||||
| None, # rotary_cos/sin, seqlens_rotary | ||||
| None, | ||||
| None, | ||||
| None, # q_descale, k_descale, v_descale | ||||
| scale, | ||||
| causal=is_causal, | ||||
| qv=None, | ||||
| q_descale=None, | ||||
| k_descale=None, | ||||
| v_descale=None, | ||||
| window_size=window_size, | ||||
| window_size_left=window_size[0], | ||||
| window_size_right=window_size[1], | ||||
| attention_chunk=0, | ||||
| softcap=softcap, | ||||
| num_splits=num_splits, | ||||
| pack_gqa=pack_gqa, | ||||
| deterministic=deterministic, | ||||
| sm_margin=sm_margin, | ||||
| return_attn_probs=return_lse, | ||||
| ) | ||||
|
|
||||
| lse = None | ||||
| if return_lse: | ||||
| out, lse = out | ||||
| lse = lse.permute(0, 2, 1).contiguous() | ||||
| lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None | ||||
|
|
||||
| if _save_ctx: | ||||
| ctx.save_for_backward(query, key, value) | ||||
| ctx.save_for_backward(query, key, value, out, softmax_lse) | ||||
| ctx.scale = scale | ||||
| ctx.is_causal = is_causal | ||||
| ctx._hub_kernel = func | ||||
| ctx.window_size = window_size | ||||
| ctx.softcap = softcap | ||||
| ctx.deterministic = deterministic | ||||
| ctx.sm_margin = sm_margin | ||||
|
|
||||
| return (out, lse) if return_lse else out | ||||
|
|
||||
|
|
@@ -1275,55 +1305,50 @@ def _flash_attention_3_hub_backward_op( | |||
| ctx: torch.autograd.function.FunctionCtx, | ||||
| grad_out: torch.Tensor, | ||||
| *args, | ||||
| window_size: tuple[int, int] = (-1, -1), | ||||
| softcap: float = 0.0, | ||||
| num_splits: int = 1, | ||||
| pack_gqa: bool | None = None, | ||||
| deterministic: bool = False, | ||||
| sm_margin: int = 0, | ||||
| **kwargs, | ||||
| ): | ||||
| query, key, value = ctx.saved_tensors | ||||
| kernel_fn = ctx._hub_kernel | ||||
| # NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward | ||||
| # primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We | ||||
| # therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with | ||||
| # `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once | ||||
| # the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward` | ||||
| # in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`. | ||||
| with torch.enable_grad(): | ||||
| query_r = query.detach().requires_grad_(True) | ||||
| key_r = key.detach().requires_grad_(True) | ||||
| value_r = value.detach().requires_grad_(True) | ||||
|
|
||||
| out = kernel_fn( | ||||
| q=query_r, | ||||
| k=key_r, | ||||
| v=value_r, | ||||
| softmax_scale=ctx.scale, | ||||
| causal=ctx.is_causal, | ||||
| qv=None, | ||||
| q_descale=None, | ||||
| k_descale=None, | ||||
| v_descale=None, | ||||
| window_size=window_size, | ||||
| softcap=softcap, | ||||
| num_splits=num_splits, | ||||
| pack_gqa=pack_gqa, | ||||
| deterministic=deterministic, | ||||
| sm_margin=sm_margin, | ||||
| return_attn_probs=False, | ||||
| ) | ||||
| if isinstance(out, tuple): | ||||
| out = out[0] | ||||
|
|
||||
| grad_query, grad_key, grad_value = torch.autograd.grad( | ||||
| out, | ||||
| (query_r, key_r, value_r), | ||||
| grad_out, | ||||
| retain_graph=False, | ||||
| allow_unused=False, | ||||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] | ||||
| wrapped_backward_fn = config.wrapped_backward_fn | ||||
| if wrapped_backward_fn is None: | ||||
| raise RuntimeError( | ||||
| "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` " | ||||
| "for context parallel execution." | ||||
| ) | ||||
|
|
||||
| query, key, value, out, softmax_lse = ctx.saved_tensors | ||||
| grad_query = torch.empty_like(query) | ||||
| grad_key = torch.empty_like(key) | ||||
| grad_value = torch.empty_like(value) | ||||
|
|
||||
| wrapped_backward_fn( | ||||
| grad_out, | ||||
| query, | ||||
| key, | ||||
| value, | ||||
| out, | ||||
| softmax_lse, | ||||
| None, | ||||
| None, # cu_seqlens_q, cu_seqlens_k | ||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just fmi, there is no fa varlen cp yet? I think meta had a version at some point
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We handle varlen separately + varlens aren't that common in diffusion use cases.
|
||||
| None, | ||||
| None, # seqused_q, seqused_k | ||||
| None, | ||||
| None, # max_seqlen_q, max_seqlen_k | ||||
| grad_query, | ||||
| grad_key, | ||||
| grad_value, | ||||
| ctx.scale, | ||||
| ctx.is_causal, | ||||
| ctx.window_size[0], | ||||
| ctx.window_size[1], | ||||
| ctx.softcap, | ||||
| ctx.deterministic, | ||||
| ctx.sm_margin, | ||||
| ) | ||||
|
|
||||
| grad_query = grad_query[..., : grad_out.shape[-1]] | ||||
| grad_key = grad_key[..., : grad_out.shape[-1]] | ||||
| grad_value = grad_value[..., : grad_out.shape[-1]] | ||||
|
|
||||
| return grad_query, grad_key, grad_value | ||||
|
|
||||
|
|
||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah didnt notice it before, we don't have it merged into main yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're working on it in #13161.