diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 12088a70c..9a2524953 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -34,6 +34,9 @@ if torch.cuda.is_available(): from .backends.cuda import ops as cuda_ops +if torch.xpu.is_available(): + from .backends.xpu import ops as xpu_ops + def _import_backends(): """ diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 9a3ac46ac..a260852f5 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -4,6 +4,8 @@ import torch +from .cextension import ipex_cpu, ipex_xpu + _IS_TORCH_GTE_24 = False if hasattr(torch.library, "register_fake"): @@ -327,3 +329,22 @@ def _( ) torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + + +if ipex_cpu or ipex_xpu: + # Register the dequantize_nf4_ipex implementation + torch.library.define( + "bitsandbytes::dequantize_nf4_ipex", + "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor", + ) + + @register_fake("bitsandbytes::dequantize_nf4_ipex") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + torch._check_is_size(blocksize) + return torch.empty(shape, dtype=dtype, device=A.device) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c7ad3a82c..746d6c1ec 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -8,6 +8,7 @@ from typing_extensions import deprecated import bitsandbytes.functional as F +from bitsandbytes.functional import ipex_cpu, ipex_xpu # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py @@ -298,6 +299,63 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor return grad_A, grad_B, None, grad_bias, None +class MatMul8bitFp(torch.autograd.Function): + # For Intel CPU and XPU MatMul8bitFp is much faster (~3x) than MatMul8bitLt in finetune. + # Because the MatMul8bitLt has more mechanisms in computing grad. + # We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow. + # We'd like to use dequant + matmul to run finetune with good performance. + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + if state.has_fp16_weights or state.CB is None: + has_grad = getattr(B, "grad", None) is not None + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: + B = B.contiguous() + + if (state.is_training and not has_grad) or state.CB is None or state.SCB is None: + state.reset_grads() + state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) + B = state.CB + + CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + output = torch.nn.functional.linear(A, CB, bias) + # to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu] + state.idx = False + ctx.state = state + ctx.dtype_A = A.dtype + ctx.grad_shape = A.shape + ctx.A = A + ctx.dtype_bias = None if bias is None else bias.dtype + return output + + @staticmethod + def backward(ctx, grad_output): + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + A = ctx.A + state = ctx.state + grad_A = grad_B = grad_bias = None + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + + if req_gradB: + grad_B = torch.matmul(A.t(), grad_output).t() + + if req_gradA: + if state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape) + else: + raise Exception("State must contain CB matrix for backward") + + return grad_A, grad_B, None, grad_bias, None + + class MatMul4Bit(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @@ -366,6 +424,10 @@ def matmul( state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold + # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU + if state.is_training: + if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu): + return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state) @@ -378,6 +440,17 @@ def matmul_4bit( ): assert quant_state is not None + if A.device.type in ("cpu", "xpu") and A.requires_grad == False: + if getattr(quant_state, "ipex", False): + # IPEX CPU will change weight to 4D so don't need transpose + B = B.t() if B.dim() == 2 else B + out = F.gemv_4bit(A, B, out, state=quant_state) + if bias is not None: + out += bias + return out + else: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: warn( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index d5ab9aa88..5f009ea40 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -7,6 +7,7 @@ from ..._ops import register_kernel from ...cextension import lib +from ..utils import ipex_cpu # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. @@ -26,22 +27,42 @@ def _(A: torch.Tensor, B: torch.Tensor): @register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) - torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}") n = A.numel() - blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) - out = torch.empty_like(A, dtype=torch.uint8) - - lib.cquantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(n), - ) + # Only FP32 has c++ kernrl + if A.dtype == torch.float32: + blocks = -(n // -blocksize) + + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty_like(A, dtype=torch.uint8) + + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(n), + ) + else: + rem = n % blocksize + has_rem = rem > 0 + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) return out, absmax @@ -50,144 +71,50 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") - torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}") - - out = torch.empty_like(A, dtype=dtype) - lib.cdequantize_blockwise_cpu_fp32( - get_ptr(code), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_longlong(blocksize), - ct.c_longlong(A.numel()), - ) + # Only FP32 has c++ kernrl + if dtype == torch.float32: + out = torch.empty_like(A, dtype=dtype) + + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_longlong(blocksize), + ct.c_longlong(A.numel()), + ) + else: + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) return out -_NF4_QUANT_TABLE = torch.tensor( - [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ], - dtype=torch.float32, - device="cpu", -) - - -@register_kernel("bitsandbytes::quantize_4bit", "cpu") -def _( - A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype -) -> tuple[torch.Tensor, torch.Tensor]: - torch._check_is_size(blocksize) - torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") - torch._check( - A.dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", - ) - - n = A.numel() - - # TODO: Support when weight matrix is not divisible by blocksize - torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") - - # Divide into blocks and normalize - blocks = A.reshape(-1, blocksize) - absmax = blocks.abs().max(dim=1).values.float() - scaled = blocks / absmax.unsqueeze(-1) - - # Quantize with the lookup table - quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8) - - # Pack two quantized values per byte - packed = quantized[::2] << 4 | quantized[1::2] - - if quant_storage != torch.uint8: - packed = packed.squeeze().view(quant_storage).unsqueeze(1) - - return packed, absmax.float() - - -@register_kernel("bitsandbytes::dequantize_4bit", "cpu") -def _( - A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, -) -> torch.Tensor: - torch._check_is_size(blocksize) - torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") - torch._check( - dtype in [torch.bfloat16, torch.float16, torch.float32], - lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", - ) - torch._check( - A.dtype == torch.uint8, - lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}", - ) - - A = A.view(-1, 1) - - # Grab upper and lower nibbles. Using int64 for indexing in the LUT. - upper = (A >> 4).to(torch.int64) - lower = (A & 0x0F).to(torch.int64) - - # Expand to blocks - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - - # Dequantize - blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None] - - # Reshape to original shape - blocks = blocks.reshape(-1, *shape[1:]) - - return blocks.to(dtype) - - -@register_kernel("bitsandbytes::gemv_4bit", "cpu") -def _( - A: torch.Tensor, - B: torch.Tensor, - shapeB: Sequence[int], - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, -) -> torch.Tensor: - # TODO: We need to determine whether `code` is NF4, FP4, or other. - # Right now we assume NF4, as this is the only one supported on CPU. - - B_dq = torch.ops.bitsandbytes.dequantize_4bit.default( - B, - absmax, - blocksize, - "nf4", - shape=shapeB, - dtype=A.dtype, - ) - - # User called gemv with B.t(), so we need to transpose it back. - # if B.shape[0] == 1: - # B_dq = B_dq.t() - - return torch.nn.functional.linear( - A, - B_dq, - bias=None, - ) +if ipex_cpu: + from bitsandbytes.utils import _reverse_4bit_compress_format + + @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) + A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) + return torch.ops.bitsandbytes.dequantize_4bit.default( + A, + absmax, + blocksize, + "nf4", + shape, + dtype, + ) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 729c2b047..ce5926979 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,9 +1,11 @@ +from collections.abc import Sequence from math import prod from typing import Optional import torch from ..._ops import register_kernel +from ..utils import CODE @register_kernel("bitsandbytes::int8_mm_dequant", "default") @@ -142,3 +144,160 @@ def _(A: torch.Tensor, threshold=0.0): A[outliers] = outlier_restore return out_row, row_stats, outlier_cols + + +@register_kernel("bitsandbytes::quantize_blockwise", "default") +def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + + n = A.numel() + rem = n % blocksize + has_rem = rem > 0 + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) + + return out, absmax + + +@register_kernel("bitsandbytes::dequantize_blockwise", "default") +def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") + + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out + + +@register_kernel("bitsandbytes::quantize_4bit", "default") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + full_blocks = n // blocksize + rem = n % blocksize + blocks = full_blocks + 1 if rem else full_blocks + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + A_flattened = A.reshape(n) + + # Scale full blocks of the tensor to [-1, 1] + A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize) + absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0] + scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1) + + # Scale any partial block + if rem: + A_rem = A_flattened[-rem:] + absmax[-1] = torch.abs(A_rem).max() + scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1) + scaled = torch.cat([scaled, scaled_rem], dim=0) + + # Quantize with the lookup table + code = CODE[quant_type].to(scaled.device).to(scaled.dtype) + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8) + + # Pack two quantized values per byte + packed = quantized[::2] << 4 | quantized[1::2] + + if quant_storage != torch.uint8: + packed = packed.squeeze().view(quant_storage).unsqueeze(1) + + return packed, absmax.float() + + +@register_kernel("bitsandbytes::dequantize_4bit", "default") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out + + +@register_kernel("bitsandbytes::gemv_4bit", "default") +def _( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + # Applied from dequantize_4bit + quant_type = "fp4" if code[1] > 0 else "nf4" + B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype) + + return torch.nn.functional.linear( + A, + B_dq, + bias=None, + ) diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py new file mode 100755 index 000000000..cc88ffae1 --- /dev/null +++ b/bitsandbytes/backends/utils.py @@ -0,0 +1,57 @@ +import torch + +try: + # to support Intel CPU/XPU (IPEX) backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None +except BaseException: + ipex_cpu = None + ipex_xpu = None + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, + device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. +) +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, + device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. +) +CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} diff --git a/bitsandbytes/backends/xpu/__init__.py b/bitsandbytes/backends/xpu/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py new file mode 100755 index 000000000..47a3bd009 --- /dev/null +++ b/bitsandbytes/backends/xpu/ops.py @@ -0,0 +1,51 @@ +from collections.abc import Sequence + +import torch + +from ..._ops import register_kernel +from ..utils import ipex_xpu + +if torch.__version__ >= (2, 7): + + @register_kernel("bitsandbytes::int8_linear_matmul", "xpu") + def _(A: torch.Tensor, B: torch.Tensor): + return torch._int_mm( + A.reshape(-1, A.shape[-1]), + B.t(), + ).reshape(*A.shape[:-1], B.shape[0]) + + +if ipex_xpu: + + @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + + @register_kernel("bitsandbytes::dequantize_blockwise", "xpu") + def _( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + ) -> torch.Tensor: + shape = A.shape + out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) + # void cdequantize_blockwise_fp32( + # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if dtype == torch.float16: + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.bfloat16: + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.float32: + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + return out.reshape(shape) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index ebc363991..b112df2f7 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -286,11 +286,26 @@ def get_native_library() -> BNBNativeLibrary: return BNBNativeLibrary(dll) +try: + # to support Intel CPU/GPU (XPU) backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None +except BaseException: + ipex_cpu = None + ipex_xpu = None + + try: lib = get_native_library() except Exception as e: error_msg = str(e) - logger.error(f"bitsandbytes library load error: {error_msg}\n", exc_info=True) + if not (ipex_cpu or ipex_xpu): + logger.error( + f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops", + exc_info=True, + ) # create a mock with error messaging as fallback lib = ErrorHandlerMockBNBNativeLibrary(error_msg) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py old mode 100644 new mode 100755 index 0bd4c8b4e..ffb66681a --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,9 +13,9 @@ from torch import Tensor from typing_extensions import deprecated -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import lib +from .cextension import ipex_cpu, ipex_xpu, lib name2qmap = {} @@ -1122,6 +1122,16 @@ def dequantize_4bit( if absmax.dtype != torch.float32: absmax = absmax.float() + # IPEX format is different, we need extra process. + if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": + return torch.ops.bitsandbytes.dequantize_nf4_ipex( + A, + absmax, + quant_state.blocksize, + quant_state.shape, + quant_state.dtype, + ) + if out is not None: torch.ops.bitsandbytes.dequantize_4bit.out( A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out @@ -1709,6 +1719,25 @@ def gemv_4bit( if state.nested: absmax = dequantize_blockwise(absmax, state.state2) + state.offset + if getattr(state, "ipex", False) and state.quant_type == "nf4": + # compute_dtype: 1 indicates fp16, 2 indicates bf16 + compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 + out = torch.ops.torch_ipex.woq_linear( + A, + B, + "nf4", + state.shape, + state.new_scales, + state.new_zeros, + None, + None, + state.blocksize, + compute_dtype, + 1, + state.compensation, + ) + return out + if out is not None: torch.ops.bitsandbytes.gemv_4bit.out( A, @@ -2507,3 +2536,49 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): return x.to(dtype) else: return None + + +def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): + quant_state = linear.weight.quant_state + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + quant_state.absmax = absmax + quant_state.nested = False + delattr(quant_state, "state2") + + if x.device.type == "cpu" and ipex_cpu: + converted_weight = _reverse_4bit_compress_format(linear.weight.data) + new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight( + converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]), + "nf4", + quant_state.shape, # weight shape + quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales + None, # zero_points + None, # bias + None, # batch_size + quant_state.blocksize, + 2, + ) + elif x.device.type == "xpu" and ipex_xpu: + new_weight = _reverse_4bit_compress_format(linear.weight.data) + new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize) + new_zeros = None + compensation = None + new_scales = list(new_scales) + if not linear.training and not x.requires_grad: + new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]) + else: + raise ValueError( + "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7" + ) + + linear.weight.data = new_weight.data + linear.weight.quant_state.ipex = True + linear.weight.quant_state.new_scales = new_scales + linear.weight.quant_state.new_zeros = new_zeros + linear.weight.quant_state.compensation = compensation diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 500102ab1..ccd842ce3 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,11 +11,12 @@ import torch.nn.functional as F import bitsandbytes as bnb -from bitsandbytes.functional import QuantState +from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.utils import ( INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer, + _reverse_4bit_compress_format, ) T = TypeVar("T", bound="torch.nn.Module") @@ -444,6 +445,7 @@ def __init__( self.compute_type_is_set = False self.quant_state = None self.quant_storage = quant_storage + self.ipex_linear_is_set = False def set_compute_type(self, x): if x.dtype in [torch.float32, torch.bfloat16]: @@ -470,13 +472,40 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): save weight and bias, then fill state_dict with components of quant_state """ + if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False): + if self.weight.device.type == "cpu": + original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight( + self.weight, "nf4", self.weight.quant_state.shape, 2 + ) + self.weight.data = _reverse_4bit_compress_format(original_weight.data) + elif self.weight.device.type == "xpu": + self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1)) + + self.weight.quant_state.ipex = False + self.ipex_linear_is_set = False + super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias if getattr(self.weight, "quant_state", None) is not None: for k, v in self.weight.quant_state.as_dict(packed=True).items(): destination[prefix + "weight." + k] = v if keep_vars else v.detach() + def set_ipex_linear(self, x: torch.Tensor): + if ( + not getattr(self.weight.quant_state, "ipex", False) + and self.weight.data.dtype == torch.uint8 + and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0 + and self.weight.quant_state.quant_type == "nf4" + ): + if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False): + _enable_ipex_fusion(self, x) + def forward(self, x: torch.Tensor): + # Check if ipex fusion can be used + if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu): + self.set_ipex_linear(x) + self.ipex_linear_is_set = True + fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -492,8 +521,10 @@ def forward(self, x: torch.Tensor): x = x.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype) + # IPEX CPU will change weight to 4D so don't need transpose + weight = self.weight.t() if self.weight.dim() == 2 else self.weight - return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): @@ -644,17 +675,20 @@ def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) if device is not None and device.type != "meta" and self.data.device.type == "cpu": - return self._quantize(device) - else: - new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), - requires_grad=self.requires_grad, - has_fp16_weights=self.has_fp16_weights, - ) - new_param.CB = self.CB - new_param.SCB = self.SCB + if device.type != "cpu" or self.data.dtype != torch.int8: + return self._quantize(device) + elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"): + self.CB = self.data - return new_param + new_param = Int8Params( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) + new_param.CB = self.CB + new_param.SCB = self.SCB + + return new_param def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0828dd295..7920e2188 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -38,6 +38,14 @@ def outlier_hook(module, input): hook.remove() +# convert btw standard 4-bit compression format and ipex compression format +def _reverse_4bit_compress_format(weight: torch.Tensor): + out_1 = (weight & 0xF0) >> 4 + out_2 = (weight & 0xF) << 4 + out = out_1 | out_2 + return out + + class OutlierTracer: _instance = None diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 11dfbf5ea..e61ce4655 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -238,15 +238,24 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise #### Intel CPU + XPU -It does not need compile CPP codes, all required ops are in [intel_extension_for_pytorch](https://pytorch-extension.intel.com/), please follow the instruction to install ipex. +If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance. -The below commands are for Linux. For installing on Windows, please adapt the below commands according to the same pattern as described [the section above on compiling from source under the Windows tab](#cuda-compile). +CPU: `pip install intel_extension_for_pytorch` +XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/` -```bash -pip install intel_extension_for_pytorch -git clone --depth 1 -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ -pip install -e . # `-e` for "editable" install, when developing BNB (otherwise leave that out) +Install bitsandbytes: +CPU: Need to build CPU C++ codes +``` +git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/ +cmake -DCOMPUTE_BACKEND=cpu -S . +make +pip install . +``` +XPU: ``` +pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git +``` + diff --git a/tests/test_autograd.py b/tests/test_autograd.py index fc2e7aa6f..5fbe1065f 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -180,9 +180,6 @@ def test_matmul_4bit( compress_statistics, quant_type, ): - if device == "cpu" and quant_type == "fp4": - pytest.xfail("Only nf4 is supported on CPU") - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) if has_bias == False: diff --git a/tests/test_functional.py b/tests/test_functional.py index 8568d45f0..fa4a14ae9 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -103,10 +103,9 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, if nested: pytest.skip("Not a typical use case.") if blocksize != 256: - pytest.skip("Only blocksize 256 is the typical one supported on CPU.") - + pytest.skip("Only blocksize 256 is used in CPU/XPU") if dtype != torch.float32: - pytest.xfail(f"CPU implementation currently only supports float32, got {dtype}") + pytest.skip("Only float32 is used in CPU/XPU") diffs = [] reldiffs = [] @@ -138,10 +137,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - assert abserr < 0.0035 + threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035 + assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 + assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -172,8 +172,8 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) def test_few_bit_quant(self, device, bits, method): - if device == "cpu" and bits != 8: - pytest.skip("CPU implementation only supports 8 bits") + if device in ("cpu", "xpu") and bits != 8: + pytest.skip("CPU/XPU implementation only supports 8 bits") abserrs = [] relerrs = [] @@ -1080,9 +1080,6 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) def test_4bit_quant(self, device, dtype, quant_type, blocksize): - if device == "cpu" and quant_type != "nf4": - pytest.xfail("fp4 quantization is not supported on CPU") - A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) @@ -1115,9 +1112,6 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) def test_4bit_compressed_stats(self, device, quant_type, blocksize): - if device == "cpu" and quant_type != "nf4": - pytest.xfail("fp4 quantization is not supported on CPU") - errs1 = [] errs2 = [] for i in range(10): @@ -1190,12 +1184,6 @@ def test_bench_4bit_dequant(self, quant_type): ) @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double_quant, kind): - if device == "cpu": - if storage_type != "nf4": - pytest.xfail("fp4 quantization is not supported on CPU") - if quant_storage != torch.uint8: - pytest.xfail("Only uint8 storage is supported on CPU") - errs1 = [] errs2 = [] errs3 = [] @@ -1342,13 +1330,6 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): - if device == "cpu": - if storage_type != "nf4": - pytest.xfail("fp4 quantization is not supported on CPU") - - if dtype == torch.bfloat16 and torch.__version__ < (2, 3): - pytest.xfail("eye doe not support bfloat16 on CPU in torch < 2.3") - dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index f3673797c..b5db2eb6f 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -32,12 +32,6 @@ @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("save_before_forward", TRUE_FALSE, ids=id_formatter("save_before_forward")) def test_linear_serialization(device, quant_type, compress_statistics, bias, quant_storage, save_before_forward): - if device == "cpu": - if quant_type == "fp4": - pytest.xfail("FP4 is not supported for CPU") - if quant_storage != "uint8": - pytest.xfail("Only uint8 storage is supported for CPU") - original_dtype = torch.float16 compute_dtype = None layer_shape = (300, 400) @@ -194,13 +188,7 @@ def test_linear_serialization(device, quant_type, compress_statistics, bias, qua @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): - if device == "cpu": - if compress_statistics: - pytest.skip("Currently segfaults on CPU") - if quant_type == "fp4": - pytest.xfail("FP4 not supported on CPU") - - tensor = torch.linspace(1, blocksize, blocksize) + tensor = torch.randn(300, 400) param = bnb.nn.Params4bit( data=tensor, quant_type=quant_type, @@ -219,13 +207,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): - if device == "cpu": - if compress_statistics: - pytest.skip("Currently segfaults on CPU") - if quant_type == "fp4": - pytest.xfail("FP4 not supported on CPU") - - tensor = torch.linspace(1, blocksize, blocksize) + tensor = torch.randn(300, 400) param = bnb.nn.Params4bit( data=tensor, quant_type=quant_type, @@ -251,13 +233,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): - if device == "cpu": - if compress_statistics: - pytest.skip("Currently segfaults on CPU") - if quant_type == "fp4": - pytest.xfail("FP4 not supported on CPU") - - original_tensor = torch.linspace(1, blocksize, blocksize, dtype=torch.float32) + original_tensor = torch.randn(300, 400) original_param = bnb.nn.Params4bit( data=original_tensor, quant_type=quant_type, diff --git a/tests/test_modules.py b/tests/test_modules.py index c8ec6311a..aa6f19c9e 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -391,12 +391,6 @@ def test_fp8linear(): ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ) def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, quant_storage): - if device == "cpu": - if embedding_class is bnb.nn.EmbeddingFP4: - pytest.xfail("FP4 is not supported for CPU") - if quant_storage is not None and quant_storage != torch.uint8: - pytest.xfail("CPU only supports uint8 storage for 4bit") - num_embeddings = 128 src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to( @@ -442,12 +436,6 @@ def test_embedding_lossless(device, embedding_class, input_shape, embedding_dim, ids=lambda x: x.__name__ if inspect.isclass(x) else str(x), ) def test_embedding_error(device, embedding_class, input_shape, embedding_dim, quant_storage): - if device == "cpu": - if embedding_class is bnb.nn.EmbeddingFP4: - pytest.xfail("FP4 is not supported for CPU") - if quant_storage is not None and quant_storage != torch.uint8: - pytest.xfail("CPU only supports uint8 storage for 4bit") - is_8bit = embedding_class is bnb.nn.Embedding8bit num_embeddings = 128 @@ -482,9 +470,6 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu @pytest.mark.parametrize("device", get_available_devices()) def test_4bit_linear_warnings(device): - if device == "cpu": - pytest.xfail("gemv_4bit op is not yet implemented on CPU") - dim1 = 64 with pytest.warns(UserWarning, match=r"inference or training"): diff --git a/tests/test_ops.py b/tests/test_ops.py index e85bc0ef0..9a0ae3338 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -143,6 +143,10 @@ def test_dequantize_blockwise(self, device, dtype, blocksize): assert out.dtype == dtype assert out.device == A.device + # TODO: Enable it + if device == "xpu": + pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check") + opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype)) @@ -153,15 +157,9 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if device == "cpu" and quant_type != "nf4": - pytest.xfail("CPU implementation is only available for nf4") - - if storage_dtype != torch.uint8: - pytest.xfail("Known issue with storage_dtype != uint8") - A = torch.randn(1024, 1024, dtype=dtype, device=device) - out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype) + out, absmax = torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, storage_dtype) assert out.device == A.device assert out.dtype == storage_dtype @@ -169,6 +167,10 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize assert absmax.device == A.device assert absmax.dtype == torch.float32 + # TODO: Enable it + if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16: + pytest.skip("CPU bf16 storage_dtype will fail on torch op check") + opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) @pytest.mark.parametrize("device", get_available_devices()) @@ -177,13 +179,6 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if device == "cpu": - if quant_type != "nf4": - pytest.xfail("CPU implementation is only available for nf4") - - if storage_dtype != torch.uint8: - pytest.xfail("CPU implementation only supports uint8 storage") - shape = (128, 128) n = prod(shape) @@ -215,9 +210,6 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): - if device == "cpu": - pytest.xfail("CPU implementation is not available") - out_features = 1024 in_features = 256