From 10bbc830c29cb508114d95a556f0f2ba4ad46058 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Tue, 11 Nov 2025 17:37:17 +0000 Subject: [PATCH 1/3] Fixes for paged fp8 attention with chunked prefill Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 130 ++++++++++++++------------ 1 file changed, 72 insertions(+), 58 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index b5abcbf..d2b5659 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -30,62 +30,62 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -if Version(torch.__version__) <= Version("2.7"): - # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, - # while for earlier versions we need a custom definition - def _scaled_mm_cpu_out( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - *, - out: Optional[Tensor] = None, - ) -> Tensor: - if out_dtype is None: - out_dtype = torch.float32 - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) - - if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) - else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) - - if out is not None: - out.copy_(ret) - return out - return ret - - torch.library.register_kernel( - torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out - ) +# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set. +# This CPU implementation is not enough for our use case, so we still have to +# keep our own custom version. +def _scaled_mm_cpu_out( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + *, + out: Optional[Tensor] = None, +) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) - @torch.library.register_kernel("aten::_scaled_mm", "cpu") - def _scaled_mm_cpu( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - ) -> Tensor: - return _scaled_mm_cpu_out( - mat1, - mat2, - scale1, - scale2, - bias, - scale_result, - out_dtype, - use_fast_accum, - out=None, - ) + if bias is not None: + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + else: + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + + if out is not None: + out.copy_(ret) + return out + return ret + + +torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) + + +@torch.library.register_kernel("aten::_scaled_mm", "cpu") +def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> Tensor: + return _scaled_mm_cpu_out( + mat1, + mat2, + scale1, + scale2, + bias, + scale_result, + out_dtype, + use_fast_accum, + out=None, + ) @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) @@ -218,6 +218,7 @@ def scaled_paged_attn_compute( num_kv_heads = value_cache.shape[2] head_size = value_cache.shape[3] block_size = value_cache.shape[1] + seq_len_q = query.shape[1] num_seqs = query.shape[0] block_tables_lst = block_table.cpu().tolist() @@ -228,6 +229,7 @@ def scaled_paged_attn_compute( block_table = block_tables_lst[i] start_pos = int(left_padded_prompt_mask[i].item()) seq_len = int(seq_lens_lst[i]) + seq_len_q_i = seq_len_q keys_lst: list[torch.Tensor] = [] values_lst: list[torch.Tensor] = [] @@ -243,6 +245,13 @@ def scaled_paged_attn_compute( values_lst.append(v) keys = torch.stack(keys_lst, dim=0) values = torch.stack(values_lst, dim=0) + seq_len_kv = keys.shape[0] + + # cut the pads for first prefill + if q.shape[0] > seq_len_kv: + seq_len_q_i = seq_len_kv + q = q[-seq_len_kv:] + if num_kv_heads > 1: # Handle MQA and GQA keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1) @@ -250,6 +259,11 @@ def scaled_paged_attn_compute( values, num_query_heads // num_kv_heads, dim=1 ) + # Generate mask for prefix attention + mask = torch.ones((1, 1, seq_len_q_i, seq_len_kv), dtype=torch.bool) + mask[:, :, :, -seq_len_q_i:] = torch.tril(mask[:, :, :, -seq_len_q_i:]) + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + out = F.scaled_dot_product_attention( # noqa: E1102 q.transpose(0, 1).unsqueeze(0), # format for sdpa (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to( @@ -258,12 +272,12 @@ def scaled_paged_attn_compute( (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to( dtype=q.dtype ), # format for sdpa - is_causal=False, # decode assumes no causal mask + attn_mask=mask, # decode assumes no causal mask scale=scale, ) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) + out = out.transpose(1, 2).view(seq_len_q_i, num_query_heads, head_size) + output[i][-seq_len_q_i:] = out return output From acdb545534f6bf3f4a8bc7f3cd6e326bc08639a7 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Tue, 11 Nov 2025 17:49:51 +0000 Subject: [PATCH 2/3] Fix lint warning Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index d2b5659..205f814 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -17,7 +17,6 @@ from typing import Optional # Third Party -from packaging.version import Version from torch import Tensor import torch import torch.nn.functional as F From 1991d09bbe2823c9fc5a87d4f4981e4d7fa51361 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Mon, 17 Nov 2025 23:15:36 +0000 Subject: [PATCH 3/3] Test fix for FP8 matmul Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 205f814..8de1395 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -60,10 +60,6 @@ def _scaled_mm_cpu_out( return ret -torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) - - -@torch.library.register_kernel("aten::_scaled_mm", "cpu") def _scaled_mm_cpu( mat1: Tensor, mat2: Tensor, @@ -87,6 +83,19 @@ def _scaled_mm_cpu( ) +if torch.__version__ >= "2.8": + DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out + torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu +else: + torch.library.register_kernel( + torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out + ) + torch.library.register_kernel( + torch.ops.aten._scaled_mm.default, "cpu", _scaled_mm_cpu + ) + + @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def spyre_scaled_bmm( mat1: Tensor, @@ -114,7 +123,7 @@ def spyre_scaled_bmm( device=mat1.device, ) for b_idx in range(mat1.shape[0]): - out[b_idx] = torch._scaled_mm( + out[b_idx] = _scaled_mm_cpu_out( mat1[b_idx], mat2[b_idx], scale1,