From 73a10e43049a3220bac807083b0c04a353fc9ede Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 17 Feb 2026 16:36:05 +0100 Subject: [PATCH 1/6] initial commit --- bitsandbytes/__init__.py | 3 + bitsandbytes/backends/mps/__init__.py | 0 bitsandbytes/backends/mps/ops.py | 144 ++++++++++++++++++++++++++ 3 files changed, 147 insertions(+) create mode 100644 bitsandbytes/backends/mps/__init__.py create mode 100644 bitsandbytes/backends/mps/ops.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 8bea82fb3..4d7c94abb 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -38,6 +38,9 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + from .backends.mps import ops as mps_ops + if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch diff --git a/bitsandbytes/backends/mps/__init__.py b/bitsandbytes/backends/mps/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py new file mode 100644 index 000000000..834de02d8 --- /dev/null +++ b/bitsandbytes/backends/mps/ops.py @@ -0,0 +1,144 @@ +"""MPS backend for bitsandbytes 4-bit quantization ops. + +Uses Metal kernels from kernels-community/bitsandbytes-mps via the +HuggingFace Kernels Hub. +""" + +from collections.abc import Sequence +from math import prod +from typing import Optional + +import torch + +from ..._ops import register_kernel + +# --------------------------------------------------------------------------- +# Quant-type mapping: BnB uses strings, our Metal kernel uses ints. +# --------------------------------------------------------------------------- +_QUANT_MAP = {"fp4": 1, "nf4": 2} +_kernel = None + +def _get_kernel(): + """Lazily load the bitsandbytes-mps kernel (local build or Hub).""" + global _kernel + if _kernel is None: + from kernels import get_kernel + # TODO: use kernels-community/bitsandbytes-mps when it's available + _kernel = get_kernel("medmekk/bitsandbytes-mps") + return _kernel + + +# ============================= quantize_4bit ================================= + + +@register_kernel("bitsandbytes::quantize_4bit", "mps") +def _( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check(blocksize in [64, 128]) + torch._check(quant_type in ("fp4", "nf4")) + + k = _get_kernel() + packed, absmax = k.quantize_4bit(A.contiguous(), blocksize, _QUANT_MAP[quant_type]) + + packed = packed.view(quant_storage).unsqueeze(1) + + return packed, absmax + + +# ============================ dequantize_4bit ================================ + + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + numel = prod(shape) + k = _get_kernel() + out = k.dequantize_4bit(A, absmax, blocksize, _QUANT_MAP[quant_type], numel, dtype) + return out.reshape(shape) + + +@register_kernel("bitsandbytes::dequantize_4bit", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check(blocksize in [64, 128]) + torch._check(quant_type in ("fp4", "nf4")) + return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + out.copy_(result) + + +# ================================ gemv_4bit ================================== + + +def _gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + if B.dtype != torch.uint8: + B = B.view(torch.uint8) + + quant_type_int = _QUANT_MAP["fp4"] if code[1] > 0 else _QUANT_MAP["nf4"] + output_features = shapeB[0] + + k = _get_kernel() + return k.gemv_4bit(A, B, absmax, output_features, blocksize, quant_type_int) + + +@register_kernel("bitsandbytes::gemv_4bit", "mps") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + return _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize) + + +@register_kernel("bitsandbytes::gemv_4bit.out", "mps") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + out: torch.Tensor, +) -> None: + result = _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize) + out.copy_(result) From e1578cd37f8899acf603798e84c1c8d11292f146 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:11:15 +0100 Subject: [PATCH 2/6] fix --- tests/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 8d9aa5ab2..a0f354e89 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -219,8 +219,8 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): out_features = 1024 in_features = 256 - if device == "cpu" and blocksize > in_features: - pytest.skip("CPU implementation only suppoer blocksize <= in_features") + if device in ("cpu", "mps") and blocksize > in_features: + pytest.skip("CPU/MPS implementation only supports blocksize <= in_features") A = torch.randn((1, 1, in_features), dtype=dtype, device=device) B = torch.randn((out_features, in_features), dtype=dtype, device=A.device) From ec4b140ee55baa26887f68714f2ef43024f8856c Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:11:23 +0100 Subject: [PATCH 3/6] fix --- bitsandbytes/backends/mps/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index 834de02d8..b580a5874 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -38,7 +38,7 @@ def _( quant_type: str, quant_storage: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: - torch._check(blocksize in [64, 128]) + torch._check(blocksize in [64, 128, 256, 512]) torch._check(quant_type in ("fp4", "nf4")) k = _get_kernel() @@ -78,7 +78,7 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - torch._check(blocksize in [64, 128]) + torch._check(blocksize in [64, 128, 256, 512]) torch._check(quant_type in ("fp4", "nf4")) return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) From f6630d82a87d094b0f1cf7604f916c68af69d3af Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:47:10 +0100 Subject: [PATCH 4/6] fix linter --- bitsandbytes/backends/mps/ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index b580a5874..e294fa857 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -18,13 +18,15 @@ _QUANT_MAP = {"fp4": 1, "nf4": 2} _kernel = None + def _get_kernel(): """Lazily load the bitsandbytes-mps kernel (local build or Hub).""" global _kernel if _kernel is None: - from kernels import get_kernel - # TODO: use kernels-community/bitsandbytes-mps when it's available - _kernel = get_kernel("medmekk/bitsandbytes-mps") + from kernels import get_kernel + + # TODO: use kernels-community/bitsandbytes-mps when it's available + _kernel = get_kernel("medmekk/bitsandbytes-mps") return _kernel From 5b3c3955c107fefb514594ea14b23c27b53c4012 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:56:08 +0100 Subject: [PATCH 5/6] fix --- bitsandbytes/backends/mps/ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index e294fa857..6eedfcc17 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -6,7 +6,6 @@ from collections.abc import Sequence from math import prod -from typing import Optional import torch From c34772c1ed7e2268e7afb8fd3b387f467c07e9c4 Mon Sep 17 00:00:00 2001 From: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Date: Thu, 19 Feb 2026 09:29:01 +0100 Subject: [PATCH 6/6] add kernels-community kernel --- bitsandbytes/backends/mps/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py index 6eedfcc17..83c2ae89e 100644 --- a/bitsandbytes/backends/mps/ops.py +++ b/bitsandbytes/backends/mps/ops.py @@ -25,7 +25,7 @@ def _get_kernel(): from kernels import get_kernel # TODO: use kernels-community/bitsandbytes-mps when it's available - _kernel = get_kernel("medmekk/bitsandbytes-mps") + _kernel = get_kernel("kernels-community/bitsandbytes-mps") return _kernel