From 6c15512dd8e1aa2640a253584043a832478f45f5 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 16 Jul 2025 15:23:51 +0000 Subject: [PATCH 1/2] Mark scale dimensions to have the same batch size as input Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_attn.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 1d31e827..2b4c4962 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -317,6 +317,23 @@ def _spyre_scaled_paged_compute_op( attn_kwargs["left_padded_prompt_mask"], attn_kwargs["block_table"], ) + + def __spyre_scaled_paged_validate_attn_kwargs_op( + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, + **attn_kwargs, + ): + __spyre_paged_validate_attn_kwargs_op(input_ids, position_ids, past_key_value_states, **attn_kwargs) + + if past_key_value_states is not None: + for k, v in past_key_value_states: + assert isinstance(k, ScaledTensor) + assert isinstance(v, ScaledTensor) + + # assert that for each layer, the scales are per-sequence + assert k._scale.shape[0] == input_ids.shape[0] + assert v._scale.shape[0] == input_ids.shape[0] register_attention_op( "spyre_paged_attn_fp8", @@ -325,5 +342,5 @@ def _spyre_scaled_paged_compute_op( is_prefill_op=lambda **attn_kwargs: attn_kwargs.get("block_table", None) is None, compute_decode_op=_spyre_scaled_paged_compute_op, - validate_attn_kwargs_op=__spyre_paged_validate_attn_kwargs_op, + validate_attn_kwargs_op=__spyre_scaled_paged_validate_attn_kwargs_op, ) From c1a68d7c220a1105501222b182913aad4bf2feb4 Mon Sep 17 00:00:00 2001 From: Antoni Viros i Martin Date: Wed, 16 Jul 2025 15:32:10 +0000 Subject: [PATCH 2/2] linting Signed-off-by: Antoni Viros i Martin --- fms_mo/aiu_addons/fp8/fp8_attn.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_attn.py b/fms_mo/aiu_addons/fp8/fp8_attn.py index 2b4c4962..e4e4224e 100644 --- a/fms_mo/aiu_addons/fp8/fp8_attn.py +++ b/fms_mo/aiu_addons/fp8/fp8_attn.py @@ -317,14 +317,16 @@ def _spyre_scaled_paged_compute_op( attn_kwargs["left_padded_prompt_mask"], attn_kwargs["block_table"], ) - + def __spyre_scaled_paged_validate_attn_kwargs_op( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_value_states: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, **attn_kwargs, ): - __spyre_paged_validate_attn_kwargs_op(input_ids, position_ids, past_key_value_states, **attn_kwargs) + __spyre_paged_validate_attn_kwargs_op( + input_ids, position_ids, past_key_value_states, **attn_kwargs + ) if past_key_value_states is not None: for k, v in past_key_value_states: