From 5ff74e93a4a92784912b0a487dabc886a1f199b5 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:20:01 -0400 Subject: [PATCH 1/7] Deprecation cleanup: remove histogram_scatter_add_2d --- bitsandbytes/functional.py | 19 ------------------- csrc/kernels.cu | 13 ------------- csrc/kernels.cuh | 4 ---- csrc/ops.cu | 9 --------- csrc/ops.cuh | 2 -- csrc/pythonInterface.cpp | 1 - 6 files changed, 48 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ffb66681a..9ac154655 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1601,25 +1601,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: return current_gnorm, clip_value, gnorm_scale -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): - assert len(histogram.shape) == 2 - assert histogram.dtype == torch.float32 - assert source.dtype == torch.float32 - assert index1.dtype == torch.int32 - assert index2.dtype == torch.int32 - - assert histogram.device.type == "cuda" - assert index1.device.type == "cuda" - assert index2.device.type == "cuda" - assert source.device.type == "cuda" - - maxdim1 = ct.c_int32(histogram.shape[0]) - n = ct.c_int32(index1.numel()) - is_on_gpu([histogram, index1, index2, source]) - lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) - - def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): if not torch.cuda.is_initialized(): torch.cuda.init() diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 22ee756d9..968f062ab 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -357,19 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } - -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) -{ - const int tid = threadIdx.x + (blockDim.x*blockIdx.x); - const int numThreads = blockDim.x*gridDim.x; - - for(int i = tid; i < n; i+=numThreads) - { - int idx = (index1[i]*maxidx1) + index2[i]; - atomicAdd(&histogram[idx], src[i]); - } -} - #define THREADS_ESTIMATE 512 #define NUM_ESTIMATE 8 #define BLOCK_ESTIMATE 4096 diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index a701481d3..9e49e55ad 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,10 +106,6 @@ template __global__ voi template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); - -__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); - - template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template __global__ void kdequant_mm_int32_fp16( diff --git a/csrc/ops.cu b/csrc/ops.cu index 775984553..bb407876d 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -18,15 +18,6 @@ using namespace BinSearch; using std::cout; using std::endl; -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) -{ - int threads = 512; - int num_blocks = n/threads; - num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; - kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} - template void estimateQuantiles(T *A, float *code, float offset, int n) { int num_blocks = n/4096; diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 48a6a3c74..6556c7315 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -165,8 +165,6 @@ template void optimizerStatic8bitBlockwise(T* p, T* g template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); -void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); - void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 56bec82e8..3b56a4733 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -271,7 +271,6 @@ extern "C" void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } - void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } From 7d214a14b48d310af96e40245a71e89947b1a599 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:30:21 -0400 Subject: [PATCH 2/7] Deprecation cleanup: vectorwise_mm_dequant --- bitsandbytes/functional.py | 63 -------------------------------------- tests/test_functional.py | 21 +++++++++++-- 2 files changed, 19 insertions(+), 65 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9ac154655..607bbf4a2 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2456,69 +2456,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return None -@deprecated( - "This function is deprecated and will be removed in a future release.", - category=FutureWarning, -) -def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): - if quant_type == "linear": - norm = S1 * S2 / (C * C) - # double cast needed to prevent overflows - return (xq.float() * norm).to(dtype) - elif quant_type == "zeropoint": - norm = 1.0 / (S1 * S2) - return (xq.float() * norm).to(dtype) - elif quant_type == "row-zeropoint": - norm = 1.0 / (S1 * S2) - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= norm - else: - x *= norm - return x.to(dtype) - elif quant_type == "vector-zeropoint": - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= 1.0 / S1 - else: - x *= 1.0 / S1 - x *= 1.0 / S2.t() - return x.to(dtype) - elif quant_type == "row": - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= S1 * S2 / (C * C) - else: - x *= S1 * S2 / (C * C) - return x.to(dtype) - elif quant_type in ["truncated-vector", "vector"]: - x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: - S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: - S2 = S2.squeeze(0) - if len(S1.shape) == 2: - x *= S1 / C - else: - x *= S1 / C - x *= S2 / C - return x.to(dtype) - else: - return None - - def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): quant_state = linear.weight.quant_state diff --git a/tests/test_functional.py b/tests/test_functional.py index 6a94205e8..209ccf816 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -564,6 +564,23 @@ def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): class TestLLMInt8Functional: + @staticmethod + def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half): + """Reference implementation for the F.int8_mm_dequant function.""" + C = 127.0 + + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1 / C + else: + x *= S1 / C + x *= S2 / C + return x.to(dtype) + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @@ -630,7 +647,7 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): C2 = F.int8_linear_matmul(A1, B1) - C4 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) + C4 = self.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) if has_bias: C4 += bias @@ -759,7 +776,7 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): C2 = F.int8_linear_matmul(A1, B1) - out3 = F.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) + out3 = self.vectorwise_mm_dequant(C2.float(), maxA, maxB.t()) err1 = torch.abs(out1 - out2).mean().item() err2 = torch.abs(out1 - out3).mean().item() From a2fe201df051519ad92724d58a724e63659a981b Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:43:25 -0400 Subject: [PATCH 3/7] Deprecation cleanup: vectorwise_quant --- bitsandbytes/functional.py | 49 -------------------------------------- tests/test_functional.py | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 57 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 607bbf4a2..310115c14 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2407,55 +2407,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): C = 127.0 -@deprecated( - "This function is deprecated and will be removed in a future release. " - "Consider using `int8_vectorwise_quant` instead.", - category=FutureWarning, -) -def vectorwise_quant(x, dim=1, quant_type="vector"): - if quant_type == "linear": - max1 = torch.abs(x).max().float() - xq = torch.round(x / max1 * 127).to(torch.int8) - return xq, max1 - elif quant_type in ["vector", "row"]: - max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x * (C / max1)).to(torch.int8) - return xq, max1 - elif quant_type == "zeropoint": - dtype = x.dtype - x = x.float() - dyna = x.max() - x.min() - if dyna == 0: - dyna = 1 - qx = 255.0 / dyna - minx = x.min() - zpx = torch.round(minx * qx) - x = torch.round(qx * x - zpx) + zpx - return x, qx - elif quant_type in ["vector-zeropoint", "row-zeropoint"]: - dtype = x.dtype - x = x.float() - dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True) - dyna[dyna == 0] = 1 - qx = 255.0 / dyna - minx = torch.amin(x, dim=dim, keepdim=True) - zpx = torch.round(minx * qx) - x = torch.round(qx * x - zpx) + zpx - return x, qx - elif quant_type == "truncated-vector": - with torch.no_grad(): - absx = torch.abs(x) - max1 = torch.amax(absx, dim=dim, keepdim=True) - max1 = max1 * 0.7 - idx = absx > max1.expand_as(absx) - sign = torch.sign(x[idx]) - x[idx] = max1.expand_as(absx)[idx] * sign - xq = torch.round(x / max1 * C).to(torch.int8) - return xq, max1 - else: - return None - - def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor): quant_state = linear.weight.quant_state diff --git a/tests/test_functional.py b/tests/test_functional.py index 209ccf816..19c15fe5b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -581,6 +581,13 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half): x *= S2 / C return x.to(dtype) + @staticmethod + def vectorwise_quant(x, dim=1): + """Reference implementation""" + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x * (127.0 / max1)).to(torch.int8) + return xq, max1 + @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("dim1", [128], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256], ids=id_formatter("dim2")) @@ -642,8 +649,8 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias): if has_bias: C1 += bias - A1, maxA = F.vectorwise_quant(A, dim=1) - B1, maxB = F.vectorwise_quant(B, dim=1) + A1, maxA = self.vectorwise_quant(A, dim=1) + B1, maxB = self.vectorwise_quant(B, dim=1) C2 = F.int8_linear_matmul(A1, B1) @@ -711,8 +718,8 @@ def test_colrow_absmax(self, dim1, dim2, dims, threshold): def test_int8_double_quant(self, dim1, dim2): for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() - out_col1, Scol = F.vectorwise_quant(A, dim=0) - out_row1, Srow = F.vectorwise_quant(A, dim=1) + out_col1, Scol = self.vectorwise_quant(A, dim=0) + out_row1, Srow = self.vectorwise_quant(A, dim=1) CA, CAt, statsA, statsAt, _ = F.int8_double_quant(A) @@ -764,8 +771,8 @@ def test_integrated_int8_linear_matmul(self, device, dim1, dim4, inner): C1a, stats1a, _ = F.int8_vectorwise_quant(A) C2a, stats2a, _ = F.int8_vectorwise_quant(B) - A1, maxA = F.vectorwise_quant(A, dim=1) - B1, maxB = F.vectorwise_quant(B, dim=1) + A1, maxA = self.vectorwise_quant(A, dim=1) + B1, maxB = self.vectorwise_quant(B, dim=1) torch.testing.assert_close(maxA.flatten().float(), stats1a) torch.testing.assert_close(maxB.flatten().float(), stats2a) @@ -909,8 +916,9 @@ def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func): else: B = torch.randn(dim2, dim2 * 4, device="cuda").half() torch.nn.init.xavier_uniform_(B) - B, SB = F.vectorwise_quant(B, quant_type="linear") - # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + + SB = torch.abs(B).max().float() + B = torch.round(B / SB * 127).to(torch.int8) print("") idx = torch.abs(A) >= threshold From fb25afe1b7cd65a766e8bacf99963d9e461a15ed Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:52:02 -0400 Subject: [PATCH 4/7] Remove unused test --- tests/test_functional.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 19c15fe5b..d721f4129 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1393,26 +1393,3 @@ def test_normal_map_tree(): for i in idx: pivots.append((values[i - 1] + values[i]) / 2) # print(pivots) - - -@pytest.mark.skip("Row scale has some bugs for ampere") -def test_managed(): - n = 32 * 10 - A = F.get_paged(n, n, dtype=torch.float32) - B = F.get_paged(n, n, dtype=torch.uint8) - B2 = F.get_paged(n, n, dtype=torch.float32) - assert A.is_paged - assert B.is_paged - assert A.page_deviceid == 0 - assert B.page_deviceid == 0 - F.fill(A, 17.0) - F.fill(B, 17) - F.fill(B2, 2) - assert (A == 17).sum().item() == n * n - assert (B == 17).sum().item() == n * n - C = A * B.float() - assert (C == 289).sum().item() == n * n - F._mul(A, B2) - F._mul(A, B2) - F._mul(A, B2) - assert (A == 17 * (2**3)).sum().item() == n * n From d3e5c20d599d8f982049211cd7abf3af9cc40b43 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 3 Jun 2025 19:57:53 -0400 Subject: [PATCH 5/7] Optimizer test cleanup --- tests/test_optim.py | 106 +++++++++++++++++--------------------------- 1 file changed, 41 insertions(+), 65 deletions(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 0d86da7d8..75e5a1714 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -289,11 +289,6 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): optimizer_names_8bit = [ - # Non-blockwise optimizers are deprecated. - # "adam8bit", - # "lion8bit", - # "momentum8bit", - # "rmsprop8bit", "adam8bit_blockwise", "lion8bit_blockwise", "momentum8bit_blockwise", @@ -310,11 +305,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): torch.set_printoptions(precision=6) - if gtype == torch.bfloat16 and "blockwise" not in optim_name: - pytest.skip() - if dim1 == 1 and dim2 == 1: return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() @@ -349,39 +342,31 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: - # print(bnb_optimizer.state[p2][max_val], name1) - if "blockwise" in optim_name: - ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] - ## separately and then stack them. The qmap is shared, but absmax is also stacked. - if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": - m1 = F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val][0], - A=bnb_optimizer.state[p2][name2][0], - blocksize=blocksize, - ) - m2 = F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val][1], - A=bnb_optimizer.state[p2][name2][1], - blocksize=blocksize, - ) - - s1 = torch.stack((m1, m2)) + ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] + ## separately and then stack them. The qmap is shared, but absmax is also stacked. + if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": + m1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val][0], + A=bnb_optimizer.state[p2][name2][0], + blocksize=blocksize, + ) + m2 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val][1], + A=bnb_optimizer.state[p2][name2][1], + blocksize=blocksize, + ) - else: - s1 = F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val], - A=bnb_optimizer.state[p2][name2], - blocksize=blocksize, - ) + s1 = torch.stack((m1, m2)) else: - s1 = F.dequantize( + s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, ) + num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 # assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) @@ -414,39 +399,33 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) - if "blockwise" in optim_name: - ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] - ## separately and then stack them. The qmap is shared, but absmax is also stacked. - if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": - s1 = torch.stack( - ( - F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val][0], - A=bnb_optimizer.state[p2][name2][0], - blocksize=blocksize, - ), - F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val][1], - A=bnb_optimizer.state[p2][name2][1], - blocksize=blocksize, - ), - ) - ) - else: - s1 = F.dequantize_blockwise( - code=bnb_optimizer.state[p2][qmap], - absmax=bnb_optimizer.state[p2][max_val], - A=bnb_optimizer.state[p2][name2], - blocksize=blocksize, + ## For AdEMAMix, we need to dequantize [p2][name2][0] and [p2][name2][1] + ## separately and then stack them. The qmap is shared, but absmax is also stacked. + if optim_name == "ademamix8bit_blockwise" and name1 == "m1_m2": + s1 = torch.stack( + ( + F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val][0], + A=bnb_optimizer.state[p2][name2][0], + blocksize=blocksize, + ), + F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val][1], + A=bnb_optimizer.state[p2][name2][1], + blocksize=blocksize, + ), ) + ) else: - s1 = F.dequantize( + s1 = F.dequantize_blockwise( code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, ) + torch.testing.assert_close(s1cpy, s1) num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 @@ -463,9 +442,6 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) - # print(sum(errors)/len(errors)) - # print(sum(relerrors)/len(relerrors)) - @pytest.mark.parametrize("optim_bits", [32, 8], ids=id_formatter("optim_bits")) @pytest.mark.parametrize("gtype", [torch.float32], ids=describe_dtype) From 69d495859f3e05ebce3ef1f6bf75d3f3f6a66ccc Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:00:55 -0400 Subject: [PATCH 6/7] Deprecations: remove estimate_quantiles, create_quantile_map --- bitsandbytes/functional.py | 85 -------------------------------------- csrc/kernels.cu | 76 ---------------------------------- csrc/kernels.cuh | 2 - csrc/ops.cu | 11 ----- csrc/ops.cuh | 3 -- csrc/pythonInterface.cpp | 5 --- tests/test_deprecated.py | 66 ----------------------------- tests/test_functional.py | 8 +--- 8 files changed, 2 insertions(+), 254 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 310115c14..6893752c9 100755 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): return torch.tensor(data, dtype=torch.float32) -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def create_quantile_map(A, total_bits=8): - q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) - q = q.tolist() - q.append(0) - - gap = 256 - len(q) - for i in range(gap): - q.append(0) - - q.sort() - - q = Tensor(q) - q = q / q.abs().max() - return q - - def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): """Verifies that the input tensors are all on the same device. @@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: return ct.c_void_p(A.data_ptr()) -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def estimate_quantiles( - A: Tensor, - out: Optional[torch.Tensor] = None, - offset: float = 1 / 512, - num_quantiles=256, -) -> Tensor: - """ - Estimates 256 equidistant quantiles on the input tensor eCDF. - - Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles - via the eCDF of the input tensor `A`. This is a fast but approximate algorithm - and the extreme quantiles close to 0 and 1 have high variance / large estimation - errors. These large errors can be avoided by using the offset variable which trims - the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it - trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02 - usually has a much lower error but is not a minimum entropy encoding. Given an offset - of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles. - - Parameters - ---------- - A : torch.Tensor - The input tensor. Any shape. - out : torch.Tensor - Tensor with the 256 estimated quantiles. - offset : float - The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) - num_quantiles : int - The number of equally spaced quantiles. - - Returns - ------- - torch.Tensor: - The 256 quantiles in float32 datatype. - """ - if A.numel() < 256: - raise NotImplementedError( - f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.", - ) - if num_quantiles > 256: - raise NotImplementedError( - f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}", - ) - if num_quantiles < 256 and offset == 1 / (512): - # override default arguments - offset = 1 / (2 * num_quantiles) - - if out is None: - out = torch.zeros((256,), dtype=torch.float32, device=A.device) - - with _cuda_device_of(A): - is_on_gpu([A, out]) - - if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - else: - raise NotImplementedError(f"Not supported data type {A.dtype}") - - if num_quantiles < 256: - step = round(256 / num_quantiles) - idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) - out = out[idx] - - return out - - class QuantState: """container for quantization state components to work with Params4bit and similar classes""" diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 968f062ab..7eba3f884 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -357,79 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran } } -#define THREADS_ESTIMATE 512 -#define NUM_ESTIMATE 8 -#define BLOCK_ESTIMATE 4096 - -template -__launch_bounds__(THREADS_ESTIMATE, 1) -__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) -{ - const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); - int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; - const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); - const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); - - T vals[NUM_ESTIMATE]; - - typedef cub::BlockRadixSort BlockRadixSort; - typedef cub::BlockLoad LoadFloat; - - __shared__ union { - typename LoadFloat::TempStorage loadf; - typename BlockRadixSort::TempStorage sort; - int smem_qidx[BLOCK_ESTIMATE]; - } temp_storage; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) - { - valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; - - // do not process half-blocks - if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = max_val; - - __syncthreads(); - LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); - - #pragma unroll 4 - for(int j = 0; j < NUM_ESTIMATE; j++) - vals[j] = ((float)vals[j]) * reciprocal_num_blocks; - - - __syncthreads(); - // sort into striped pattern to mitigate bank conflicts - // striped pattern index for thread 0 [0, 1024, 2048, 3096] - // striped pattern index for thread 1 [1, 1025, 2049, 3097] - BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); - - __syncthreads(); - for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) - temp_storage.smem_qidx[j] = -1; - - __syncthreads(); - - if(threadIdx.x < 256) - { - float q_interval = (1.0f-(2.0f*offset))/255.0f; - int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); - temp_storage.smem_qidx[local_idx] = threadIdx.x; - } - - __syncthreads(); - - for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) - { - if(temp_storage.smem_qidx[i] != -1) - atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); - } - } -} - - __launch_bounds__(TH, 4) __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) { @@ -2985,9 +2912,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); -template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); -template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); - #define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ float* state1, float *unorm, \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 9e49e55ad..c5b996262 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -10,8 +10,6 @@ #define kernels -template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); - __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index bb407876d..a99df1a06 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -18,14 +18,6 @@ using namespace BinSearch; using std::cout; using std::endl; -template void estimateQuantiles(T *A, float *code, float offset, int n) -{ - int num_blocks = n/4096; - num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); - kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); -} void quantize(float *code, float *A, unsigned char *out, int n) { @@ -609,9 +601,6 @@ template int igemmlt<32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, cons template int igemmlt<8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); template int igemmlt<8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream); -template void estimateQuantiles(half *A, float *code, float offset, int n); -template void estimateQuantiles(float *A, float *code, float offset, int n); - template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 6556c7315..99a24a209 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -136,9 +136,6 @@ class ContextCusparse }; - -template void estimateQuantiles(T *A, float *code, float offset, int n); - void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 3b56a4733..0b8b1942b 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -19,9 +19,6 @@ //=================================================================================== #if BUILD_CUDA -void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } -void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } - //void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) //{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } @@ -169,8 +166,6 @@ void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_r extern "C" { #if BUILD_CUDA - void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } - void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); } diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index b950fa70c..9d5d04f23 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -1,6 +1,4 @@ -import numpy as np import pytest -from scipy.stats import norm import torch import bitsandbytes as bnb @@ -9,70 +7,6 @@ from tests.test_autograd import TRANSPOSE_VALS -@pytest.mark.deprecated -def test_kbit_quantile_estimation(): - for i in range(100): - data = torch.randn(1024, 1024, device="cuda") - for bits in range(2, 9): - p = np.linspace(1.3e-4, 1 - 1.3e-4, 2**bits) - val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) - err = torch.abs(val1 - val2).mean() - assert err < 0.038 - - for i in range(100): - data = torch.randn(1024, 1024, device="cuda") - for bits in range(2, 4): - total_values = 2**bits - 1 - p = np.linspace(0, 1, 2 * total_values + 1) - idx = np.arange(1, 2 * total_values + 1, 2) - p = p[idx] - offset = 1 / (2 * total_values) - p = np.linspace(offset, 1 - offset, total_values) - val1 = torch.Tensor(norm.ppf(p)).cuda() - val2 = F.estimate_quantiles(data, num_quantiles=2**bits - 1) - err = torch.abs(val1 - val2).mean() - assert err < 0.035 - - -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) -@pytest.mark.deprecated -def test_estimate_quantiles(dtype): - A = torch.rand(1024, 1024, device="cuda") - A = A.to(dtype) - code = F.estimate_quantiles(A) - - percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) - torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) - - A = torch.randn(1024, 1024, device="cuda") - A = A.to(dtype) - code = F.estimate_quantiles(A) - - quantiles = torch.quantile(A.float(), percs) - diff = torch.abs(code - quantiles) - assert (diff > 5e-02).sum().item() == 0 - - -@pytest.mark.deprecated -def test_quantile_quantization(): - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - assert diff < 0.0075 - - A1 = torch.rand(1024, 1024, device="cuda") - code = F.estimate_quantiles(A1) - C = F.quantize_no_absmax(A1, code) - A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) - assert diff < 0.001 - - @pytest.mark.deprecated def test_dynamic_quantization(): diffs = [] diff --git a/tests/test_functional.py b/tests/test_functional.py index d721f4129..1aa2e1d37 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -170,7 +170,7 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) - @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"]) + @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): if device in ("cpu", "xpu") and bits != 8: pytest.skip("CPU/XPU implementation only supports 8 bits") @@ -186,11 +186,7 @@ def test_few_bit_quant(self, device, bits, method): code = F.create_fp8_map(True, ebits, pbits, bits).to(device) elif method == "dynamic": code = F.create_dynamic_map(True, bits - 0, bits).to(device) - elif method == "quantile": - if device != "cuda": - pytest.skip("Quantile map only works on CUDA") - values = torch.randn(2048, 2048, device="cuda") - code = F.create_quantile_map(values, bits).cuda() + # for some data types we have no zero # for some data types we have one zero # for some data types we have two zeros From d1d0dfe82a607d2b1c9575e53357a464b523bf64 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:10:31 -0400 Subject: [PATCH 7/7] Move deprecated test --- tests/test_deprecated.py | 31 +++++++++++++++++++++++++++++++ tests/test_modules.py | 31 ------------------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 9d5d04f23..f469ff351 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -142,3 +142,34 @@ def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): grad_err = (gradB1 - gradB2).abs().mean() assert grad_err.item() < 0.003 torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) + + +@pytest.mark.deprecated +def test_fp8linear(): + b = 10 + h = 1024 + inp = torch.randn(b, h).cuda() + fp32 = torch.nn.Linear(h, h * 2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() + fp32b = torch.nn.Linear(h * 2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() + + fp8.weight.data.copy_(fp32.weight.data) + fp8.bias.data.copy_(fp32.bias.data) + fp8b.weight.data.copy_(fp32b.weight.data) + fp8b.bias.data.copy_(fp32b.bias.data) + + a = fp32b(torch.nn.functional.gelu(fp32(inp))) + b = fp8b(torch.nn.functional.gelu(fp8(inp))) + + err = (a - b).abs().mean() + + a.mean().backward() + b.mean().backward() + + graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() + + assert err < 0.05 + assert graderr < 0.00002 + assert bgraderr < 0.00002 diff --git a/tests/test_modules.py b/tests/test_modules.py index 319e67714..adbebd09e 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -343,37 +343,6 @@ def test_kbit_backprop(device, module): assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 -@pytest.mark.deprecated -def test_fp8linear(): - b = 10 - h = 1024 - inp = torch.randn(b, h).cuda() - fp32 = torch.nn.Linear(h, h * 2).cuda() - fp8 = bnb.research.nn.LinearFP8Mixed(h, h * 2).cuda() - fp32b = torch.nn.Linear(h * 2, h).cuda() - fp8b = bnb.research.nn.LinearFP8Mixed(h * 2, h).cuda() - - fp8.weight.data.copy_(fp32.weight.data) - fp8.bias.data.copy_(fp32.bias.data) - fp8b.weight.data.copy_(fp32b.weight.data) - fp8b.bias.data.copy_(fp32b.bias.data) - - a = fp32b(torch.nn.functional.gelu(fp32(inp))) - b = fp8b(torch.nn.functional.gelu(fp8(inp))) - - err = (a - b).abs().mean() - - a.mean().backward() - b.mean().backward() - - graderr = (fp8.weight.grad - fp32.weight.grad).abs().mean() - bgraderr = (fp8.bias.grad - fp32.bias.grad).abs().mean() - - assert err < 0.05 - assert graderr < 0.00002 - assert bgraderr < 0.00002 - - @pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("embedding_dim", [64, 65]) @pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)