Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
from .backends.mps import ops as mps_ops

if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
Expand Down
Empty file.
145 changes: 145 additions & 0 deletions bitsandbytes/backends/mps/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""MPS backend for bitsandbytes 4-bit quantization ops.

Uses Metal kernels from kernels-community/bitsandbytes-mps via the
HuggingFace Kernels Hub.
"""

from collections.abc import Sequence
from math import prod

import torch

from ..._ops import register_kernel

# ---------------------------------------------------------------------------
# Quant-type mapping: BnB uses strings, our Metal kernel uses ints.
# ---------------------------------------------------------------------------
_QUANT_MAP = {"fp4": 1, "nf4": 2}
_kernel = None


def _get_kernel():
"""Lazily load the bitsandbytes-mps kernel (local build or Hub)."""
global _kernel
if _kernel is None:
from kernels import get_kernel

# TODO: use kernels-community/bitsandbytes-mps when it's available
_kernel = get_kernel("kernels-community/bitsandbytes-mps")
return _kernel


# ============================= quantize_4bit =================================


@register_kernel("bitsandbytes::quantize_4bit", "mps")
def _(
A: torch.Tensor,
blocksize: int,
quant_type: str,
quant_storage: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [64, 128, 256, 512])
torch._check(quant_type in ("fp4", "nf4"))

k = _get_kernel()
packed, absmax = k.quantize_4bit(A.contiguous(), blocksize, _QUANT_MAP[quant_type])

packed = packed.view(quant_storage).unsqueeze(1)

return packed, absmax


# ============================ dequantize_4bit ================================


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

numel = prod(shape)
k = _get_kernel()
out = k.dequantize_4bit(A, absmax, blocksize, _QUANT_MAP[quant_type], numel, dtype)
return out.reshape(shape)


@register_kernel("bitsandbytes::dequantize_4bit", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check(blocksize in [64, 128, 256, 512])
torch._check(quant_type in ("fp4", "nf4"))
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
out.copy_(result)


# ================================ gemv_4bit ==================================


def _gemv_4bit_impl(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
if B.dtype != torch.uint8:
B = B.view(torch.uint8)

quant_type_int = _QUANT_MAP["fp4"] if code[1] > 0 else _QUANT_MAP["nf4"]
output_features = shapeB[0]

k = _get_kernel()
return k.gemv_4bit(A, B, absmax, output_features, blocksize, quant_type_int)


@register_kernel("bitsandbytes::gemv_4bit", "mps")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
return _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize)


@register_kernel("bitsandbytes::gemv_4bit.out", "mps")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
out: torch.Tensor,
) -> None:
result = _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize)
out.copy_(result)
4 changes: 2 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
out_features = 1024
in_features = 256

if device == "cpu" and blocksize > in_features:
pytest.skip("CPU implementation only suppoer blocksize <= in_features")
if device in ("cpu", "mps") and blocksize > in_features:
pytest.skip("CPU/MPS implementation only supports blocksize <= in_features")

A = torch.randn((1, 1, in_features), dtype=dtype, device=device)
B = torch.randn((out_features, in_features), dtype=dtype, device=A.device)
Expand Down
Loading