Skip to content

Commit bcee5f3

Browse files
Merge pull request #146 from BrandonGroth/no_triton
build: Move triton to an optional dependency
2 parents 06e371a + 7df7d79 commit bcee5f3

File tree

4 files changed

+29
-13
lines changed

4 files changed

+29
-13
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ The following optional dependencies are available:
105105
- `mx`: `microxcaling` package for MX quantization
106106
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
107107
- `torchvision`: `torch` package for image recognition training and inference
108+
- `triton`: `triton` package for matrix multiplication kernels
108109
- `visualize`: Dependencies for visualizing models and performance data
109110
- `test`: Dependencies needed for unit testing
110111
- `dev`: Dependencies needed for development

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,20 @@
1515
"""This file contains external kernels for FP and INT8 matmul written in triton."""
1616

1717
# Third Party
18-
from triton.language.extra import libdevice
1918
import torch
19+
20+
# Local
21+
from fms_mo.utils.import_utils import available_packages
22+
23+
# Assume any calls to the file are requesting triton
24+
if not available_packages["triton"]:
25+
raise ImportError(
26+
"triton python package is not avaialble, please check your installation."
27+
)
28+
29+
# Third Party
30+
# pylint: disable=wrong-import-position
31+
from triton.language.extra import libdevice
2032
import triton
2133
import triton.language as tl
2234

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ dependencies = [
2626
"accelerate>=0.20.3,!=0.34,<1.7",
2727
"transformers>=4.45,<4.53",
2828
"torch>=2.2.0,<2.6",
29-
"triton>=3.0,<3.4",
3029
"tqdm>=4.66.2,<5.0",
3130
"datasets>=3.0.0,<4.0",
3231
"ninja>=1.11.1.1,<2.0",
@@ -47,6 +46,7 @@ mx = ["microxcaling>=1.1"]
4746
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
4847
torchvision = ["torchvision>=0.17"]
4948
flash-attn = ["flash-attn>=2.5.3,<3.0"]
49+
triton = ["triton>=3.0,<3.4"]
5050
visualize = ["matplotlib", "graphviz", "pygraphviz"]
5151
dev = ["pre-commit>=3.0.4,<5.0"]
5252
test = ["pytest", "pillow"]

tests/triton_kernels/test_triton_mm.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
35+
@pytest.mark.parametrize("mkn", [64, 256, 1024])
3636
@pytest.mark.parametrize(
3737
"dtype_to_test",
3838
[
@@ -43,11 +43,12 @@
4343
torch.float8_e5m2,
4444
],
4545
)
46+
@pytest.mark.skipif(
47+
not torch.cuda.is_available(),
48+
reason="test_triton_matmul_fp can only when GPU is available",
49+
)
4650
def test_triton_matmul_fp(mkn, dtype_to_test):
4751
"""Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes."""
48-
if not torch.cuda.is_available():
49-
# only run the test when GPU is available
50-
return
5152

5253
torch.manual_seed(23)
5354
m = n = k = mkn
@@ -81,12 +82,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test):
8182
assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3
8283

8384

84-
@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
85+
@pytest.mark.parametrize("mkn", [64, 256, 1024])
86+
@pytest.mark.skipif(
87+
not torch.cuda.is_available(),
88+
reason="test_triton_matmul_int8 can only when GPU is available",
89+
)
8590
def test_triton_matmul_int8(mkn):
8691
"""Parametric tests for triton imatmul kernel using variety of tensor sizes."""
87-
if not torch.cuda.is_available():
88-
# only run the test when GPU is available
89-
return
9092

9193
torch.manual_seed(23)
9294
m = n = k = mkn
@@ -123,13 +125,14 @@ def test_triton_matmul_int8(mkn):
123125

124126
@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)])
125127
@pytest.mark.parametrize("trun_bits", [0, 8, 12, 16])
128+
@pytest.mark.skipif(
129+
not torch.cuda.is_available(),
130+
reason="test_linear_fpx_acc can only when GPU is available",
131+
)
126132
def test_linear_fpx_acc(feat_in_out, trun_bits):
127133
"""Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run
128134
on CUDA.
129135
"""
130-
if not torch.cuda.is_available():
131-
# only run the test when GPU is available
132-
return
133136

134137
torch.manual_seed(23)
135138
feat_in, feat_out = feat_in_out

0 commit comments

Comments
 (0)