From cee93db6ebb98eeed6c1e9b69e5586fb6cc98a0c Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:16:38 -0400 Subject: [PATCH] Fix indexing overflow issue for blockwise quantization with large tensor sizes --- bitsandbytes/backends/cuda/ops.py | 4 +- csrc/kernels.cu | 15 ++++--- csrc/ops.cu | 13 +++--- tests/test_functional.py | 72 ++++++++++++++++++++++++++++--- 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index 30cad3e34..8d6b55366 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -326,7 +326,7 @@ def _( get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), - ct.c_int(n), + ct.c_int32(n), ) if A.dtype == torch.bfloat16: @@ -403,7 +403,7 @@ def _dequantize_4bit_impl( get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), - ct.c_int(out.numel()), + ct.c_int32(out.numel()), _get_tensor_stream(A), ) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 2c232da80..de48f5e82 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -328,14 +328,16 @@ __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ) { - const int n_full = gridDim.x * BLOCK_SIZE; + // This can overflow, so we clamp to INT32_MAX. We won't have more elements than this. + const int n_full = min(gridDim.x * BLOCK_SIZE, INT32_MAX); + + const int base_idx = blockIdx.x * BLOCK_SIZE; int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH / 2 : NUM_PER_TH]; - // float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; int local_rand_idx = 0; @@ -358,8 +360,8 @@ __global__ void kQuantizeBlockwise( for (int i = threadIdx.x; i < 256; i += blockDim.x) smem_code[i] = code[i]; - for (int i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + for (int64_t i = base_idx; i < n_full; i += gridDim.x * BLOCK_SIZE) { + valid_items = min(BLOCK_SIZE, static_cast(n - i)); local_abs_max = -FLT_MAX; __syncthreads(); @@ -442,7 +444,8 @@ __global__ void for (int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { if (DATA_TYPE > 0) { - valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i); + // Cast n to int64_t to avoid overflow for large n + valid_items_load = min(TILE_SIZE, static_cast((static_cast(n) + 1) / 2) - i); valid_items_store = min(TILE_SIZE * 2, n - i * 2); } else { valid_items_load = min(TILE_SIZE, n - i); diff --git a/csrc/ops.cu b/csrc/ops.cu index 71256719f..aafab7522 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -61,16 +61,17 @@ template void dequantizeBlockwise( float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, cudaStream_t stream ) { - // printf("stream==%d\n",stream); - int num_blocks = n / blocksize; - num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + constexpr int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + // Upcast to int64 to avoid overflow for large n + int grid_blocks = ((int64_t)n + tile_size - 1) / tile_size; + if (DATA_TYPE > 0) kDequantizeBlockwise - <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize / 2, n); + <<>>(code, A, absmax, out, blocksize / 2, n); else kDequantizeBlockwise - <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); + <<>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 072e3b4f5..08de12008 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -151,6 +151,34 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, assert relerr < 0.012 assert A2.dtype == dtype + @pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) + @pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) + @pytest.mark.parametrize("blocksize", [256], ids=id_formatter("blocksize")) + def test_dynamic_blockwise_quantization_large(self, device, dtype, blocksize): + """ + Test that we can successfully quantize a large tensor. Note that the following limitations apply: + - On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels. + - On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test. + - Verification of the accuracy for dequantization has too high memory overhead for this test. + """ + if device not in ["cuda", "xpu"]: + pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.") + + data = torch.randn(2**31 - 1, device=device, dtype=dtype) + q_data, q_stats = F.quantize_blockwise(data, blocksize=blocksize) + + assert q_data is not None + assert q_data.dtype == torch.uint8 + assert q_data.numel() == data.numel() + + # Dequant + del data + dq = F.dequantize_blockwise(q_data, q_stats) + + assert dq.dtype == dtype + assert dq.numel() == q_data.numel() + @pytest.mark.skipif("cpu" not in get_available_devices(), reason="CPU is required") @pytest.mark.parametrize("hidden", [128]) @pytest.mark.parametrize("blocksize", [4096, 16384]) @@ -1118,18 +1146,17 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): A1 = torch.randn(1024, 1024, device=device, dtype=dtype) qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + del qa, SA + + assert A2.dtype == dtype err = (A1 - A2).abs().float() + del A2 + relerr = (err / (A1.abs().float() + 1e-8)).mean() err = err.mean() - assert A2.dtype == dtype - - # With larger block sizes, we can expect this to blow up. - # At blocksize>=1024, don't even bother looking at relerr. - # - # Actually, the above is not true anymore after fixing the integer packing bug. - # The following values were taken from averaging 1k samples per test configuration after fixing the bug. + # The following values were taken from averaging 1k samples per test configuration. error_dict = dict() error_dict["fp4"] = dict() error_dict["nf4"] = dict() @@ -1213,6 +1240,37 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): assert err.item() < 0.11 assert relerr.item() < 0.28 + @pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) + @pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize")) + def test_4bit_quant_large(self, device, dtype, quant_type, blocksize): + """ + Test that we can successfully quantize a large tensor. Note that the following limitations apply: + - On CUDA/XPU/ROCm, the maximum number of elements is limited to 2**31 - 1 due to int32 indexing in C++ kernels. + - On CUDA, this test requires ~10GiB of memory for fp32 + - On CPU, there is a significantly higher memory overhead for the quantization, so we skip this test. + - Verification of the accuracy for dequantization has too high memory overhead for this test. + """ + + if device not in ["cuda", "xpu"]: + pytest.skip("This test is only for CUDA and XPU devices due to memory constraints.") + + A1 = torch.randn(2**31 - 1, device=device, dtype=dtype) + qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + + assert qa is not None + assert qa.dtype == torch.uint8 + assert qa.numel() == (2**31 - 1 + 1) // 2 # each byte holds 2 quantized values + + # Dequant + del A1 + dq = F.dequantize_4bit(qa, SA) + + assert dq.dtype == dtype + assert dq.numel() == 2**31 - 1 + # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ["nf4"]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")