From dacf1d2e32f6f419d37fe3646f0783c05f33b19b Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Fri, 11 Jul 2025 22:00:05 +0000 Subject: [PATCH] Add support for per-sequence scaling in FP8 attention, and fix some CPU fallback errors Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_attn.py | 60 +++++++++++++++++++++------ fms_mo/aiu_addons/fp8/fp8_linear.py | 1 + fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 56 ++++++++++++++++++------- 3 files changed, 90 insertions(+), 27 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index ea86e08a..1d31e827 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -70,12 +70,24 @@ def _math_fp8_store_op( k_scale = key_cache._scale v_scale = value_cache._scale else: - k_scale = (torch.abs(keys).max() / K_RANGE).to(dtype=torch.float32) - v_scale = (torch.abs(values).max() / V_RANGE).to(dtype=torch.float32) + k_scale = ( + (torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) + ) + v_scale = ( + (torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) + ) # Scale kv tensors for storage - keys = (keys / k_scale).to(torch.float8_e4m3fn).transpose(2, 1) - values = (values / v_scale).to(torch.float8_e4m3fn).transpose(2, 1) + keys = ( + (keys / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1) + ) + values = ( + (values / v_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn).transpose(2, 1) + ) if ( isinstance(key_cache, ScaledTensor) @@ -134,10 +146,20 @@ def _math_fp8_compute_op( value_cache = value_cache._data else: # Store op wasn't run (e.g. encoders, use_cache=False) - k_scale = (torch.abs(key_cache).max() / K_RANGE).to(dtype=torch.float32) - v_scale = (torch.abs(value_cache).max() / V_RANGE).to(dtype=torch.float32) - key_cache = (key_cache / k_scale).to(torch.float8_e4m3fn) - value_cache = (value_cache / v_scale).to(torch.float8_e4m3fn) + k_scale = ( + (torch.abs(key_cache).amax(dim=(1, 2, 3)) / K_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) + ) + v_scale = ( + (torch.abs(value_cache).amax(dim=(1, 2, 3)) / V_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) + ) + key_cache = (key_cache / k_scale.view(-1, 1, 1, 1)).to(torch.float8_e4m3fn) + value_cache = (value_cache / v_scale.view(-1, 1, 1, 1)).to( + torch.float8_e4m3fn + ) # If store wasn't run, we need to transpose the tensors here # TODO: Refactor FMS to avoid edge cases where this fails; add use_cache param here @@ -192,14 +214,20 @@ def _math_fp8_compute_op( * scale_factor ) else: - key_t = (key_cache.to(dtype=orig_dtype) * k_scale).transpose(-2, -1) + key_t = ( + (key_cache.to(dtype=orig_dtype) * k_scale.view(-1, 1, 1, 1)) + .to(dtype=orig_dtype) + .transpose(-2, -1) + ) attn_weight = query @ key_t attn_weight *= scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, p_dropout, train=True) # Do matmul in orig_dtype - attn = attn_weight @ (value_cache.to(dtype=orig_dtype) * v_scale) + attn = attn_weight @ ( + value_cache.to(dtype=orig_dtype) * v_scale.view(-1, 1, 1, 1) + ).to(dtype=orig_dtype) attn = attn.to(orig_dtype).transpose(2, 1).contiguous() return attn @@ -226,9 +254,15 @@ def _spyre_scaled_paged_store_op( value_cache, ScaledTensor ), "kv cache must be preallocated" if not key_cache._scaled: - key_cache._scale = (torch.abs(keys).max() / 200.0).to(dtype=torch.float32) - value_cache._scale = (torch.abs(values).max() / 100.0).to( - dtype=torch.float32 + key_cache._scale = ( + (torch.abs(keys).amax(dim=(1, 2, 3)) / K_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) + ) + value_cache._scale = ( + (torch.abs(values).amax(dim=(1, 2, 3)) / V_RANGE) + .clamp(min=1e-5) + .to(dtype=torch.float32) ) result_key_cache_data, result_value_cache_data = ( diff --git a/fms_mo/aiu_addons/fp8/fp8_linear.py b/fms_mo/aiu_addons/fp8/fp8_linear.py index 6062665b..7f85f33a 100644 --- a/fms_mo/aiu_addons/fp8/fp8_linear.py +++ b/fms_mo/aiu_addons/fp8/fp8_linear.py @@ -20,6 +20,7 @@ import torch # Local +from fms_mo.aiu_addons.fp8 import fp8_spyre_op # pylint: disable=unused-import from fms_mo.prep import available_packages # pylint: disable=not-callable diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 696aab25..66679a8b 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -29,12 +29,7 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -aten = torch.ops.aten -DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] - - -@torch.library.register_kernel("aten::_scaled_mm", "cpu") -def _scaled_mm_cpu( +def _scaled_mm_cpu_out( mat1: Tensor, mat2: Tensor, scale1: Tensor, @@ -43,6 +38,8 @@ def _scaled_mm_cpu( scale_result: Optional[Tensor] = None, out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, + *, + out: Optional[Tensor] = None, ) -> Tensor: if out_dtype is None: out_dtype = torch.float32 @@ -50,8 +47,41 @@ def _scaled_mm_cpu( mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) if bias is not None: - return torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) - return torch.mm(mat1, mat2).to(dtype=out_dtype) + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + else: + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + + if out is not None: + out.copy_(ret) + return out + return ret + + +torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) + + +@torch.library.register_kernel("aten::_scaled_mm", "cpu") +def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> Tensor: + return _scaled_mm_cpu_out( + mat1, + mat2, + scale1, + scale2, + bias, + scale_result, + out_dtype, + use_fast_accum, + out=None, + ) @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) @@ -127,7 +157,6 @@ def scaled_paged_attn_store( Scales key and value tensors, and stores them to the paged KV cache using the same schema as vLLM. """ - print("Should never hit") result_key_cache = key_cache.clone() result_value_cache = value_cache.clone() for seq_i, slot_mapping_seq in enumerate(slot_mapping): @@ -136,10 +165,10 @@ def scaled_paged_attn_store( position = slot.item() % 64 result_key_cache[block_number, position, :, :] = ( - key[seq_i, tok_i, :, :] / key_scale + key[seq_i, tok_i, :, :] / key_scale[seq_i] ).to(dtype=torch.float8_e4m3fn) result_value_cache[block_number, position, :, :] = ( - value[seq_i, tok_i, :, :] / value_scale + value[seq_i, tok_i, :, :] / value_scale[seq_i] ).to(dtype=torch.float8_e4m3fn) return result_key_cache, result_value_cache @@ -179,7 +208,6 @@ def scaled_paged_attn_compute( Implements a CPU fallback to run the kernel that has been confirmed to match the vLLM fused kernel. """ - print("Should never hit") # torch.zeros(NUM_BLOCKS, BLOCK_SIZE, kvheads, head_size, dtype=model_dtype), output = torch.zeros_like(query) num_query_heads = query.shape[2] @@ -220,10 +248,10 @@ def scaled_paged_attn_compute( out = F.scaled_dot_product_attention( # noqa: E1102 q.transpose(0, 1).unsqueeze(0), # format for sdpa - (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale).to( + (keys.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * key_scale[i]).to( dtype=q.dtype ), # format for sdpa - (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale).to( + (values.transpose(0, 1).unsqueeze(0).to(dtype=q.dtype) * value_scale[i]).to( dtype=q.dtype ), # format for sdpa is_causal=False, # decode assumes no causal mask