-
-
Notifications
You must be signed in to change notification settings - Fork 799
Add mps backend (python only) #1823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MekkCyber
wants to merge
2
commits into
bitsandbytes-foundation:main
Choose a base branch
from
MekkCyber:add-metal-backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| """Dispatcher shimming the editable layout. | ||
|
|
||
| When this repository is used via ``pip install -e .`` the real Python | ||
| package lives under ``bitsandbytes/bitsandbytes``. Importing from the | ||
| workspace root (e.g. running scripts from ``.../ai/kernels``) would | ||
| otherwise resolve to this outer directory, yielding a namespace module | ||
| with no attributes. Import the inner package eagerly and mirror its | ||
| symbols so ``import bitsandbytes`` always behaves the same as the | ||
| installed wheel. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import importlib | ||
| from types import ModuleType | ||
|
|
||
| _inner: ModuleType = importlib.import_module(".bitsandbytes", __name__) | ||
|
|
||
| # Copy dunder metadata expected by consumers. | ||
| for _name in ("__all__", "__doc__", "__file__", "__loader__", "__path__", "__spec__", "__version__"): | ||
| if hasattr(_inner, _name): | ||
| globals()[_name] = getattr(_inner, _name) | ||
|
|
||
| # Re-export public symbols while leaving dunders alone. | ||
| for _name, _value in vars(_inner).items(): | ||
| if not _name.startswith("__"): | ||
| globals()[_name] = _value | ||
|
|
||
| del _inner, _name, _value, ModuleType, importlib | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # MPS backend registrations are defined in ops.py | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| from collections.abc import Sequence | ||
| from typing import Optional | ||
|
|
||
| import ctypes as ct | ||
| from ctypes import _CFuncPtr | ||
| import torch | ||
|
|
||
| from ..._ops import register_kernel | ||
| from ...cextension import lib | ||
| from ..default.ops import _dequantize_4bit_impl, _quantize_4bit_impl | ||
| from ..utils import CODE | ||
| from .shim import MPSTensorShim#, configure_mps_blockwise_kernel | ||
|
|
||
|
|
||
| def _sync_mps_if_needed() -> None: | ||
| if torch.backends.mps.is_available(): | ||
| torch.mps.synchronize() | ||
|
|
||
|
|
||
| def _check_mps_device(tensor: torch.Tensor, name: str) -> None: | ||
| torch._check( | ||
| tensor.device.type == "mps", | ||
| lambda: f"{name} must live on an MPS device for the MPS backend, got {tensor.device.type}", | ||
| ) | ||
|
|
||
|
|
||
| def _supports_dtype(dtype: torch.dtype) -> bool: | ||
| return dtype in (torch.float16, torch.float32) | ||
|
|
||
|
|
||
| def _resolve_quant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: | ||
| try: | ||
| if dtype == torch.float16: | ||
| fn = getattr( | ||
| lib, | ||
| "cquantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp16_nf4", | ||
| ) | ||
| # configure_mps_blockwise_kernel(fn) | ||
| return fn | ||
| if dtype == torch.float32: | ||
| fn = getattr( | ||
| lib, | ||
| "cquantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp32_nf4", | ||
| ) | ||
| # configure_mps_blockwise_kernel(fn) | ||
| return fn | ||
| except AttributeError: | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def _resolve_dequant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: | ||
| try: | ||
| if dtype == torch.float16: | ||
| fn = getattr( | ||
| lib, | ||
| "cdequantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp16_nf4", | ||
| ) | ||
| # configure_mps_blockwise_kernel(fn) | ||
| return fn | ||
| if dtype == torch.float32: | ||
| fn = getattr( | ||
| lib, | ||
| "cdequantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp32_nf4", | ||
| ) | ||
| # configure_mps_blockwise_kernel(fn) | ||
| return fn | ||
| except AttributeError: | ||
| return None | ||
| return None | ||
|
|
||
|
|
||
| def _quantize_4bit_native( | ||
| A: torch.Tensor, | ||
| blocksize: int, | ||
| quant_type: str, | ||
| quant_storage: torch.dtype, | ||
| ) -> tuple[torch.Tensor, torch.Tensor] | None: | ||
| if quant_storage != torch.uint8 or not _supports_dtype(A.dtype): | ||
| return None | ||
|
|
||
| fn = _resolve_quant_fn(A.dtype, quant_type) | ||
| if fn is None: | ||
| return None | ||
|
|
||
| n = A.numel() | ||
| blocks = -(n // -blocksize) | ||
| absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) | ||
| out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) | ||
|
|
||
| input_shim = MPSTensorShim.from_tensor(A) | ||
| absmax_shim = MPSTensorShim.from_tensor(absmax) | ||
| out_shim = MPSTensorShim.from_tensor(out) | ||
|
|
||
| _sync_mps_if_needed() | ||
| fn( | ||
| input_shim.struct, | ||
| absmax_shim.struct, | ||
| out_shim.struct, | ||
| ct.c_int32(blocksize), | ||
| ct.c_int32(n), | ||
| ) | ||
| return out, absmax | ||
|
|
||
|
|
||
| def _dequantize_4bit_native( | ||
| A: torch.Tensor, | ||
| absmax: torch.Tensor, | ||
| blocksize: int, | ||
| quant_type: str, | ||
| dtype: torch.dtype, | ||
| out: torch.Tensor, | ||
| ) -> bool: | ||
| if A.dtype != torch.uint8 or not _supports_dtype(dtype): | ||
| return False | ||
|
|
||
| _check_mps_device(absmax, "absmax") | ||
| fn = _resolve_dequant_fn(dtype, quant_type) | ||
| if fn is None: | ||
| return False | ||
|
|
||
| packed_shim = MPSTensorShim.from_tensor(A) | ||
| absmax_shim = MPSTensorShim.from_tensor(absmax) | ||
| out_shim = MPSTensorShim.from_tensor(out) | ||
|
|
||
| _sync_mps_if_needed() | ||
| fn( | ||
| packed_shim.struct, | ||
| absmax_shim.struct, | ||
| out_shim.struct, | ||
| ct.c_int32(blocksize), | ||
| ct.c_int32(out.numel()), | ||
| ) | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| @register_kernel("bitsandbytes::quantize_4bit", "mps") | ||
| def _( | ||
| A: torch.Tensor, | ||
| blocksize: int, | ||
| quant_type: str, | ||
| quant_storage: torch.dtype, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| _check_mps_device(A, "A") | ||
| # result = _quantize_4bit_native(A, blocksize, quant_type, quant_storage) | ||
| # if result is not None: | ||
| # return result | ||
| return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) | ||
|
|
||
|
|
||
| @register_kernel("bitsandbytes::dequantize_4bit", "mps") | ||
| def _( | ||
| A: torch.Tensor, | ||
| absmax: torch.Tensor, | ||
| blocksize: int, | ||
| quant_type: str, | ||
| shape: Sequence[int], | ||
| dtype: torch.dtype, | ||
| ) -> torch.Tensor: | ||
| _check_mps_device(A, "A") | ||
| _check_mps_device(absmax, "absmax") | ||
|
|
||
| out = torch.empty(shape, dtype=dtype, device=A.device) | ||
| if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out): | ||
| return out | ||
| # return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) | ||
|
|
||
|
|
||
| @register_kernel("bitsandbytes::dequantize_4bit.out", "mps") | ||
| def _( | ||
| A: torch.Tensor, | ||
| absmax: torch.Tensor, | ||
| blocksize: int, | ||
| quant_type: str, | ||
| shape: Sequence[int], | ||
| dtype: torch.dtype, | ||
| out: torch.Tensor, | ||
| ) -> None: | ||
| _check_mps_device(A, "A") | ||
| _check_mps_device(out, "out") | ||
| _check_mps_device(absmax, "absmax") | ||
| torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}") | ||
| torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") | ||
|
|
||
| _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out) | ||
| # result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) | ||
| # out.copy_(result) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| import ctypes as ct | ||
| from dataclasses import dataclass | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class _BNBMPSTensor(ct.Structure): | ||
| _fields_ = [ | ||
| ("storage", ct.c_void_p), | ||
| ("byte_offset", ct.c_size_t), | ||
| ("nbytes", ct.c_size_t), | ||
| ] | ||
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class MPSTensorShim: | ||
| """ | ||
| Lightweight wrapper that keeps a Tensor alive while exposing its Metal storage. | ||
|
|
||
| PyTorch stores an ``id<MTLBuffer>`` inside the tensor's untyped storage data | ||
| pointer on MPS. We capture that pointer once and forward the storage offset | ||
| so native kernels can bind the correct buffer without any host copies. | ||
| """ | ||
|
|
||
| tensor: torch.Tensor | ||
| struct: _BNBMPSTensor | ||
|
|
||
| @classmethod | ||
| def from_tensor(cls, tensor: torch.Tensor) -> "MPSTensorShim": | ||
| if hasattr(tensor, "untyped_storage"): | ||
| storage = tensor.untyped_storage() | ||
| else: | ||
| storage = tensor.storage() | ||
|
|
||
| storage_ptr = storage.data_ptr() | ||
| byte_offset = tensor.storage_offset() * tensor.element_size() | ||
| nbytes = tensor.nbytes | ||
|
|
||
| struct = _BNBMPSTensor( | ||
| ct.c_void_p(storage_ptr), | ||
| ct.c_size_t(byte_offset), | ||
| ct.c_size_t(nbytes), | ||
| ) | ||
| return cls(tensor=tensor, struct=struct) | ||
|
|
||
|
|
||
| # def configure_mps_blockwise_kernel(fn: Callable[[object], None]) -> None: | ||
| # """ | ||
| # Ensure ctypes knows the function expects our tensor shim structs by value. | ||
| # """ | ||
|
|
||
| # try: | ||
| # argtypes = getattr(fn, "argtypes") | ||
| # except AttributeError: | ||
| # argtypes = None | ||
|
|
||
| # desired = [_BNBMPSTensor, _BNBMPSTensor, _BNBMPSTensor, ct.c_int32, ct.c_int32] | ||
| # if argtypes != desired: | ||
| # fn.argtypes = desired | ||
| # if getattr(fn, "restype", None) is not None: | ||
| # fn.restype = None | ||
|
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this; CI for tests currently breaks with this and I wouldn't normally expect anyone to run a script from the project root.
I've been thinking separately of switching to a src layout from the flat layout we have now, which would probably take care of things like this too, but probably out of scope here. I think we could just move
test_bnb_mac.pyintoexamples.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, of coursed, this is just a draft to get the discussion started on what we’ll need. I managed to get it working using a shim to pass the underlying MTL buffer to the kernel, but ran into another issue: I have to commit the work and wait manually, since I can’t synchronize with the PyTorch stream afterward. I’m also working on another PR to integrate libtorch and see if it gives us better performance.