diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..c41923d0b --- /dev/null +++ b/__init__.py @@ -0,0 +1,30 @@ +"""Dispatcher shimming the editable layout. + +When this repository is used via ``pip install -e .`` the real Python +package lives under ``bitsandbytes/bitsandbytes``. Importing from the +workspace root (e.g. running scripts from ``.../ai/kernels``) would +otherwise resolve to this outer directory, yielding a namespace module +with no attributes. Import the inner package eagerly and mirror its +symbols so ``import bitsandbytes`` always behaves the same as the +installed wheel. +""" + +from __future__ import annotations + +import importlib +from types import ModuleType + +_inner: ModuleType = importlib.import_module(".bitsandbytes", __name__) + +# Copy dunder metadata expected by consumers. +for _name in ("__all__", "__doc__", "__file__", "__loader__", "__path__", "__spec__", "__version__"): + if hasattr(_inner, _name): + globals()[_name] = getattr(_inner, _name) + +# Re-export public symbols while leaving dunders alone. +for _name, _value in vars(_inner).items(): + if not _name.startswith("__"): + globals()[_name] = _value + +del _inner, _name, _value, ModuleType, importlib + diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 8bea82fb3..8d4fa4c7e 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -9,7 +9,7 @@ import torch -from . import _ops, research, utils +from . import _ops, nn, research, utils from .autograd._functions import ( MatmulLtState, matmul, @@ -38,6 +38,9 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + from .backends.mps import ops as mps_ops + if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index a0f0d2a34..9ab44a7e4 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -189,8 +189,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, return out -@register_kernel("bitsandbytes::quantize_4bit", "default") -def _( +def _quantize_4bit_impl( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) @@ -232,6 +231,13 @@ def _( return packed, absmax.float() +@register_kernel("bitsandbytes::quantize_4bit", "default") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) + + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -243,7 +249,6 @@ def _dequantize_4bit_impl( # Enable non uint8 dtype if A.dtype != torch.uint8: A = A.view(torch.uint8) - A = A.reshape(-1) # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) @@ -290,7 +295,6 @@ def _( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) - return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) diff --git a/bitsandbytes/backends/mps/__init__.py b/bitsandbytes/backends/mps/__init__.py new file mode 100644 index 000000000..662a206e0 --- /dev/null +++ b/bitsandbytes/backends/mps/__init__.py @@ -0,0 +1,2 @@ +# MPS backend registrations are defined in ops.py + diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py new file mode 100644 index 000000000..1702a4034 --- /dev/null +++ b/bitsandbytes/backends/mps/ops.py @@ -0,0 +1,188 @@ +from collections.abc import Sequence +from typing import Optional + +import ctypes as ct +from ctypes import _CFuncPtr +import torch + +from ..._ops import register_kernel +from ...cextension import lib +from ..default.ops import _dequantize_4bit_impl, _quantize_4bit_impl +from ..utils import CODE +from .shim import MPSTensorShim#, configure_mps_blockwise_kernel + + +def _sync_mps_if_needed() -> None: + if torch.backends.mps.is_available(): + torch.mps.synchronize() + + +def _check_mps_device(tensor: torch.Tensor, name: str) -> None: + torch._check( + tensor.device.type == "mps", + lambda: f"{name} must live on an MPS device for the MPS backend, got {tensor.device.type}", + ) + + +def _supports_dtype(dtype: torch.dtype) -> bool: + return dtype in (torch.float16, torch.float32) + + +def _resolve_quant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: + try: + if dtype == torch.float16: + fn = getattr( + lib, + "cquantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp16_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + if dtype == torch.float32: + fn = getattr( + lib, + "cquantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp32_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + except AttributeError: + return None + return None + + +def _resolve_dequant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: + try: + if dtype == torch.float16: + fn = getattr( + lib, + "cdequantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp16_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + if dtype == torch.float32: + fn = getattr( + lib, + "cdequantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp32_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + except AttributeError: + return None + return None + + +def _quantize_4bit_native( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor] | None: + if quant_storage != torch.uint8 or not _supports_dtype(A.dtype): + return None + + fn = _resolve_quant_fn(A.dtype, quant_type) + if fn is None: + return None + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + input_shim = MPSTensorShim.from_tensor(A) + absmax_shim = MPSTensorShim.from_tensor(absmax) + out_shim = MPSTensorShim.from_tensor(out) + + _sync_mps_if_needed() + fn( + input_shim.struct, + absmax_shim.struct, + out_shim.struct, + ct.c_int32(blocksize), + ct.c_int32(n), + ) + return out, absmax + + +def _dequantize_4bit_native( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> bool: + if A.dtype != torch.uint8 or not _supports_dtype(dtype): + return False + + _check_mps_device(absmax, "absmax") + fn = _resolve_dequant_fn(dtype, quant_type) + if fn is None: + return False + + packed_shim = MPSTensorShim.from_tensor(A) + absmax_shim = MPSTensorShim.from_tensor(absmax) + out_shim = MPSTensorShim.from_tensor(out) + + _sync_mps_if_needed() + fn( + packed_shim.struct, + absmax_shim.struct, + out_shim.struct, + ct.c_int32(blocksize), + ct.c_int32(out.numel()), + ) + + return True + + +@register_kernel("bitsandbytes::quantize_4bit", "mps") +def _( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + _check_mps_device(A, "A") + # result = _quantize_4bit_native(A, blocksize, quant_type, quant_storage) + # if result is not None: + # return result + return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) + + +@register_kernel("bitsandbytes::dequantize_4bit", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + _check_mps_device(A, "A") + _check_mps_device(absmax, "absmax") + + out = torch.empty(shape, dtype=dtype, device=A.device) + if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out): + return out + # return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + _check_mps_device(A, "A") + _check_mps_device(out, "out") + _check_mps_device(absmax, "absmax") + torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + + _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out) + # result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + # out.copy_(result) \ No newline at end of file diff --git a/bitsandbytes/backends/mps/shim.py b/bitsandbytes/backends/mps/shim.py new file mode 100644 index 000000000..53c18fd92 --- /dev/null +++ b/bitsandbytes/backends/mps/shim.py @@ -0,0 +1,63 @@ +import ctypes as ct +from dataclasses import dataclass +from typing import Callable + +import torch + + +class _BNBMPSTensor(ct.Structure): + _fields_ = [ + ("storage", ct.c_void_p), + ("byte_offset", ct.c_size_t), + ("nbytes", ct.c_size_t), + ] + + +@dataclass(slots=True) +class MPSTensorShim: + """ + Lightweight wrapper that keeps a Tensor alive while exposing its Metal storage. + + PyTorch stores an ``id`` inside the tensor's untyped storage data + pointer on MPS. We capture that pointer once and forward the storage offset + so native kernels can bind the correct buffer without any host copies. + """ + + tensor: torch.Tensor + struct: _BNBMPSTensor + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "MPSTensorShim": + if hasattr(tensor, "untyped_storage"): + storage = tensor.untyped_storage() + else: + storage = tensor.storage() + + storage_ptr = storage.data_ptr() + byte_offset = tensor.storage_offset() * tensor.element_size() + nbytes = tensor.nbytes + + struct = _BNBMPSTensor( + ct.c_void_p(storage_ptr), + ct.c_size_t(byte_offset), + ct.c_size_t(nbytes), + ) + return cls(tensor=tensor, struct=struct) + + +# def configure_mps_blockwise_kernel(fn: Callable[[object], None]) -> None: +# """ +# Ensure ctypes knows the function expects our tensor shim structs by value. +# """ + +# try: +# argtypes = getattr(fn, "argtypes") +# except AttributeError: +# argtypes = None + +# desired = [_BNBMPSTensor, _BNBMPSTensor, _BNBMPSTensor, ct.c_int32, ct.c_int32] +# if argtypes != desired: +# fn.argtypes = desired +# if getattr(fn, "restype", None) is not None: +# fn.restype = None + diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 188576225..48d933384 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -278,19 +278,35 @@ def get_native_library() -> BNBNativeLibrary: """ Load CUDA library XOR CPU, as the latter contains a subset of symbols of the former. """ - cuda_specs = get_cuda_specs() - binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" - - if cuda_specs: - cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) - - if not cuda_binary_path.exists(): - raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") - - binary_path = cuda_binary_path - - if torch._C._has_xpu: + cpu_binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" + binary_path = cpu_binary_path + + if BNB_BACKEND in {"CUDA", "ROCm"}: + cuda_specs = get_cuda_specs() + if cuda_specs: + candidate = get_cuda_bnb_library_path(cuda_specs) + if not candidate.exists(): + raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {candidate}") + binary_path = candidate + else: + logger.warning( + "bitsandbytes: CUDA/ROCm backend requested but PyTorch did not expose runtime specs; " + "falling back to CPU implementation." + ) + elif BNB_BACKEND == "XPU": binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" + elif BNB_BACKEND == "MPS": + binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}" + + if not binary_path.exists(): + if BNB_BACKEND == "MPS": + logger.warning( + "bitsandbytes: libbitsandbytes_mps was not found. Falling back to CPU kernels; " + "MPS-specific optimizations will be unavailable." + ) + binary_path = cpu_binary_path + else: + raise RuntimeError(f"bitsandbytes: native library not found at {binary_path}") logger.debug(f"Loading bitsandbytes native library from: {binary_path}") @@ -313,6 +329,8 @@ def get_native_library() -> BNBNativeLibrary: BNB_BACKEND = "ROCm" elif torch.cuda.is_available(): BNB_BACKEND = "CUDA" +elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + BNB_BACKEND = "MPS" elif torch._C._has_xpu: BNB_BACKEND = "XPU" diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index d3332acfe..e8561a893 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -527,7 +527,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def forward(self, x: torch.Tensor): fix_4bit_weight_quant_state_from_module(self) quant_state = self.weight.quant_state - if ( not getattr(quant_state, "packing_format_for_cpu", False) and x.device.type == "cpu" diff --git a/csrc/mps_kernels.metal b/csrc/mps_kernels.metal index 63b3bf78c..0d9b01fe4 100644 --- a/csrc/mps_kernels.metal +++ b/csrc/mps_kernels.metal @@ -1,117 +1,265 @@ #include +#include using namespace metal; -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 - -template -static unsigned char quantize_scalar( - float rand, - device float* code, - float x) -{ - int pivot = 127; - int upper_pivot = 255; - int lower_pivot = 0; - - float lower = -1.0f; - float upper = 1.0f; - - float val = code[pivot]; - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; - } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; +namespace { + +constant float NF4_CODE[16] = { + -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, + -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, + 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, + 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 +}; + +constant float FP4_CODE[16] = { + 0.0, 0.0052, 0.6667, 1.0, 0.3333, 0.5, 0.1667, 0.25, + 0.0, -0.0052, -0.6667, -1.0, -0.3333, -0.5, -0.1667, -0.25 +}; + +template +inline uchar encode_value(float value, constant float* code_table) { + float best = fabs(value - code_table[0]); + uchar index = 0; + for (uchar i = 1; i < 16; ++i) { + float diff = fabs(value - code_table[i]); + if (diff < best) { + best = diff; + index = i; } - val = code[pivot]; } + return index; +} + +template +inline void quantize_block( + device const scalar_t* input, + device float* absmax, + device uchar* packed, + uint n, + uint blocksize, + uint block_index, + constant float* code_table +) { + uint start = block_index * blocksize; + if (start >= n) { + return; + } + + uint end = min(start + blocksize, n); + float max_val = 0.0f; + for (uint i = start; i < end; ++i) { + float current = fabs((float)input[i]); + max_val = max(max_val, current); + } + + absmax[block_index] = max_val; + float inv = max_val > 0.0f ? 1.0f / max_val : 0.0f; - if(upper_pivot == 255) - upper = code[upper_pivot]; - if(lower_pivot == 0) - lower = code[lower_pivot]; - - if(!STOCHASTIC) - { - if(x > val) - { - float midpoint = (upper+val)*0.5f; - if(x > midpoint) - { - return upper_pivot; + uint out_byte = start >> 1; + bool has_pending = false; + uchar pending = 0; + + for (uint i = start; i < end; ++i) { + float normalized = (max_val > 0.0f) ? clamp((float)input[i] * inv, -1.0f, 1.0f) : 0.0f; + uchar q = encode_value(normalized, code_table) & 0xF; + + if (!has_pending) { + pending = q << 4; + has_pending = true; + if (i == end - 1) { + packed[out_byte++] = pending; + has_pending = false; + } + } else { + packed[out_byte++] = pending | q; + has_pending = false; } - else - return pivot; - } - else - { - float midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } } - else - { - if(x > val) - { - float dist_to_upper = fabs(upper-x); - float dist_full = upper-val; - if(rand >= dist_to_upper/dist_full) return upper_pivot; - else return pivot; - } - else - { - float dist_to_lower = fabs(lower-x); - float dist_full = val-lower; - if(rand >= dist_to_lower/dist_full) return lower_pivot; - else return pivot; - } +} + +template +inline void dequantize_block( + device const uchar* packed, + device const float* absmax, + device scalar_t* output, + uint n, + uint blocksize, + uint block_index, + uint thread_idx, + uint threadgroup_size, + constant float* code_table, + threadgroup float& shared_scale +) { + uint block_start = block_index * blocksize; + if (block_start >= n) { + return; + } + uint block_end = min(block_start + blocksize, n); + uint pairs_in_block = (block_end - block_start + 1) >> 1; + + if (thread_idx == 0) { + shared_scale = absmax[block_index]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + float scale = shared_scale; + + for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) { + uint value_index0 = block_start + pair * 2; + if (value_index0 >= block_end) { + break; + } + + uint byte_index0 = value_index0 >> 1; + uchar byte_val0 = packed[byte_index0]; + bool upper0 = ((value_index0 & 1) == 0); + uchar nibble0 = upper0 ? ((byte_val0 >> 4) & 0xF) : (byte_val0 & 0xF); + float decoded0 = code_table[nibble0] * scale; + output[value_index0] = scalar_t(decoded0); + + uint value_index1 = value_index0 + 1; + if (value_index1 < block_end) { + uint byte_index1 = value_index1 >> 1; + uchar byte_val1 = (byte_index1 == byte_index0) ? byte_val0 : packed[byte_index1]; + bool upper1 = ((value_index1 & 1) == 0); + uchar nibble1 = upper1 ? ((byte_val1 >> 4) & 0xF) : (byte_val1 & 0xF); + float decoded1 = code_table[nibble1] * scale; + output[value_index1] = scalar_t(decoded1); + } } } -kernel void quantize(device float* code [[buffer(0)]], - device float* A [[buffer(1)]], - device uchar* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint id [[thread_position_in_grid]]) { - const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); - uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; - const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); +} // namespace - float vals[NUM]; - uchar qvals[NUM]; +// Quantization kernels +kernel void quantize_4bit_fp16_fp4( + device const half* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= blocks) { + return; + } + quantize_block(input, absmax, packed, n, blocksize, gid, FP4_CODE); +} - for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { - valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; +kernel void quantize_4bit_fp16_nf4( + device const half* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= blocks) { + return; + } + quantize_block(input, absmax, packed, n, blocksize, gid, NF4_CODE); +} - threadgroup_barrier(mem_flags::mem_threadgroup); +kernel void quantize_4bit_fp32_fp4( + device const float* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= blocks) { + return; + } + quantize_block(input, absmax, packed, n, blocksize, gid, FP4_CODE); +} - for (uint j = 0; j < valid_items; j++) { - vals[j] = A[i + j]; +kernel void quantize_4bit_fp32_nf4( + device const float* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint gid [[thread_position_in_grid]] +) { + if (gid >= blocks) { + return; } + quantize_block(input, absmax, packed, n, blocksize, gid, NF4_CODE); +} - for (uint j = 0; j < valid_items; j++) { - qvals[j] = quantize_scalar(0.0f, code, vals[j]); +// Dequantization kernels +kernel void dequantize_4bit_fp16_fp4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device half* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; } + threadgroup float shared_scale; + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale); +} - threadgroup_barrier(mem_flags::mem_threadgroup); +kernel void dequantize_4bit_fp16_nf4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device half* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; + } + threadgroup float shared_scale; + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale); +} - for (uint j = 0; j < valid_items; j++) { - out[i + j] = qvals[j]; +kernel void dequantize_4bit_fp32_fp4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device float* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; } - } + threadgroup float shared_scale; + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_scale); } + +kernel void dequantize_4bit_fp32_nf4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device float* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; + } + threadgroup float shared_scale; + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_scale); +} \ No newline at end of file diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm index 85ed1b1e4..8ef523b1c 100644 --- a/csrc/mps_ops.mm +++ b/csrc/mps_ops.mm @@ -1,62 +1,244 @@ -#import +#import +#import -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 +#include +#include +#include +#include +#include -static inline MPSGraph* get_graph() { - static MPSGraph* cur = nil; - if (!cur) { - cur = [[MPSGraph alloc] init]; - } - return cur; -} +namespace { + +typedef struct { + void* storage; + size_t byte_offset; + size_t nbytes; +} BNBMPSTensor; static inline id get_device() { - NSError* error = nil; static id device = nil; - if (!device) { + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ device = MTLCreateSystemDefaultDevice(); - } if (!device) { - NSLog(@"Failed to get MPS device"); + NSLog(@"bitsandbytes: failed to acquire Metal device"); abort(); } + }); return device; } +static inline id get_command_queue() { + static id queue = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + queue = [get_device() newCommandQueue]; + if (!queue) { + NSLog(@"bitsandbytes: failed to create Metal command queue"); + abort(); + } + }); + return queue; +} + +static inline NSURL* metallib_url() { + Dl_info info; + if (dladdr(reinterpret_cast(&metallib_url), &info) == 0) { + NSLog(@"bitsandbytes: dladdr failed to resolve metallib path"); + abort(); + } + NSString* dylibPath = [NSString stringWithUTF8String:info.dli_fname]; + NSString* directory = [dylibPath stringByDeletingLastPathComponent]; + NSString* metallibPath = [directory stringByAppendingPathComponent:@"bitsandbytes.metallib"]; + return [NSURL fileURLWithPath:metallibPath]; +} + static inline id get_library() { - NSError* error = nil; static id library = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + NSError* error = nil; + library = [get_device() newLibraryWithURL:metallib_url() error:&error]; if (!library) { - library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + NSLog(@"bitsandbytes: failed to load bitsandbytes.metallib (%@)", error); + abort(); + } + }); + return library; +} + +static inline id get_pipeline(NSString* functionName) { + static NSMutableDictionary>* cache = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + cache = [[NSMutableDictionary alloc] init]; + }); + + id pipeline = cache[functionName]; + if (pipeline) { + return pipeline; } - if (!library) { - NSLog(@"Failed to load bitsandbytes.metallib"); + + NSError* error = nil; + id function = [get_library() newFunctionWithName:functionName]; + if (!function) { + NSLog(@"bitsandbytes: missing Metal kernel %@", functionName); abort(); } - return library; + + pipeline = [get_device() newComputePipelineStateWithFunction:function error:&error]; + [function release]; + + if (!pipeline) { + NSLog(@"bitsandbytes: failed to create pipeline for %@ (%@)", functionName, error); + abort(); + } + + cache[functionName] = pipeline; + return pipeline; } -/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) -{ - id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 -dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; -}*/ - -// MPSGraph function for quantize -extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { - id device = get_device(); - id library = get_library(); - static id kernel = nil; - if (!kernel) { - kernel = [library newFunctionWithName:@"quantize"]; - if (!kernel) { - NSLog(@"Failed to load bitsandbytes.metallib"); - abort(); - } +struct TensorView { + id buffer; + NSUInteger offset; +}; + +static inline TensorView make_tensor_view(const BNBMPSTensor& tensor, const char* label) { + TensorView view; + view.buffer = __builtin_bit_cast(id, tensor.storage); + view.offset = static_cast(tensor.byte_offset); + if (!view.buffer && tensor.nbytes > 0) { + NSLog(@"bitsandbytes: missing MTLBuffer for %s tensor (storage=%p, bytes=%zu)", label, tensor.storage, tensor.nbytes); + abort(); } - NSLog(@"Not implemented"); - return nil; + return view; } + +static inline void dispatch_quant_kernel( + NSString* name, + const BNBMPSTensor& input, + const BNBMPSTensor& absmax, + const BNBMPSTensor& out, + uint32_t blocksize, + uint32_t n +) { + if (n == 0) { + return; + } + + uint32_t blocks = (n + blocksize - 1) / blocksize; + TensorView inputView = make_tensor_view(input, "input"); + TensorView absmaxView = make_tensor_view(absmax, "absmax"); + TensorView outView = make_tensor_view(out, "out"); + + id commandBuffer = [get_command_queue() commandBuffer]; + id encoder = [commandBuffer computeCommandEncoder]; + id pipeline = get_pipeline(name); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:inputView.buffer offset:inputView.offset atIndex:0]; + [encoder setBuffer:absmaxView.buffer offset:absmaxView.offset atIndex:1]; + [encoder setBuffer:outView.buffer offset:outView.offset atIndex:2]; + [encoder setBytes:&n length:sizeof(uint32_t) atIndex:3]; + [encoder setBytes:&blocksize length:sizeof(uint32_t) atIndex:4]; + [encoder setBytes:&blocks length:sizeof(uint32_t) atIndex:5]; + + NSUInteger threadsPerThreadgroup = pipeline.threadExecutionWidth; + if (threadsPerThreadgroup == 0) { + threadsPerThreadgroup = 1; + } + MTLSize threads = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(blocks, 1, 1); + [encoder dispatchThreads:grid threadsPerThreadgroup:threads]; + [encoder endEncoding]; + + [commandBuffer commit]; + [commandBuffer waitUntilCompleted]; + +} + +static inline void dispatch_dequant_kernel( + NSString* name, + const BNBMPSTensor& packed, + const BNBMPSTensor& absmax, + const BNBMPSTensor& output, + uint32_t blocksize, + uint32_t n +) { + if (n == 0) { + return; + } + uint32_t blocks = (n + blocksize - 1) / blocksize; + TensorView packedView = make_tensor_view(packed, "packed"); + TensorView absmaxView = make_tensor_view(absmax, "absmax"); + TensorView outputView = make_tensor_view(output, "output"); + + id commandBuffer = [get_command_queue() commandBuffer]; + id encoder = [commandBuffer computeCommandEncoder]; + id pipeline = get_pipeline(name); + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:packedView.buffer offset:packedView.offset atIndex:0]; + [encoder setBuffer:absmaxView.buffer offset:absmaxView.offset atIndex:1]; + [encoder setBuffer:outputView.buffer offset:outputView.offset atIndex:2]; + [encoder setBytes:&n length:sizeof(uint32_t) atIndex:3]; + [encoder setBytes:&blocksize length:sizeof(uint32_t) atIndex:4]; + [encoder setBytes:&blocks length:sizeof(uint32_t) atIndex:5]; + + NSUInteger maxThreadsPerTG = pipeline.maxTotalThreadsPerThreadgroup; + NSUInteger desiredThreads = (blocksize + 1) / 2; + if (desiredThreads == 0) { + desiredThreads = 1; + } + NSUInteger threadsPerThreadgroup = std::min(maxThreadsPerTG, std::max(1, desiredThreads)); + if (threadsPerThreadgroup < pipeline.threadExecutionWidth) { + threadsPerThreadgroup = std::min(pipeline.threadExecutionWidth, maxThreadsPerTG); + } + + NSUInteger totalThreads = threadsPerThreadgroup * blocks; + MTLSize threads = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(totalThreads, 1, 1); + [encoder dispatchThreads:grid threadsPerThreadgroup:threads]; + [encoder endEncoding]; + + [commandBuffer commit]; + // [commandBuffer waitUntilCompleted]; +} + +} // namespace + +extern "C" { + +void cquantize_blockwise_fp16_fp4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp16_fp4", input, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp16_nf4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp16_nf4", input, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_fp4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp32_fp4", input, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_nf4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp32_nf4", input, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp16_fp4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp16_fp4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp16_nf4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp16_nf4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32_fp4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp32_fp4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32_nf4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp32_nf4", packed, absmax, out, blocksize, n); +} + +} // extern "C" diff --git a/pyproject.toml b/pyproject.toml index 65f9314c5..8d35ecda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ test = [ [tool.setuptools] package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] } +# package-data = { "*" = ["libbitsandbytes*.*", "bitsandbytes.metallib", "py.typed"] } [tool.setuptools.packages.find] include = ["bitsandbytes*"] diff --git a/script.sh b/script.sh new file mode 100755 index 000000000..dec3e8f6f --- /dev/null +++ b/script.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +PYTHON_PATH=/Users/medmekk/miniforge3/envs/gpt/bin/python +$PYTHON_PATH ./test_bnb_mac.py \ No newline at end of file diff --git a/test_bnb_mac.py b/test_bnb_mac.py new file mode 100644 index 000000000..3d395e66b --- /dev/null +++ b/test_bnb_mac.py @@ -0,0 +1,70 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig + +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1b-Instruct") +quantization_config = BitsAndBytesConfig(load_in_4bit=True) +model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1b-Instruct", device_map="mps", quantization_config=quantization_config) +print("model.device:", model.device) +prompt = "Hello, how are you?" +inputs = tokenizer(prompt, return_tensors="pt").to(model.device) +outputs = model.generate(**inputs, max_new_tokens=20) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # or whatever entry function you have + +# import torch +# import bitsandbytes as bnb +# A = torch.randn(2048, device='mps', dtype=torch.float16) +# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, 'nf4', torch.uint8) +# print('q.shape:', q.shape, q.dtype) +# print('absmax.shape:', absmax.shape, absmax.dtype) +# B = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, 'nf4', A.shape, A.dtype) +# print('ok', float((A-B).abs().max())) + +# import torch, bitsandbytes as bnb + +# torch.manual_seed(0) +# A = torch.randn(256, device="mps", dtype=torch.float16) + +# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, "nf4", torch.uint8) +# B_native = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, "nf4", A.shape, A.dtype) + +# # CPU reference (uses the default implementation, then move back to MPS) +# B_ref = torch.ops.bitsandbytes.dequantize_4bit.default( +# q.cpu(), absmax.cpu(), 64, "nf4", A.shape, A.dtype +# ).to("mps") + +# print("A[:8] ", A[:8].cpu()) +# print("B_native[:8]", B_native[:8].cpu()) +# print("B_ref[:8] ", B_ref[:8].cpu()) +# print("max |A-B_native|:", float((A - B_native).abs().max())) +# print("max |A-B_ref| :", float((A - B_ref).abs().max())) + +# diff = (B_native - B_ref).cpu() +# print("B_native shape:", B_native.shape) +# print("B_ref shape:", B_ref.shape) +# print("max |B_native - B_ref|:", float(diff.abs().max())) +# print("first 16 diffs:", diff[:16]) + +# q_cpu, absmax_cpu = torch.ops.bitsandbytes.quantize_4bit.default( +# A.cpu(), 64, "nf4", torch.uint8 +# ) + +# print("q identical? ", torch.equal(q.cpu(), q_cpu)) +# print("absmax max diff:", float((absmax.cpu() - absmax_cpu).abs().max())) +# print("q_mps[:8]:", q.view(-1)[:8].cpu()) +# print("q_cpu[:8]:", q_cpu.view(-1)[:8]) +# print("absmax_mps[:4]:", absmax[:4].cpu()) +# print("absmax_cpu[:4]:", absmax_cpu[:4]) + +# import torch, bitsandbytes as bnb, time + +# torch.manual_seed(0) +# A = torch.randn(4096 * 4096, device="mps", dtype=torch.float16) +# blocksize = 64 + +# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, "nf4", torch.uint8) + +# torch.mps.synchronize() +# t0 = time.perf_counter() +# torch.ops.bitsandbytes.dequantize_4bit(q, absmax, blocksize, "nf4", A.shape, A.dtype) +# torch.mps.synchronize() +# dt = time.perf_counter() - t0 +# print(f"Dequant time: {dt*1000:.2f} ms for {A.numel()/1e6:.1f}M elements") \ No newline at end of file