diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 999116c97..88f448bcd 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -1,14 +1,16 @@ from collections.abc import Sequence import warnings +from packaging import version import torch from ..._ops import register_kernel from ..utils import ipex_xpu, triton_available -# _int_mm is available in torch starting from 2.7 version, -# but currently it's don't have xpu implementation. -if ipex_xpu and torch.__version__ >= (2, 7): +# _int_mm is available in torch starting from 2.9 version, or ipex 2.7 +if version.parse(torch.__version__).release >= version.parse("2.9").release or ( + ipex_xpu and torch.__version__ >= (2, 7) +): @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") def _(A: torch.Tensor, B: torch.Tensor): diff --git a/pyproject.toml b/pyproject.toml index d26832e4f..6626d1fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,7 +43,8 @@ classifiers = [ ] dependencies = [ "torch>=2.2,<3", - "numpy>=1.17" + "numpy>=1.17", + "packaging>=20.9" ] [project.urls]