From c5a4f44c38a98944eeaa5fcd7379e8bb70a6d891 Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Mon, 23 Jun 2025 10:00:50 -0400 Subject: [PATCH 1/4] build: Move triton to optional Signed-off-by: Brandon Groth --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 11a0cc52..be9723fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "accelerate>=0.20.3,!=0.34,<1.7", "transformers>=4.45,<4.52", "torch>=2.2.0,<2.6", -"triton>=3.0,<3.4", "tqdm>=4.66.2,<5.0", "datasets>=3.0.0,<4.0", "ninja>=1.11.1.1,<2.0", @@ -47,6 +46,7 @@ mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] torchvision = ["torchvision>=0.17"] flash-attn = ["flash-attn>=2.5.3,<3.0"] +triton = ["triton>=3.0,<3.4"] visualize = ["matplotlib", "graphviz", "pygraphviz"] dev = ["pre-commit>=3.0.4,<5.0"] test = ["pytest", "pillow"] From 67af0b14ffa7142c268cdb0398cb5fd6be6d4879 Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Mon, 23 Jun 2025 10:01:41 -0400 Subject: [PATCH 2/4] build: Add guard for entering triton_kernels.py Signed-off-by: Brandon Groth --- fms_mo/custom_ext_kernels/triton_kernels.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 24fbde89..7316dc25 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -15,11 +15,23 @@ """This file contains external kernels for FP and INT8 matmul written in triton.""" # Third Party -from triton.language.extra import libdevice import torch + +# First Party +from fms_mo.utils.import_utils import available_packages + +# Assume any calls to the file are requesting triton +if not available_packages["triton"]: + raise ImportError( + "triton python package is not avaialble, please check your installation." + ) + +# Third Party +from triton.language.extra import libdevice import triton import triton.language as tl + DTYPE_I8 = [torch.int8] DTYPE_F8 = [torch.float8_e4m3fn, torch.float8_e5m2] DTYPE_8BIT = DTYPE_I8 + DTYPE_F8 From f9f97d43772e5bcba9d556f7ca83afa6b279d9e6 Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Mon, 23 Jun 2025 10:02:45 -0400 Subject: [PATCH 3/4] test: Moved triton quick return to skip and reduced test sizes Signed-off-by: Brandon Groth --- fms_mo/custom_ext_kernels/triton_kernels.py | 4 ++-- tests/triton_kernels/test_triton_mm.py | 25 ++++++++++++--------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 7316dc25..beae5a2e 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -17,7 +17,7 @@ # Third Party import torch -# First Party +# Local from fms_mo.utils.import_utils import available_packages # Assume any calls to the file are requesting triton @@ -27,11 +27,11 @@ ) # Third Party +# pylint: disable=wrong-import-position from triton.language.extra import libdevice import triton import triton.language as tl - DTYPE_I8 = [torch.int8] DTYPE_F8 = [torch.float8_e4m3fn, torch.float8_e5m2] DTYPE_8BIT = DTYPE_I8 + DTYPE_F8 diff --git a/tests/triton_kernels/test_triton_mm.py b/tests/triton_kernels/test_triton_mm.py index 07328888..20cd24b5 100644 --- a/tests/triton_kernels/test_triton_mm.py +++ b/tests/triton_kernels/test_triton_mm.py @@ -32,7 +32,7 @@ ) -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +@pytest.mark.parametrize("mkn", [64, 256, 1024]) @pytest.mark.parametrize( "dtype_to_test", [ @@ -43,11 +43,12 @@ torch.float8_e5m2, ], ) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_triton_matmul_fp can only when GPU is available", +) def test_triton_matmul_fp(mkn, dtype_to_test): """Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes.""" - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) m = n = k = mkn @@ -79,12 +80,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test): assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3 -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +@pytest.mark.parametrize("mkn", [64, 256, 1024]) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_triton_matmul_int8 can only when GPU is available", +) def test_triton_matmul_int8(mkn): """Parametric tests for triton imatmul kernel using variety of tensor sizes.""" - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) m = n = k = mkn @@ -121,13 +123,14 @@ def test_triton_matmul_int8(mkn): @pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) @pytest.mark.parametrize("trun_bits", [0, 8, 12, 16]) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_linear_fpx_acc can only when GPU is available", +) def test_linear_fpx_acc(feat_in_out, trun_bits): """Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run on CUDA. """ - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) feat_in, feat_out = feat_in_out From 7df7d791a3cba3f4c3e8c66fe8f5c7c293601505 Mon Sep 17 00:00:00 2001 From: Brandon Groth Date: Mon, 23 Jun 2025 15:27:42 -0400 Subject: [PATCH 4/4] docs: Added optional dependency for triton in readme Signed-off-by: Brandon Groth --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index aa422d2f..a7348d2e 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ The following optional dependencies are available: - `mx`: `microxcaling` package for MX quantization - `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs - `torchvision`: `torch` package for image recognition training and inference +- `triton`: `triton` package for matrix multiplication kernels - `visualize`: Dependencies for visualizing models and performance data - `test`: Dependencies needed for unit testing - `dev`: Dependencies needed for development