From 97247931424af4fbafb7b5584dce1a7b19583565 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 09:44:02 +0000 Subject: [PATCH 1/8] enable hf kernel Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 83 ++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 32 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index def87045c..23966c0f3 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -219,6 +219,21 @@ def _( return out if has_avx512bf16(): + gemm_4bit_forward_kernel = None + try: + from pathlib import Path + + from kernels import get_local_kernel + + gemm_4bit_forward_kernel = get_local_kernel( + repo_path=Path( + "/workspace/nix/nix/store/vvsb2xvj5zkzfd37r1k1d5j23hpa9n86-quantization_bitsandbytes-torch-ext" + ), + package_name="quantization_bitsandbytes", + ).gemm_4bit_forward_kernel + except Exception as exc: # pragma: no cover - best effort fallback + gemm_4bit_forward_kernel = None + logger.warning("Failed to load CPU gemm_4bit kernel: %s", exc) @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( @@ -239,38 +254,42 @@ def _( final_out_shape = (*A.shape[:-1], shapeB[0]) A = A.reshape(-1, A.shape[-1]) out_shape = (*A.shape[:-1], shapeB[0]) - out = torch.empty(out_shape, dtype=A.dtype, device=A.device) - M = A.shape[0] - N = shapeB[0] - K = A.shape[1] - x_strideM = A.stride(0) - out_strideM = out.stride(0) - if quant_type == "fp4": - lib.gemv_4bit_inference_cpu_fp4_bf16( - ct.c_int64(M), - ct.c_int64(N), - ct.c_int64(K), - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(out), - ct.c_int64(blocksize), - ct.c_int64(x_strideM), - ct.c_int64(out_strideM), - ) - elif quant_type == "nf4": - lib.gemv_4bit_inference_cpu_nf4_bf16( - ct.c_int64(M), - ct.c_int64(N), - ct.c_int64(K), - get_ptr(A), - get_ptr(B), - get_ptr(absmax), - get_ptr(out), - ct.c_int64(blocksize), - ct.c_int64(x_strideM), - ct.c_int64(out_strideM), - ) + quant_type_num = 1 if quant_type == "fp4" else 0 + if gemm_4bit_forward_kernel is not None: + out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num) + else: + out = torch.empty(out_shape, dtype=A.dtype, device=A.device) + M = A.shape[0] + N = shapeB[0] + K = A.shape[1] + x_strideM = A.stride(0) + out_strideM = out.stride(0) + if quant_type == "fp4": + lib.gemv_4bit_inference_cpu_fp4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) + elif quant_type == "nf4": + lib.gemv_4bit_inference_cpu_nf4_bf16( + ct.c_int64(M), + ct.c_int64(N), + ct.c_int64(K), + get_ptr(A), + get_ptr(B), + get_ptr(absmax), + get_ptr(out), + ct.c_int64(blocksize), + ct.c_int64(x_strideM), + ct.c_int64(out_strideM), + ) if dtype != torch.bfloat16: out = out.to(dtype) From 2e45f759dfdefeff7308b46799d10196026cfa7e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 09:48:07 +0000 Subject: [PATCH 2/8] fix typo Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 23966c0f3..62bfd6b2d 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -230,7 +230,7 @@ def _( "/workspace/nix/nix/store/vvsb2xvj5zkzfd37r1k1d5j23hpa9n86-quantization_bitsandbytes-torch-ext" ), package_name="quantization_bitsandbytes", - ).gemm_4bit_forward_kernel + ).gemm_4bit_forward except Exception as exc: # pragma: no cover - best effort fallback gemm_4bit_forward_kernel = None logger.warning("Failed to load CPU gemm_4bit kernel: %s", exc) From eafb770c9b0c9da16e0d05f1b614b846b11f17e8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 09:54:23 +0000 Subject: [PATCH 3/8] add kernels dep Signed-off-by: jiqing-feng --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4f807e04e..d6723b7e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "torch>=2.3,<3", "numpy>=1.17", "packaging>=20.9" + "kernels>=0.11.1" ] [project.urls] From bba2e245a37647bba514ed3725c714e5053a9d0a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 09:59:41 +0000 Subject: [PATCH 4/8] fix typo Signed-off-by: jiqing-feng --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d6723b7e2..fb18fe81a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ classifiers = [ dependencies = [ "torch>=2.3,<3", "numpy>=1.17", - "packaging>=20.9" + "packaging>=20.9", "kernels>=0.11.1" ] From 3c2729ff028d44dd08e2ff0c022f6760275d4835 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 28 Nov 2025 15:50:40 +0000 Subject: [PATCH 5/8] update tests Signed-off-by: jiqing-feng --- tests/test_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 8d9aa5ab2..3218b9215 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -237,7 +237,6 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): quant_type=quant_type, ) B_q, state = bitsandbytes.functional._convert_weight_packed_for_cpu(B_q, state) - B_q = B_q.t() absmax = state.absmax out = torch.ops.bitsandbytes.gemv_4bit.default(A, B_q, B.shape, absmax, code, blocksize) From 3465403adb13ef6c1c68a6b097888807e90ccfb7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 3 Dec 2025 16:18:52 +0000 Subject: [PATCH 6/8] optional for kernels Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 5 ++++- pyproject.toml | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 62bfd6b2d..1bce60aca 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -233,7 +233,10 @@ def _( ).gemm_4bit_forward except Exception as exc: # pragma: no cover - best effort fallback gemm_4bit_forward_kernel = None - logger.warning("Failed to load CPU gemm_4bit kernel: %s", exc) + logger.warning( + "Failed to load CPU gemm_4bit kernel: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1", + exc, + ) @register_kernel("bitsandbytes::gemv_4bit", "cpu") def _( diff --git a/pyproject.toml b/pyproject.toml index fb18fe81a..43c98be8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ "torch>=2.3,<3", "numpy>=1.17", "packaging>=20.9", - "kernels>=0.11.1" ] [project.urls] From f7f18f557fd05d69556c99392c1e9b7597008a8c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 5 Dec 2025 09:19:46 +0000 Subject: [PATCH 7/8] update kernel Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 1bce60aca..a2c74488b 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -221,16 +221,9 @@ def _( if has_avx512bf16(): gemm_4bit_forward_kernel = None try: - from pathlib import Path + from kernels import get_kernel - from kernels import get_local_kernel - - gemm_4bit_forward_kernel = get_local_kernel( - repo_path=Path( - "/workspace/nix/nix/store/vvsb2xvj5zkzfd37r1k1d5j23hpa9n86-quantization_bitsandbytes-torch-ext" - ), - package_name="quantization_bitsandbytes", - ).gemm_4bit_forward + gemm_4bit_forward_kernel = get_kernel("kernels-community/quantization_bitsandbytes").gemm_4bit_forward except Exception as exc: # pragma: no cover - best effort fallback gemm_4bit_forward_kernel = None logger.warning( From 2c667a434da8b7273853e2f598762c6260873ecf Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 9 Dec 2025 09:03:58 +0000 Subject: [PATCH 8/8] fix format Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index a2c74488b..436676c99 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -227,7 +227,7 @@ def _( except Exception as exc: # pragma: no cover - best effort fallback gemm_4bit_forward_kernel = None logger.warning( - "Failed to load CPU gemm_4bit kernel: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1", + "Failed to load CPU gemm_4bit_forward from kernels-community: %s. Please make sure you already `pip install kernels` and the kernels >= 0.11.1", exc, ) @@ -250,8 +250,8 @@ def _( final_out_shape = (*A.shape[:-1], shapeB[0]) A = A.reshape(-1, A.shape[-1]) out_shape = (*A.shape[:-1], shapeB[0]) - quant_type_num = 1 if quant_type == "fp4" else 0 if gemm_4bit_forward_kernel is not None: + quant_type_num = 1 if quant_type == "fp4" else 0 out = gemm_4bit_forward_kernel(A, B, absmax, blocksize, quant_type_num) else: out = torch.empty(out_shape, dtype=A.dtype, device=A.device)