Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ The following optional dependencies are available:
- `mx`: `microxcaling` package for MX quantization
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
- `torchvision`: `torch` package for image recognition training and inference
- `triton`: `triton` package for matrix multiplication kernels
- `visualize`: Dependencies for visualizing models and performance data
- `test`: Dependencies needed for unit testing
- `dev`: Dependencies needed for development
Expand Down
14 changes: 13 additions & 1 deletion fms_mo/custom_ext_kernels/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,20 @@
"""This file contains external kernels for FP and INT8 matmul written in triton."""

# Third Party
from triton.language.extra import libdevice
import torch

# Local
from fms_mo.utils.import_utils import available_packages

# Assume any calls to the file are requesting triton
if not available_packages["triton"]:
raise ImportError(
"triton python package is not avaialble, please check your installation."
)

# Third Party
# pylint: disable=wrong-import-position
from triton.language.extra import libdevice
import triton
import triton.language as tl

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies = [
"accelerate>=0.20.3,!=0.34,<1.7",
"transformers>=4.45,<4.52",
"torch>=2.2.0,<2.6",
"triton>=3.0,<3.4",
"tqdm>=4.66.2,<5.0",
"datasets>=3.0.0,<4.0",
"ninja>=1.11.1.1,<2.0",
Expand All @@ -47,6 +46,7 @@ mx = ["microxcaling>=1.1"]
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
torchvision = ["torchvision>=0.17"]
flash-attn = ["flash-attn>=2.5.3,<3.0"]
triton = ["triton>=3.0,<3.4"]
visualize = ["matplotlib", "graphviz", "pygraphviz"]
dev = ["pre-commit>=3.0.4,<5.0"]
test = ["pytest", "pillow"]
Expand Down
25 changes: 14 additions & 11 deletions tests/triton_kernels/test_triton_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)


@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
@pytest.mark.parametrize("mkn", [64, 256, 1024])
@pytest.mark.parametrize(
"dtype_to_test",
[
Expand All @@ -43,11 +43,12 @@
torch.float8_e5m2,
],
)
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="test_triton_matmul_fp can only when GPU is available",
)
def test_triton_matmul_fp(mkn, dtype_to_test):
"""Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes."""
if not torch.cuda.is_available():
# only run the test when GPU is available
return

torch.manual_seed(23)
m = n = k = mkn
Expand Down Expand Up @@ -79,12 +80,13 @@ def test_triton_matmul_fp(mkn, dtype_to_test):
assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3


@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096])
@pytest.mark.parametrize("mkn", [64, 256, 1024])
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="test_triton_matmul_int8 can only when GPU is available",
)
def test_triton_matmul_int8(mkn):
"""Parametric tests for triton imatmul kernel using variety of tensor sizes."""
if not torch.cuda.is_available():
# only run the test when GPU is available
return

torch.manual_seed(23)
m = n = k = mkn
Expand Down Expand Up @@ -121,13 +123,14 @@ def test_triton_matmul_int8(mkn):

@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)])
@pytest.mark.parametrize("trun_bits", [0, 8, 12, 16])
@pytest.mark.skipif(
not torch.cuda.is_available(),
reason="test_linear_fpx_acc can only when GPU is available",
)
def test_linear_fpx_acc(feat_in_out, trun_bits):
"""Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run
on CUDA.
"""
if not torch.cuda.is_available():
# only run the test when GPU is available
return

torch.manual_seed(23)
feat_in, feat_out = feat_in_out
Expand Down
Loading