diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index cb2796b1f..6305fc0c7 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -214,7 +214,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") @@ -272,7 +272,7 @@ def _dequantize_blockwise_impl( if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check( @@ -306,7 +306,7 @@ def _( if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( @@ -388,7 +388,7 @@ def _dequantize_4bit_impl( if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128]) else: - torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) + torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32]) torch._check(quant_type in ["fp4", "nf4"]) torch._check( diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2a2a40273..bca3dd66d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -842,7 +842,7 @@ def quantize_4bit( out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`. @@ -953,7 +953,7 @@ def dequantize_4bit( out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. - Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. Raises: diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8af27075e..7100b5bd2 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -423,6 +423,92 @@ __global__ void kQuantizeBlockwise( } } +// Specialized kernel for blocksize=32 with 4-bit quantization +// Processes 2 blocks of 32 values per warp to maintain full thread utilization +// Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1 +template +__global__ void kQuantizeBlockwise32( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +) { + // Fixed parameters for blocksize=32 with 4-bit + constexpr int BLOCK_SIZE = 32; // Size of each quantization block + constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing) + constexpr int THREADS = 32; // Total threads (full warp) + constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block + + // Each CUDA thread block processes 2 quantization blocks of 32 values each + const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block + + T vals[NUM_PER_TH]; + unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte + float local_abs_max = 0.0f; + + // Determine which quantization block this thread belongs to (0 or 1) + const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31 + const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15) + + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore StoreChar; + typedef cub::WarpReduce WarpReduce; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One for each warp half + __shared__ float smem_absmax_value[2]; // Store 2 absmax values + + const int i = base_idx + block_id * BLOCK_SIZE; + + // Early exit if this quantization block is out of bounds + if (i >= n) + return; + + // Load 64 values total (32 threads × 2 values each) + __syncthreads(); + LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f); + + // Each thread computes max of its NUM_PER_TH values + local_abs_max = -FLT_MAX; +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + // Warp-level reduction within each half (threads 0-15 and 16-31 separately) + local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX); + + // First thread of each warp half stores the absmax + if (local_thread_id == 0) { + smem_absmax_value[block_id] = 1.0f / local_abs_max; + absmax[blockIdx.x * 2 + block_id] = local_abs_max; + } + __syncthreads(); + + // Broadcast absmax to all threads in each half + local_abs_max = smem_absmax_value[block_id]; + + // Quantize values based on data type + switch (DATA_TYPE) { + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH / 2; j++) { + qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4; + qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max); + } + break; + } + + // Store quantized values (all 32 threads write their outputs) + __syncthreads(); + StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2)); +} + template __global__ void kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) { @@ -2440,9 +2526,24 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4) MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4) -template __global__ void kDequantizeBlockwise( - float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n -); +// Template instantiations for blocksize=32 specialized kernel (4-bit only) +#define MAKE_kQuantizeBlockwise32(dtype, data_type_name) \ + template __global__ void kQuantizeBlockwise32( \ + float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \ + const int rand_offset, const int n \ + ); + +// FP4 instantiations for blocksize=32 +MAKE_kQuantizeBlockwise32(half, FP4) MAKE_kQuantizeBlockwise32(float, FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4) + + // NF4 instantiations for blocksize=32 + MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float, NF4) MAKE_kQuantizeBlockwise32( + __nv_bfloat16, NF4 + ) + + template __global__ void kDequantizeBlockwise( + float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n + ); template __global__ void kDequantizeBlockwise( float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n ); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 558b46236..e7a1282bc 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -17,6 +17,11 @@ __global__ void kQuantizeBlockwise( float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, const int rand_offset, const int n ); +template +__global__ void kQuantizeBlockwise32( + float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, + const int rand_offset, const int n +); template __global__ void kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 226ed10f6..875c82b1c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,6 +50,14 @@ void quantizeBlockwise( kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if (blocksize == 64) kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if (blocksize == 32) { + // For 4-bit: use specialized kernel (kQuantizeBlockwise32) that processes 2 blocks per warp + // Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2 + if (DATA_TYPE > 0) { + int num_blocks_adjusted = (num_blocks + 1) / 2; + kQuantizeBlockwise32<<>>(code, A, absmax, out, rand, rand_offset, n); + } + } CUDA_CHECK_RETURN(cudaPeekAtLastError()); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 55964818c..d65c603ed 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional: @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize( "blocksize", - [64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096], + [32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096], ) def test_4bit_quant(self, device, dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1122,6 +1122,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): error_dict["fp4"] = dict() error_dict["nf4"] = dict() error_dict["fp4"]["err"] = { + 32: 0.092737, 64: 0.096545, 128: 0.102947, 256: 0.108685, @@ -1131,6 +1132,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): 4096: 0.129573, } error_dict["fp4"]["rel_err"] = { + 32: 0.251279, 64: 0.260130, 128: 0.275734, 256: 0.289842, @@ -1141,6 +1143,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): } error_dict["nf4"]["err"] = { + 32: 0.070270, 64: 0.072792, 128: 0.076835, 256: 0.080326, @@ -1150,6 +1153,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): 4096: 0.092537, } error_dict["nf4"]["rel_err"] = { + 32: 0.196508, 64: 0.203299, 128: 0.215252, 256: 0.226044, @@ -1168,7 +1172,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize( + "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize") + ) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -1205,7 +1211,9 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): @pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")) + @pytest.mark.parametrize( + "blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize") + ) def test_4bit_quant_large(self, device, dtype, quant_type, blocksize): """ Test that we can successfully quantize a large tensor. Note that the following limitations apply: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 2b92ee4f1..aa693713c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -193,7 +193,7 @@ def test_linear_serialization( @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_copy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): @@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) -@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128]) +@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): if device == "hpu" and not is_supported_on_hpu(quant_type): diff --git a/tests/test_ops.py b/tests/test_ops.py index 3218b9215..aa20995ee 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps: @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): pytest.skip("This configuration is not supported on HPU.") @@ -210,7 +210,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) + @pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512]) @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):