diff --git a/README.md b/README.md index aa422d2f..a7348d2e 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ The following optional dependencies are available: - `mx`: `microxcaling` package for MX quantization - `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs - `torchvision`: `torch` package for image recognition training and inference +- `triton`: `triton` package for matrix multiplication kernels - `visualize`: Dependencies for visualizing models and performance data - `test`: Dependencies needed for unit testing - `dev`: Dependencies needed for development diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index 24fbde89..beae5a2e 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -15,8 +15,20 @@ """This file contains external kernels for FP and INT8 matmul written in triton.""" # Third Party -from triton.language.extra import libdevice import torch + +# Local +from fms_mo.utils.import_utils import available_packages + +# Assume any calls to the file are requesting triton +if not available_packages["triton"]: + raise ImportError( + "triton python package is not avaialble, please check your installation." + ) + +# Third Party +# pylint: disable=wrong-import-position +from triton.language.extra import libdevice import triton import triton.language as tl diff --git a/pyproject.toml b/pyproject.toml index 11a0cc52..be9723fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "accelerate>=0.20.3,!=0.34,<1.7", "transformers>=4.45,<4.52", "torch>=2.2.0,<2.6", -"triton>=3.0,<3.4", "tqdm>=4.66.2,<5.0", "datasets>=3.0.0,<4.0", "ninja>=1.11.1.1,<2.0", @@ -47,6 +46,7 @@ mx = ["microxcaling>=1.1"] opt = ["fms-model-optimizer[fp8, gptq, mx]"] torchvision = ["torchvision>=0.17"] flash-attn = ["flash-attn>=2.5.3,<3.0"] +triton = ["triton>=3.0,<3.4"] visualize = ["matplotlib", "graphviz", "pygraphviz"] dev = ["pre-commit>=3.0.4,<5.0"] test = ["pytest", "pillow"] diff --git a/tests/triton_kernels/test_triton_mm.py b/tests/triton_kernels/test_triton_mm.py index 07328888..20cd24b5 100644 --- a/tests/triton_kernels/test_triton_mm.py +++ b/tests/triton_kernels/test_triton_mm.py @@ -32,7 +32,7 @@ ) -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +@pytest.mark.parametrize("mkn", [64, 256, 1024]) @pytest.mark.parametrize( "dtype_to_test", [ @@ -43,11 +43,12 @@ torch.float8_e5m2, ], ) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_triton_matmul_fp can only when GPU is available", +) def test_triton_matmul_fp(mkn, dtype_to_test): """Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes.""" - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) m = n = k = mkn @@ -79,12 +80,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test): assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3 -@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +@pytest.mark.parametrize("mkn", [64, 256, 1024]) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_triton_matmul_int8 can only when GPU is available", +) def test_triton_matmul_int8(mkn): """Parametric tests for triton imatmul kernel using variety of tensor sizes.""" - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) m = n = k = mkn @@ -121,13 +123,14 @@ def test_triton_matmul_int8(mkn): @pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) @pytest.mark.parametrize("trun_bits", [0, 8, 12, 16]) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="test_linear_fpx_acc can only when GPU is available", +) def test_linear_fpx_acc(feat_in_out, trun_bits): """Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run on CUDA. """ - if not torch.cuda.is_available(): - # only run the test when GPU is available - return torch.manual_seed(23) feat_in, feat_out = feat_in_out