diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index b5abcbf..8de1395 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -17,7 +17,6 @@ from typing import Optional # Third Party -from packaging.version import Version from torch import Tensor import torch import torch.nn.functional as F @@ -30,62 +29,71 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -if Version(torch.__version__) <= Version("2.7"): - # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, - # while for earlier versions we need a custom definition - def _scaled_mm_cpu_out( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - *, - out: Optional[Tensor] = None, - ) -> Tensor: - if out_dtype is None: - out_dtype = torch.float32 - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) - - if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) - else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) - - if out is not None: - out.copy_(ret) - return out - return ret +# PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set. +# This CPU implementation is not enough for our use case, so we still have to +# keep our own custom version. +def _scaled_mm_cpu_out( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + *, + out: Optional[Tensor] = None, +) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) + + if bias is not None: + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + else: + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + if out is not None: + out.copy_(ret) + return out + return ret + + +def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> Tensor: + return _scaled_mm_cpu_out( + mat1, + mat2, + scale1, + scale2, + bias, + scale_result, + out_dtype, + use_fast_accum, + out=None, + ) + + +if torch.__version__ >= "2.8": + DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + torch.ops.aten._scaled_mm.out.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu_out + torch.ops.aten._scaled_mm.default.py_kernels[DispatchKey.CPU] = _scaled_mm_cpu +else: torch.library.register_kernel( torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out ) - - @torch.library.register_kernel("aten::_scaled_mm", "cpu") - def _scaled_mm_cpu( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - ) -> Tensor: - return _scaled_mm_cpu_out( - mat1, - mat2, - scale1, - scale2, - bias, - scale_result, - out_dtype, - use_fast_accum, - out=None, - ) + torch.library.register_kernel( + torch.ops.aten._scaled_mm.default, "cpu", _scaled_mm_cpu + ) @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) @@ -115,7 +123,7 @@ def spyre_scaled_bmm( device=mat1.device, ) for b_idx in range(mat1.shape[0]): - out[b_idx] = torch._scaled_mm( + out[b_idx] = _scaled_mm_cpu_out( mat1[b_idx], mat2[b_idx], scale1, @@ -218,6 +226,7 @@ def scaled_paged_attn_compute( num_kv_heads = value_cache.shape[2] head_size = value_cache.shape[3] block_size = value_cache.shape[1] + seq_len_q = query.shape[1] num_seqs = query.shape[0] block_tables_lst = block_table.cpu().tolist() @@ -228,6 +237,7 @@ def scaled_paged_attn_compute( block_table = block_tables_lst[i] start_pos = int(left_padded_prompt_mask[i].item()) seq_len = int(seq_lens_lst[i]) + seq_len_q_i = seq_len_q keys_lst: list[torch.Tensor] = [] values_lst: list[torch.Tensor] = [] @@ -243,6 +253,13 @@ def scaled_paged_attn_compute( values_lst.append(v) keys = torch.stack(keys_lst, dim=0) values = torch.stack(values_lst, dim=0) + seq_len_kv = keys.shape[0] + + # cut the pads for first prefill + if q.shape[0] > seq_len_kv: + seq_len_q_i = seq_len_kv + q = q[-seq_len_kv:] + if num_kv_heads > 1: # Handle MQA and GQA keys = torch.repeat_interleave(keys, num_query_heads // num_kv_heads, dim=1) @@ -250,6 +267,11 @@ def scaled_paged_attn_compute( values, num_query_heads // num_kv_heads, dim=1 ) + # Generate mask for prefix attention + mask = torch.ones((1, 1, seq_len_q_i, seq_len_kv), dtype=torch.bool) + mask[:, :, :, -seq_len_q_i:] = torch.tril(mask[:, :, :, -seq_len_q_i:]) + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + out = F.scaled_dot_product_attention( # noqa: E1102 q.transpose(0, 1).unsqueeze(0), # format for sdpa (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to( @@ -258,12 +280,12 @@ def scaled_paged_attn_compute( (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to( dtype=q.dtype ), # format for sdpa - is_causal=False, # decode assumes no causal mask + attn_mask=mask, # decode assumes no causal mask scale=scale, ) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) + out = out.transpose(1, 2).view(seq_len_q_i, num_query_heads, head_size) + output[i][-seq_len_q_i:] = out return output