From eaad33b4c01fac5e5bc294edc1642d80bcdff3d0 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Tue, 14 Oct 2025 16:50:19 -0400 Subject: [PATCH 1/3] Remove custom scaled bmm op on cpu and fix fp8 test Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 58 --------------------------- tests/aiu_addons/test_fp8_addon.py | 6 +-- 2 files changed, 3 insertions(+), 61 deletions(-) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 66679a8b..b1c7482c 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -13,9 +13,6 @@ # limitations under the License. """Torch registration of FP8xFP8 operation for attention BMMs.""" -# Standard -from typing import Optional - # Third Party from torch import Tensor import torch @@ -29,61 +26,6 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -def _scaled_mm_cpu_out( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - *, - out: Optional[Tensor] = None, -) -> Tensor: - if out_dtype is None: - out_dtype = torch.float32 - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) - - if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) - else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) - - if out is not None: - out.copy_(ret) - return out - return ret - - -torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) - - -@torch.library.register_kernel("aten::_scaled_mm", "cpu") -def _scaled_mm_cpu( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, -) -> Tensor: - return _scaled_mm_cpu_out( - mat1, - mat2, - scale1, - scale2, - bias, - scale_result, - out_dtype, - use_fast_accum, - out=None, - ) - - @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def spyre_scaled_bmm( mat1: Tensor, diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 81d263b3..a382c63c 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -51,9 +51,9 @@ def test_fp8_op() -> None: # Local from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op - query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") - key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") - value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + query = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") + key = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") + value = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) assert out.size() == query.size() From f61c56110c9691d10adff5eff10cd71e065bda5a Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 15 Oct 2025 17:37:09 -0400 Subject: [PATCH 2/3] Re-enable custom op for pt<=2.7 Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/fp8/fp8_spyre_op.py | 62 +++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index b1c7482c..b5abcbfe 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -13,7 +13,11 @@ # limitations under the License. """Torch registration of FP8xFP8 operation for attention BMMs.""" +# Standard +from typing import Optional + # Third Party +from packaging.version import Version from torch import Tensor import torch import torch.nn.functional as F @@ -26,6 +30,64 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 +if Version(torch.__version__) <= Version("2.7"): + # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, + # while for earlier versions we need a custom definition + def _scaled_mm_cpu_out( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + *, + out: Optional[Tensor] = None, + ) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) + + if bias is not None: + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + else: + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + + if out is not None: + out.copy_(ret) + return out + return ret + + torch.library.register_kernel( + torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out + ) + + @torch.library.register_kernel("aten::_scaled_mm", "cpu") + def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + ) -> Tensor: + return _scaled_mm_cpu_out( + mat1, + mat2, + scale1, + scale2, + bias, + scale_result, + out_dtype, + use_fast_accum, + out=None, + ) + + @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def spyre_scaled_bmm( mat1: Tensor, From 1c4f22e0c5af923ff206edb7135af74a4a41c055 Mon Sep 17 00:00:00 2001 From: Andrea Fasoli Date: Wed, 15 Oct 2025 17:44:59 -0400 Subject: [PATCH 3/3] Clean up versioning for int8 aiu op Signed-off-by: Andrea Fasoli --- fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py index 41aa896f..4bea6114 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py @@ -36,10 +36,8 @@ def implement_op_decorator(op_namespace_id): Always compare against pytorch version in current environment. """ - torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) - def decorator(func): - if torch_version < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): return torch.library.impl(op_namespace_id, "default")(func) return torch.library.custom_op(op_namespace_id, mutates_args=())(func) @@ -51,10 +49,8 @@ def register_op_decorator(op_namespace_id): Always compare against pytorch version in current environment. """ - torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) - def decorator(func): - if torch_version < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): return torch.library.impl_abstract(op_namespace_id)(func) return torch.library.register_fake(op_namespace_id)(func) @@ -73,7 +69,7 @@ def register_aiu_i8i8_op(): logger.warning("AIU op has already been registered") return op_namespace_id = "fms_mo::i8i8_aiu" - if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): torch.library.define( op_namespace_id, "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, "