From 6e55c34f0b56161d0ae3e81f5c309207f5cada27 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 12:40:56 +0000 Subject: [PATCH 1/9] enable fp16/bf16 absmax Signed-off-by: jiqing-feng --- bitsandbytes/backends/cpu/ops.py | 2 +- bitsandbytes/backends/default/ops.py | 6 +++--- bitsandbytes/functional.py | 9 --------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 5f009ea40..1727bcb46 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -49,7 +49,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor rem = n % blocksize has_rem = rem > 0 blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index ce5926979..48d30ced4 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -154,7 +154,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor rem = n % blocksize has_rem = rem > 0 blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) @@ -204,7 +204,7 @@ def _( full_blocks = n // blocksize rem = n % blocksize blocks = full_blocks + 1 if rem else full_blocks - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_flattened = A.reshape(n) # Scale full blocks of the tensor to [-1, 1] @@ -229,7 +229,7 @@ def _( if quant_storage != torch.uint8: packed = packed.squeeze().view(quant_storage).unsqueeze(1) - return packed, absmax.float() + return packed, absmax @register_kernel("bitsandbytes::dequantize_4bit", "default") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6893752c9..34ef8f268 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -759,8 +759,6 @@ def dequantize_blockwise( if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() if out is not None: torch.ops.bitsandbytes.dequantize_blockwise.out( @@ -1034,8 +1032,6 @@ def dequantize_4bit( if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() # IPEX format is different, we need extra process. if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4": @@ -1079,8 +1075,6 @@ def quantize( code = code.to(A.device) absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: - absmax = absmax.float() inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) @@ -2328,9 +2322,6 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - quant_state.absmax = absmax quant_state.nested = False delattr(quant_state, "state2") From 73543ef2aef2c92a387706a4e49665ae25333a89 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 13:24:12 +0000 Subject: [PATCH 2/9] fix absmax dtype Signed-off-by: jiqing-feng --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 34ef8f268..b8b6bbb4f 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2320,7 +2320,7 @@ def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): quant_state = linear.weight.quant_state if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2).to(x.dtype) absmax += quant_state.offset quant_state.absmax = absmax quant_state.nested = False From 18f971576472daad4d7a965ef51daad16ec90db2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 14:22:05 +0000 Subject: [PATCH 3/9] fix ipex op Signed-off-by: jiqing-feng --- bitsandbytes/backends/xpu/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 47a3bd009..591147e42 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -42,9 +42,9 @@ def _( if dtype == torch.float16: ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) elif dtype == torch.bfloat16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax.float(), out, blocksize, A.numel()) elif dtype == torch.float32: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax.float(), out, blocksize, A.numel()) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") From 6f4854842d37c382a31a0e60d359f8ee58acced2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 15:07:24 +0000 Subject: [PATCH 4/9] fx tests Signed-off-by: jiqing-feng --- tests/test_functional.py | 2 -- tests/test_ops.py | 9 --------- 2 files changed, 11 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 1aa2e1d37..5f2d6a811 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -104,8 +104,6 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, pytest.skip("Not a typical use case.") if blocksize != 256: pytest.skip("Only blocksize 256 is used in CPU/XPU") - if dtype != torch.float32: - pytest.skip("Only float32 is used in CPU/XPU") diffs = [] reldiffs = [] diff --git a/tests/test_ops.py b/tests/test_ops.py index 7da19c012..9b413a974 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -53,7 +53,6 @@ def test_int8_vectorwise_quant(self, threshold, device): assert out_row.dtype == torch.int8 assert out_row.device == A.device assert row_stats.shape == (10,) - assert row_stats.dtype == torch.float32 assert row_stats.device == A.device if threshold > 0.0: @@ -104,9 +103,6 @@ class TestInt8BlockwiseQuantOps: @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_quantize_blockwise(self, device, dtype, blocksize): if device == "cpu": - if dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - if blocksize != 256: pytest.skip("CPU implementation is slow; only test blocksize=256") @@ -119,7 +115,6 @@ def test_quantize_blockwise(self, device, dtype, blocksize): assert out.device == A.device assert absmax.device == A.device - assert absmax.dtype == torch.float32 opcheck(torch.ops.bitsandbytes.quantize_blockwise, (A, code, blocksize)) @@ -127,9 +122,6 @@ def test_quantize_blockwise(self, device, dtype, blocksize): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512]) def test_dequantize_blockwise(self, device, dtype, blocksize): - if device == "cpu" and dtype != torch.float32: - pytest.skip("CPU implementation is only available for float32") - A = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8, device=device) code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32) @@ -165,7 +157,6 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize assert out.dtype == storage_dtype assert absmax.device == A.device - assert absmax.dtype == torch.float32 if storage_dtype != torch.uint8: pytest.xfail("opcheck fails for storage_dtype != torch.uint8") From c4b3cca2723b6037f38740b2be17ecdbbd2e6ae0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 15:18:32 +0000 Subject: [PATCH 5/9] fix ipex input dtype Signed-off-by: jiqing-feng --- bitsandbytes/backends/xpu/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py index 591147e42..62d61f10b 100755 --- a/bitsandbytes/backends/xpu/ops.py +++ b/bitsandbytes/backends/xpu/ops.py @@ -40,11 +40,11 @@ def _( # void cdequantize_blockwise_fp32( # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) if dtype == torch.float16: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax.float(), out, blocksize, A.numel()) elif dtype == torch.bfloat16: ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax.float(), out, blocksize, A.numel()) elif dtype == torch.float32: - ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax.float(), out, blocksize, A.numel()) + ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") From d24f47d13cf125ae05583e1e81bc1ee7934e3b06 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 15:31:05 +0000 Subject: [PATCH 6/9] fix meta register dtype Signed-off-by: jiqing-feng --- bitsandbytes/_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index a260852f5..65b815e38 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -225,7 +225,7 @@ def _( n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) return out, absmax @@ -268,7 +268,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor torch._check_is_size(blocksize) n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) out = torch.empty_like(A, dtype=torch.uint8) return out, absmax From 3ce11b3ffe6d493373b6ac735f9e1099a6257e7f Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 15:34:15 +0000 Subject: [PATCH 7/9] fix test threshold Signed-off-by: jiqing-feng --- tests/test_functional.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 5f2d6a811..38780dce6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -135,11 +135,10 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023 + assert abserr < 0.0023 assert relerr < 0.012 assert A2.dtype == dtype From 8799041b028f218c38cad453516e98db5b282519 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 16:30:24 +0000 Subject: [PATCH 8/9] revert mistake change Signed-off-by: jiqing-feng --- tests/test_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_ops.py b/tests/test_ops.py index 9b413a974..a29970480 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -53,6 +53,7 @@ def test_int8_vectorwise_quant(self, threshold, device): assert out_row.dtype == torch.int8 assert out_row.device == A.device assert row_stats.shape == (10,) + assert row_stats.dtype == torch.float32 assert row_stats.device == A.device if threshold > 0.0: From daad33d80af527692c1bf5f560150a712d6eb481 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 9 Jun 2025 10:36:19 +0000 Subject: [PATCH 9/9] keep cuda op Signed-off-by: jiqing-feng --- bitsandbytes/_ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 65b815e38..2a1d7aac3 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -225,7 +225,8 @@ def _( n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + dtype = torch.float32 if torch.cuda.is_available() else A.dtype + absmax = torch.empty((blocks,), device=A.device, dtype=dtype) out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) return out, absmax @@ -268,7 +269,8 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor torch._check_is_size(blocksize) n = A.numel() blocks = -(n // -blocksize) - absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + dtype = torch.float32 if torch.cuda.is_available() else A.dtype + absmax = torch.empty((blocks,), device=A.device, dtype=dtype) out = torch.empty_like(A, dtype=torch.uint8) return out, absmax