diff --git a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py index 66679a8b..b5abcbfe 100644 --- a/fms_mo/aiu_addons/fp8/fp8_spyre_op.py +++ b/fms_mo/aiu_addons/fp8/fp8_spyre_op.py @@ -17,6 +17,7 @@ from typing import Optional # Third Party +from packaging.version import Version from torch import Tensor import torch import torch.nn.functional as F @@ -29,60 +30,63 @@ # open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 -def _scaled_mm_cpu_out( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - *, - out: Optional[Tensor] = None, -) -> Tensor: - if out_dtype is None: - out_dtype = torch.float32 - mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) - mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) - - if bias is not None: - ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) - else: - ret = torch.mm(mat1, mat2).to(dtype=out_dtype) - - if out is not None: - out.copy_(ret) - return out - return ret - - -torch.library.register_kernel(torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out) - - -@torch.library.register_kernel("aten::_scaled_mm", "cpu") -def _scaled_mm_cpu( - mat1: Tensor, - mat2: Tensor, - scale1: Tensor, - scale2: Tensor, - bias: Optional[Tensor] = None, - scale_result: Optional[Tensor] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, -) -> Tensor: - return _scaled_mm_cpu_out( - mat1, - mat2, - scale1, - scale2, - bias, - scale_result, - out_dtype, - use_fast_accum, - out=None, +if Version(torch.__version__) <= Version("2.7"): + # PyTorch 2.8 adds scaled_mm_out op for CPU in the ATen set, + # while for earlier versions we need a custom definition + def _scaled_mm_cpu_out( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + *, + out: Optional[Tensor] = None, + ) -> Tensor: + if out_dtype is None: + out_dtype = torch.float32 + mat1 = (mat1.to(dtype=out_dtype) * scale1).to(dtype=out_dtype) + mat2 = (mat2.to(dtype=out_dtype) * scale2).to(dtype=out_dtype) + + if bias is not None: + ret = torch.addmm(bias, mat1, mat2).to(dtype=out_dtype) + else: + ret = torch.mm(mat1, mat2).to(dtype=out_dtype) + + if out is not None: + out.copy_(ret) + return out + return ret + + torch.library.register_kernel( + torch.ops.aten._scaled_mm.out, "cpu", _scaled_mm_cpu_out ) + @torch.library.register_kernel("aten::_scaled_mm", "cpu") + def _scaled_mm_cpu( + mat1: Tensor, + mat2: Tensor, + scale1: Tensor, + scale2: Tensor, + bias: Optional[Tensor] = None, + scale_result: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + ) -> Tensor: + return _scaled_mm_cpu_out( + mat1, + mat2, + scale1, + scale2, + bias, + scale_result, + out_dtype, + use_fast_accum, + out=None, + ) + @torch.library.custom_op("spyre::scaled_bmm", mutates_args=()) def spyre_scaled_bmm( diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py index 41aa896f..4bea6114 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py @@ -36,10 +36,8 @@ def implement_op_decorator(op_namespace_id): Always compare against pytorch version in current environment. """ - torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) - def decorator(func): - if torch_version < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): return torch.library.impl(op_namespace_id, "default")(func) return torch.library.custom_op(op_namespace_id, mutates_args=())(func) @@ -51,10 +49,8 @@ def register_op_decorator(op_namespace_id): Always compare against pytorch version in current environment. """ - torch_version = Version(torch.__version__.split("+", maxsplit=1)[0]) - def decorator(func): - if torch_version < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): return torch.library.impl_abstract(op_namespace_id)(func) return torch.library.register_fake(op_namespace_id)(func) @@ -73,7 +69,7 @@ def register_aiu_i8i8_op(): logger.warning("AIU op has already been registered") return op_namespace_id = "fms_mo::i8i8_aiu" - if Version(torch.__version__.split("+", maxsplit=1)[0]) < Version("2.4"): + if Version(torch.__version__) < Version("2.4"): torch.library.define( op_namespace_id, "(Tensor x, Tensor weight, Tensor bias, Tensor qdata, " diff --git a/tests/aiu_addons/test_fp8_addon.py b/tests/aiu_addons/test_fp8_addon.py index 81d263b3..a382c63c 100644 --- a/tests/aiu_addons/test_fp8_addon.py +++ b/tests/aiu_addons/test_fp8_addon.py @@ -51,9 +51,9 @@ def test_fp8_op() -> None: # Local from fms_mo.aiu_addons.fp8.fp8_attn import _math_fp8_compute_op - query = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") - key = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") - value = torch.randn((1, 32, 64, 128), dtype=torch.bfloat16, device="cuda") + query = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") + key = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") + value = torch.randn((1, 64, 32, 128), dtype=torch.bfloat16, device="cuda") out = _math_fp8_compute_op(query, key, value, 32, 32, 0.0, None) assert out.size() == query.size()