Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Dispatcher shimming the editable layout.

When this repository is used via ``pip install -e .`` the real Python
package lives under ``bitsandbytes/bitsandbytes``. Importing from the
workspace root (e.g. running scripts from ``.../ai/kernels``) would
otherwise resolve to this outer directory, yielding a namespace module
with no attributes. Import the inner package eagerly and mirror its
symbols so ``import bitsandbytes`` always behaves the same as the
installed wheel.
"""
Comment on lines +1 to +10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this; CI for tests currently breaks with this and I wouldn't normally expect anyone to run a script from the project root.

I've been thinking separately of switching to a src layout from the flat layout we have now, which would probably take care of things like this too, but probably out of scope here. I think we could just move test_bnb_mac.py into examples.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, of coursed, this is just a draft to get the discussion started on what we’ll need. I managed to get it working using a shim to pass the underlying MTL buffer to the kernel, but ran into another issue: I have to commit the work and wait manually, since I can’t synchronize with the PyTorch stream afterward. I’m also working on another PR to integrate libtorch and see if it gives us better performance.


from __future__ import annotations

import importlib
from types import ModuleType

_inner: ModuleType = importlib.import_module(".bitsandbytes", __name__)

Check failure on line 17 in __init__.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (F821)

__init__.py:17:9: F821 Undefined name `ModuleType`

# Copy dunder metadata expected by consumers.
for _name in ("__all__", "__doc__", "__file__", "__loader__", "__path__", "__spec__", "__version__"):
if hasattr(_inner, _name):
globals()[_name] = getattr(_inner, _name)

# Re-export public symbols while leaving dunders alone.
for _name, _value in vars(_inner).items():
if not _name.startswith("__"):
globals()[_name] = _value

del _inner, _name, _value, ModuleType, importlib

5 changes: 4 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from . import _ops, research, utils
from . import _ops, nn, research, utils
from .autograd._functions import (
MatmulLtState,
matmul,
Expand Down Expand Up @@ -38,6 +38,9 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
from .backends.mps import ops as mps_ops

if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
Expand Down
12 changes: 8 additions & 4 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
return out


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
def _quantize_4bit_impl(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
Expand Down Expand Up @@ -232,6 +231,13 @@ def _(
return packed, absmax.float()


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage)


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
Expand All @@ -243,7 +249,6 @@ def _dequantize_4bit_impl(
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
Expand Down Expand Up @@ -290,7 +295,6 @@ def _(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)

return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


Expand Down
2 changes: 2 additions & 0 deletions bitsandbytes/backends/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# MPS backend registrations are defined in ops.py

188 changes: 188 additions & 0 deletions bitsandbytes/backends/mps/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from collections.abc import Sequence
from typing import Optional

import ctypes as ct
from ctypes import _CFuncPtr
import torch

from ..._ops import register_kernel
from ...cextension import lib
from ..default.ops import _dequantize_4bit_impl, _quantize_4bit_impl
from ..utils import CODE
from .shim import MPSTensorShim#, configure_mps_blockwise_kernel


def _sync_mps_if_needed() -> None:
if torch.backends.mps.is_available():
torch.mps.synchronize()


def _check_mps_device(tensor: torch.Tensor, name: str) -> None:
torch._check(
tensor.device.type == "mps",
lambda: f"{name} must live on an MPS device for the MPS backend, got {tensor.device.type}",
)


def _supports_dtype(dtype: torch.dtype) -> bool:
return dtype in (torch.float16, torch.float32)


def _resolve_quant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]:
try:
if dtype == torch.float16:
fn = getattr(
lib,
"cquantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp16_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
if dtype == torch.float32:
fn = getattr(
lib,
"cquantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp32_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
except AttributeError:
return None
return None


def _resolve_dequant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]:
try:
if dtype == torch.float16:
fn = getattr(
lib,
"cdequantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp16_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
if dtype == torch.float32:
fn = getattr(
lib,
"cdequantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp32_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
except AttributeError:
return None
return None


def _quantize_4bit_native(
A: torch.Tensor,
blocksize: int,
quant_type: str,
quant_storage: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor] | None:
if quant_storage != torch.uint8 or not _supports_dtype(A.dtype):
return None

fn = _resolve_quant_fn(A.dtype, quant_type)
if fn is None:
return None

n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)

input_shim = MPSTensorShim.from_tensor(A)
absmax_shim = MPSTensorShim.from_tensor(absmax)
out_shim = MPSTensorShim.from_tensor(out)

_sync_mps_if_needed()
fn(
input_shim.struct,
absmax_shim.struct,
out_shim.struct,
ct.c_int32(blocksize),
ct.c_int32(n),
)
return out, absmax


def _dequantize_4bit_native(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> bool:
if A.dtype != torch.uint8 or not _supports_dtype(dtype):
return False

_check_mps_device(absmax, "absmax")
fn = _resolve_dequant_fn(dtype, quant_type)
if fn is None:
return False

packed_shim = MPSTensorShim.from_tensor(A)
absmax_shim = MPSTensorShim.from_tensor(absmax)
out_shim = MPSTensorShim.from_tensor(out)

_sync_mps_if_needed()
fn(
packed_shim.struct,
absmax_shim.struct,
out_shim.struct,
ct.c_int32(blocksize),
ct.c_int32(out.numel()),
)

return True


@register_kernel("bitsandbytes::quantize_4bit", "mps")
def _(
A: torch.Tensor,
blocksize: int,
quant_type: str,
quant_storage: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
_check_mps_device(A, "A")
# result = _quantize_4bit_native(A, blocksize, quant_type, quant_storage)
# if result is not None:
# return result
return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage)


@register_kernel("bitsandbytes::dequantize_4bit", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
_check_mps_device(A, "A")
_check_mps_device(absmax, "absmax")

out = torch.empty(shape, dtype=dtype, device=A.device)
if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
return out
# return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
_check_mps_device(A, "A")
_check_mps_device(out, "out")
_check_mps_device(absmax, "absmax")
torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")

_dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out)
# result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
# out.copy_(result)
63 changes: 63 additions & 0 deletions bitsandbytes/backends/mps/shim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import ctypes as ct
from dataclasses import dataclass
from typing import Callable

import torch


class _BNBMPSTensor(ct.Structure):
_fields_ = [
("storage", ct.c_void_p),
("byte_offset", ct.c_size_t),
("nbytes", ct.c_size_t),
]


@dataclass(slots=True)
class MPSTensorShim:
"""
Lightweight wrapper that keeps a Tensor alive while exposing its Metal storage.

PyTorch stores an ``id<MTLBuffer>`` inside the tensor's untyped storage data
pointer on MPS. We capture that pointer once and forward the storage offset
so native kernels can bind the correct buffer without any host copies.
"""

tensor: torch.Tensor
struct: _BNBMPSTensor

@classmethod
def from_tensor(cls, tensor: torch.Tensor) -> "MPSTensorShim":
if hasattr(tensor, "untyped_storage"):
storage = tensor.untyped_storage()
else:
storage = tensor.storage()

storage_ptr = storage.data_ptr()
byte_offset = tensor.storage_offset() * tensor.element_size()
nbytes = tensor.nbytes

struct = _BNBMPSTensor(
ct.c_void_p(storage_ptr),
ct.c_size_t(byte_offset),
ct.c_size_t(nbytes),
)
return cls(tensor=tensor, struct=struct)


# def configure_mps_blockwise_kernel(fn: Callable[[object], None]) -> None:
# """
# Ensure ctypes knows the function expects our tensor shim structs by value.
# """

# try:
# argtypes = getattr(fn, "argtypes")
# except AttributeError:
# argtypes = None

# desired = [_BNBMPSTensor, _BNBMPSTensor, _BNBMPSTensor, ct.c_int32, ct.c_int32]
# if argtypes != desired:
# fn.argtypes = desired
# if getattr(fn, "restype", None) is not None:
# fn.restype = None

Loading
Loading