Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

Expand Down Expand Up @@ -272,7 +272,7 @@ def _dequantize_blockwise_impl(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(
Expand Down Expand Up @@ -306,7 +306,7 @@ def _(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down Expand Up @@ -388,7 +388,7 @@ def _dequantize_4bit_impl(
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64, 32])

torch._check(quant_type in ["fp4", "nf4"])
torch._check(
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def quantize_4bit(
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
Expand Down Expand Up @@ -953,7 +953,7 @@ def dequantize_4bit(
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
Raises:
Expand Down
107 changes: 104 additions & 3 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,92 @@ __global__ void kQuantizeBlockwise(
}
}

// Specialized kernel for blocksize=32 with 4-bit quantization
// Processes 2 blocks of 32 values per warp to maintain full thread utilization
// Uses 32 threads total: threads 0-15 handle block 0, threads 16-31 handle block 1
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise32(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
) {
// Fixed parameters for blocksize=32 with 4-bit
constexpr int BLOCK_SIZE = 32; // Size of each quantization block
constexpr int NUM_PER_TH = 2; // Values per thread (for 4-bit packing)
constexpr int THREADS = 32; // Total threads (full warp)
constexpr int THREADS_PER_BLOCK = 16; // Threads handling each quantization block

// Each CUDA thread block processes 2 quantization blocks of 32 values each
const int base_idx = blockIdx.x * BLOCK_SIZE * 2; // 2 blocks per CUDA block

T vals[NUM_PER_TH];
unsigned char qvals[NUM_PER_TH / 2]; // For 4-bit: 2 values per byte
float local_abs_max = 0.0f;

// Determine which quantization block this thread belongs to (0 or 1)
const int block_id = threadIdx.x / THREADS_PER_BLOCK; // 0 for threads 0-15, 1 for threads 16-31
const int local_thread_id = threadIdx.x % THREADS_PER_BLOCK; // Thread ID within the block (0-15)

typedef cub::BlockLoad<T, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, THREADS, NUM_PER_TH / 2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::WarpReduce<float> WarpReduce;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename WarpReduce::TempStorage warp_reduce[2]; // One for each warp half
__shared__ float smem_absmax_value[2]; // Store 2 absmax values

const int i = base_idx + block_id * BLOCK_SIZE;

// Early exit if this quantization block is out of bounds
if (i >= n)
return;

// Load 64 values total (32 threads × 2 values each)
__syncthreads();
LoadT(loadt).Load(&(A[base_idx]), vals, min(BLOCK_SIZE * 2, n - base_idx), (T)0.0f);

// Each thread computes max of its NUM_PER_TH values
local_abs_max = -FLT_MAX;
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

// Warp-level reduction within each half (threads 0-15 and 16-31 separately)
local_abs_max = WarpReduce(warp_reduce[block_id]).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX);

// First thread of each warp half stores the absmax
if (local_thread_id == 0) {
smem_absmax_value[block_id] = 1.0f / local_abs_max;
absmax[blockIdx.x * 2 + block_id] = local_abs_max;
}
__syncthreads();

// Broadcast absmax to all threads in each half
local_abs_max = smem_absmax_value[block_id];

// Quantize values based on data type
switch (DATA_TYPE) {
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}

// Store quantized values (all 32 threads write their outputs)
__syncthreads();
StoreChar(storec).Store(&(out[base_idx / 2]), qvals, min((BLOCK_SIZE * 2 + 1) / 2, (n - base_idx + 1) / 2));
}

template <typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n) {
Expand Down Expand Up @@ -2440,9 +2526,24 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 128, 2, 0, NF4)
MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, 0, NF4)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
// Template instantiations for blocksize=32 specialized kernel (4-bit only)
#define MAKE_kQuantizeBlockwise32(dtype, data_type_name) \
template __global__ void kQuantizeBlockwise32<dtype, data_type_name>( \
float* code, dtype* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand, \
const int rand_offset, const int n \
);

// FP4 instantiations for blocksize=32
MAKE_kQuantizeBlockwise32(half, FP4) MAKE_kQuantizeBlockwise32(float, FP4) MAKE_kQuantizeBlockwise32(__nv_bfloat16, FP4)

// NF4 instantiations for blocksize=32
MAKE_kQuantizeBlockwise32(half, NF4) MAKE_kQuantizeBlockwise32(float, NF4) MAKE_kQuantizeBlockwise32(
__nv_bfloat16, NF4
)

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(
float* code, unsigned char* A, float* absmax, half* out, const int blocksize, const int n
);
Expand Down
5 changes: 5 additions & 0 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ __global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int DATA_TYPE>
__global__ void kQuantizeBlockwise32(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ void quantizeBlockwise(
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 64)
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
else if (blocksize == 32) {
// For 4-bit: use specialized kernel (kQuantizeBlockwise32) that processes 2 blocks per warp
// Each CUDA block handles 2 quantization blocks, so divide num_blocks by 2
if (DATA_TYPE > 0) {
int num_blocks_adjusted = (num_blocks + 1) / 2;
kQuantizeBlockwise32<T, DATA_TYPE><<<num_blocks_adjusted, 32>>>(code, A, absmax, out, rand, rand_offset, n);
}
}

CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
Expand Down
14 changes: 11 additions & 3 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
[32, 64, 128, 256, 512, 1024, 2048, 4096] if not ROCM_WARP_SIZE_64 else [128, 256, 512, 1024, 2048, 4096],
)
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand All @@ -1122,6 +1122,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
error_dict["fp4"] = dict()
error_dict["nf4"] = dict()
error_dict["fp4"]["err"] = {
32: 0.092737,
64: 0.096545,
128: 0.102947,
256: 0.108685,
Expand All @@ -1131,6 +1132,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
4096: 0.129573,
}
error_dict["fp4"]["rel_err"] = {
32: 0.251279,
64: 0.260130,
128: 0.275734,
256: 0.289842,
Expand All @@ -1141,6 +1143,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
}

error_dict["nf4"]["err"] = {
32: 0.070270,
64: 0.072792,
128: 0.076835,
256: 0.080326,
Expand All @@ -1150,6 +1153,7 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
4096: 0.092537,
}
error_dict["nf4"]["rel_err"] = {
32: 0.196508,
64: 0.203299,
128: 0.215252,
256: 0.226044,
Expand All @@ -1168,7 +1172,9 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -1205,7 +1211,9 @@ def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No accelerator device")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize(
"blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128], ids=id_formatter("blocksize")
)
def test_4bit_quant_large(self, device, dtype, quant_type, blocksize):
"""
Test that we can successfully quantize a large tensor. Note that the following limitations apply:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_linear_serialization(

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -250,7 +250,7 @@ def test_params4bit_torch_chunk_split(device, quant_type):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("blocksize", [32, 64, 128] if not ROCM_WARP_SIZE_64 else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand All @@ -176,7 +176,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.parametrize("blocksize", [32, 64, 128, 256, 512] if not ROCM_WARP_SIZE_64 else [128, 256, 512])
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
Expand Down