From 5102319dadec2719a7e6cad85b2e450ccf064321 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 26 Mar 2025 13:46:42 -0400 Subject: [PATCH 1/4] Testing cleanup --- benchmarking/optimizer_benchmark.py | 56 ++++++++++++++++++ tests/conftest.py | 11 ++++ tests/test_autograd.py | 91 ++--------------------------- tests/test_deprecated.py | 87 +++++++++++++++++++++++++++ tests/test_functional.py | 2 +- tests/test_generation.py | 2 +- tests/test_optim.py | 41 ------------- tests/test_triton.py | 1 + 8 files changed, 161 insertions(+), 130 deletions(-) create mode 100644 benchmarking/optimizer_benchmark.py diff --git a/benchmarking/optimizer_benchmark.py b/benchmarking/optimizer_benchmark.py new file mode 100644 index 000000000..27dae7ae0 --- /dev/null +++ b/benchmarking/optimizer_benchmark.py @@ -0,0 +1,56 @@ +""" +Extracted from tests/test_optim.py + +Usage: pytest benchmarking/optimizer_benchmark.py +""" + +import time + +import pytest +from tests.helpers import describe_dtype, id_formatter +import torch + +import bitsandbytes as bnb + +str2optimizers = {"paged_adamw": (torch.optim.AdamW, bnb.optim.PagedAdamW)} + + +@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) +@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) +@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) +@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) +@pytest.mark.benchmark +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == "torch": + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device="cuda") + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) + lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i == 2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0) diff --git a/tests/conftest.py b/tests/conftest.py index c029c3cb5..a514e1284 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,22 @@ import gc +import random +import numpy as np import pytest import torch +def _set_seed(): + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.mps.manual_seed(0) + np.random.seed(0) + random.seed(0) + + def pytest_runtest_call(item): try: + _set_seed() item.runtest() except AssertionError as ae: if str(ae) == "Torch not compiled with CUDA enabled": diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 4b93ebcbe..347a93131 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -6,7 +6,6 @@ BOOLEAN_TRIPLES, TRUE_FALSE, describe_dtype, - get_test_dims, id_formatter, ) @@ -136,10 +135,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec torch.testing.assert_close(gradBias1, gradBias2) -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("dim1", [48], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [64, 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [64], ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", [96], ids=id_formatter("dim4")) @pytest.mark.parametrize("funcs", [(torch.matmul, bnb.matmul_4bit)], ids=["func=matmul"]) @pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) @pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) @@ -231,85 +230,3 @@ def test_matmul_4bit( if req_grad[2]: torch.testing.assert_close(gradBias1, gradBias2) - - -@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) -@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) -@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) -@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) -@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) -@pytest.mark.parametrize( - "funcs", - [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], - ids=["matmul_fp8_mixed", "matmul_fp8_global"], -) -def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): - dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) - dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - req_grad = list(req_grad) - req_grad[2] = False - - for i in range(3): - # normal multiply - if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) - B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) - - torch.nn.init.xavier_uniform_(B) - - fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) - bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) - - if not transpose[0] and transpose[1]: - out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B.t(), fw_code, bw_code) - elif not transpose[0] and not transpose[1]: - out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B, fw_code, bw_code) - - assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" - - n = out_bnb.numel() - err = torch.abs(out_bnb - out_torch).float().mean().item() - if n > 0: - assert err < 0.115 - # assert err < 0.20 - if any(req_grad): - out_bnb.data.copy_(out_torch) - torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() - loss_bnb.backward() - gradA1 = A.grad - gradB1 = B.grad - A.grad = None - B.grad = None - - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() - loss_torch.backward() - gradA2 = A.grad - gradB2 = B.grad - A.grad = None - B.grad = None - - if req_grad[0]: - torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) - - if req_grad[1]: - n = gradB1.numel() - if dim2 > 0: - assert torch.abs(gradB1).sum() > 0.0 - assert torch.abs(gradB2).sum() > 0.0 - else: - assert torch.abs(gradB1).sum() == 0.0 - assert torch.abs(gradB2).sum() == 0.0 - idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - - assert (idx == 0).sum().item() <= n * 0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx == 0).sum().item() <= n * 0.02 - grad_err = (gradB1 - gradB2).abs().mean() - assert grad_err.item() < 0.003 - torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 9872cdfca..b950fa70c 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -3,7 +3,10 @@ from scipy.stats import norm import torch +import bitsandbytes as bnb from bitsandbytes import functional as F +from tests.helpers import BOOLEAN_TRIPLES, describe_dtype, get_test_dims, id_formatter +from tests.test_autograd import TRANSPOSE_VALS @pytest.mark.deprecated @@ -121,3 +124,87 @@ def test_percentile_clipping(gtype): torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) torch.testing.assert_close(clip1, clip2) torch.testing.assert_close(gnorm1, gnorm2) + + +@pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim4", get_test_dims(32, 96, n=1), ids=id_formatter("dim4")) +@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad")) +@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose")) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize( + "funcs", + [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)], + ids=["matmul_fp8_mixed", "matmul_fp8_global"], +) +@pytest.mark.deprecated +@pytest.mark.skip("Deprecated functionality, to be removed.") +def test_matmul_fp8(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(3): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + + torch.nn.init.xavier_uniform_(B) + + fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) + bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t(), fw_code, bw_code) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B, fw_code, bw_code) + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + # assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + grad_err = (gradB1 - gradB2).abs().mean() + assert grad_err.item() < 0.003 + torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3) diff --git a/tests/test_functional.py b/tests/test_functional.py index b4172dd35..66cddb661 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -893,7 +893,7 @@ def test_spmm_coo_very_sparse(self, dim1, dim2, dtype, out_func): @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [256, 1024], ids=id_formatter("dim2")) - @pytest.skip("No longer supported") + @pytest.mark.skip("No longer supported") def test_integrated_sparse_decomp(self, dim1, dim2): threshold = 3.0 for _ in range(k): diff --git a/tests/test_generation.py b/tests/test_generation.py index 911aa14da..38b5ce9bd 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -60,7 +60,7 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f return tokenizer.decode(outputs[0], skip_special_tokens=True) -models = ["huggyllama/llama-7b", "bigscience/bloom-1b7"] +models = ["bigscience/bloom-1b7"] dtypes = ["nf4", "fp4"] diff --git a/tests/test_optim.py b/tests/test_optim.py index 3384f4c1d..2bc3752f3 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -604,44 +604,3 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): params = (total_steps - total_steps // 5) * dim1 * dim2 print(optim_name, gtype, s, params, s / params) # assert s < 3.9 - - -@pytest.mark.parametrize("dim1", [2 * 1024], ids=id_formatter("dim1")) -@pytest.mark.parametrize("gtype", [torch.float16], ids=describe_dtype) -@pytest.mark.parametrize("optim_name", ["paged_adamw"], ids=id_formatter("optim_name")) -@pytest.mark.parametrize("mode", ["bnb"], ids=id_formatter("mode")) -@pytest.mark.benchmark -def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): - layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) - layers1 = layers1.to(gtype) - layers1 = layers1.cuda() - - large_tensor = None - if mode == "torch": - optim = str2optimizers[optim_name][0](layers1.parameters()) - else: - optim = str2optimizers[optim_name][1](layers1.parameters()) - # 12 GB - large_tensor = torch.empty((int(4.5e9),), device="cuda") - - torch.cuda.synchronize() - time.sleep(5) - - num_batches = 5 - batches = torch.randn(num_batches, 128, dim1, device="cuda").to(gtype) - lbls = torch.randint(0, 10, size=(num_batches, 128)).cuda() - - for i in range(num_batches): - print(i) - b = batches[i] - if i == 2: - torch.cuda.synchronize() - t0 = time.time() - - out1 = layers1(b) - - loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() - loss1.backward() - optim.step() - torch.cuda.synchronize() - print(mode, time.time() - t0) diff --git a/tests/test_triton.py b/tests/test_triton.py index 3624fb5e9..70656a56f 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -11,6 +11,7 @@ not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, reason="This test requires triton and a GPU with compute capability 8.0 or higher.", ) +@pytest.mark.skip("No longer supported.") @pytest.mark.parametrize("vector_wise_quantization", TRUE_FALSE) def test_switchback(vector_wise_quantization): for dim in [83]: From feb113935d74ce46d5110a1b1a119c2bdcc44e36 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:34:33 -0400 Subject: [PATCH 2/4] More test cleanup --- tests/test_functional.py | 337 +++++++++++++++++++-------------------- tests/test_ops.py | 3 + 2 files changed, 169 insertions(+), 171 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 66cddb661..7af5e4754 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -369,9 +369,9 @@ def test_approx_igemm(self, dim1, dim2, quant_methods, batched): # print(mean(errors)) # print(mean(relerrors)) - @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 256, n=2), ids=id_formatter("hidden_dim")) - @pytest.mark.parametrize("batch_dim", get_test_dims(16, 256, n=2), ids=id_formatter("batch_dim")) - @pytest.mark.parametrize("seq_dim", get_test_dims(16, 256, n=2), ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("hidden_dim", [32, 256], ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", [16, 256], ids=id_formatter("batch_dim")) + @pytest.mark.parametrize("seq_dim", [16, 256], ids=id_formatter("seq_dim")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) @@ -415,9 +415,9 @@ def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim): torch.testing.assert_close(out.float(), out2) - @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=3), ids=id_formatter("seq_dim")) - @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=3), ids=id_formatter("hidden_dim")) - @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=3), ids=id_formatter("batch_dim")) + @pytest.mark.parametrize("seq_dim", [32, 256, 512], ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("hidden_dim", [64, 1024, 4096], ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", [2, 8, 16], ids=id_formatter("batch_dim")) def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) @@ -431,9 +431,9 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim): torch.testing.assert_close(out.float(), out2) - @pytest.mark.parametrize("seq_dim", get_test_dims(32, 512, n=2), ids=id_formatter("seq_dim")) - @pytest.mark.parametrize("hidden_dim", get_test_dims(32, 1024 * 4, n=2), ids=id_formatter("hidden_dim")) - @pytest.mark.parametrize("batch_dim", get_test_dims(2, 16, n=2), ids=id_formatter("batch_dim")) + @pytest.mark.parametrize("seq_dim", [32, 512], ids=id_formatter("seq_dim")) + @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) + @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): @@ -501,10 +501,10 @@ def min_max(x): assert mean(errs) < 0.015 assert mean(relerrs) < 0.3 - @pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=2), ids=id_formatter("dim1")) - @pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=2), ids=id_formatter("dim2")) - @pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=2), ids=id_formatter("dim3")) - @pytest.mark.parametrize("dim4", get_test_dims(32, 256, n=2), ids=id_formatter("dim4")) + @pytest.mark.parametrize("dim1", [1, 64], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [32, 128], ids=id_formatter("dim2")) + @pytest.mark.parametrize("dim3", [32, 256], ids=id_formatter("dim3")) + @pytest.mark.parametrize("dim4", [32, 256], ids=id_formatter("dim4")) @pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose")) def test_ibmm(self, dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) @@ -760,8 +760,8 @@ def test_coo_int8_vectorwise_quant(self, dim1, dim2): class TestSpMMFunctional: - @pytest.mark.parametrize("dim1", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim1")) - @pytest.mark.parametrize("dim2", get_test_dims(1, 1 * 1024, n=2), ids=id_formatter("dim2")) + @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) + @pytest.mark.parametrize("dim2", [128, 512], ids=id_formatter("dim2")) @pytest.mark.parametrize("transposed_B", TRUE_FALSE, ids=id_formatter("transposed_B")) def test_spmm_coo(self, dim1, dim2, transposed_B): threshold = 1.5 @@ -1096,37 +1096,34 @@ def test_4bit_quant(self, dtype, quant_type, blocksize): assert err.item() < math.log2(blocksize) * 8e-2 @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) - def test_4bit_compressed_stats(self, quant_type): - for blocksize in [128, 64]: - errs1 = [] - errs2 = [] - for i in range(10): - A1 = torch.randn(1024, 1024, device="cuda").half() - q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) - q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) - A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) - A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) - - err = (A1 - A2).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() - err = err.mean() + @pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) + def test_4bit_compressed_stats(self, quant_type, blocksize): + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device="cuda").half() + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3 = F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) - errs1.append(err.item()) + err = (A1 - A2).abs().float() + relerr = (err / (A1.abs().float() + 1e-15)).mean() + err = err.mean() - assert err.item() < 0.11 - assert relerr.item() < 0.28 + errs1.append(err.item()) - err = (A1 - A3).abs().float() - relerr = (err / (A1.abs().float() + 1e-15)).mean() - err = err.mean() + assert err.item() < 0.11 + assert relerr.item() < 0.28 - errs2.append(err.item()) + err = (A1 - A3).abs().float() + relerr = (err / (A1.abs().float() + 1e-15)).mean() + err = err.mean() - assert err.item() < 0.11 - assert relerr.item() < 0.28 + errs2.append(err.item()) - # print(sum(errs1)/len(errs1), blocksize, quant_type) - # print(sum(errs2)/len(errs2), blocksize, quant_type) + assert err.item() < 0.11 + assert relerr.item() < 0.28 # @pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) @pytest.mark.parametrize("quant_type", ["nf4"]) @@ -1169,135 +1166,133 @@ def test_bench_4bit_dequant(self, quant_type): [torch.uint8, torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype, ) - def test_gemv_4bit(self, dtype, storage_type, quant_storage, double_quant, kind): - for dim in [128, 256, 512, 1024]: - # for dim in [4*1024]: - # for dim in [1*16]: - errs1 = [] - errs2 = [] - errs3 = [] - relerrs1 = [] - relerrs2 = [] - relerrs3 = [] - max_errs1 = [] - max_errs2 = [] - max_errs3 = [] + @pytest.mark.parametrize("dim", [128, 256, 512, 1024], ids=id_formatter("dim")) + def test_gemv_4bit(self, dim, dtype, storage_type, quant_storage, double_quant, kind): + errs1 = [] + errs2 = [] + errs3 = [] + relerrs1 = [] + relerrs2 = [] + relerrs3 = [] + max_errs1 = [] + max_errs2 = [] + max_errs3 = [] - for i in range(100): - if kind == "fc1": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "fc2": - A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") - B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - elif kind == "attn_packed": - A = torch.randn(1, dim, dtype=dtype, device="cuda") - B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) - - qB, state = F.quantize_4bit( - B, - quant_type=storage_type, - compress_statistics=double_quant, - quant_storage=quant_storage, - ) - C3 = torch.matmul(A, B.t()) - C2 = F.gemv_4bit(A, qB.t(), state=state) - A.requires_grad = True - C1 = bnb.matmul_4bit(A, qB.t(), state) - - err1 = (C1 - C2).abs().float() - err2 = (C3 - C2).abs().float() - err3 = (C3 - C1).abs().float() - - mag1 = torch.abs(C1).float() + 1e-5 - mag2 = torch.abs(C3).float() + 1e-5 - mag3 = torch.abs(C3).float() + 1e-5 - - relerr1 = err1 / mag1 - relerr2 = err2 / mag2 - relerr3 = err3 / mag3 - - max_err1 = err1.max() - max_err2 = err2.max() - max_err3 = err3.max() - - errs1.append(err1.mean().item()) - errs2.append(err2.mean().item()) - errs3.append(err3.mean().item()) - - relerrs1.append(relerr1.mean().item()) - relerrs2.append(relerr2.mean().item()) - relerrs3.append(relerr3.mean().item()) - - max_errs1.append(max_err1.item()) - max_errs2.append(max_err2.item()) - max_errs3.append(max_err3.item()) - - c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 - - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) - err1 = sum(errs1) / len(errs1) / math.sqrt(dim) - err2 = sum(errs2) / len(errs2) / math.sqrt(dim) - err3 = sum(errs3) / len(errs3) / math.sqrt(dim) - relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) - relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) - relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) - maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) - maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) - maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) - absratio = err2 / err3 - relratio = relerr2 / relerr3 - maxratio = relerr2 / relerr3 - - # for debugging if the tests fails - # - # print('='*80) - # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') - # print(C1.flatten()[-20:]) - # print(C2.flatten()[-20:]) - # print(f'inference vs training abs: {err1}') - # print(f'inference vs training rel: {relerr1}') - # print(f'inference vs training max: {maxerr1}') - # print(f'inference vs training vs torch err ratio abs: {absratio}') - # print(f'inference vs training vs torch err ratio rel: {relratio}') - # print(f'inference vs training vs torch err ratio max: {maxratio}') - if dtype == torch.float16: - if dim <= 512: - assert err1 < 7e-5 - assert relerr1 < 0.0008 - else: - assert err1 < 6e-5 - assert relerr1 < 2e-4 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 - elif dtype == torch.float32: - if dim <= 512: - assert err1 < 5e-8 - assert relerr1 < 1e-6 - assert maxerr1 < 1e-7 - else: - assert err1 < 5e-8 - assert relerr1 < 8e-6 - assert maxerr1 < 1e-7 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.005 and relratio > 0.995 - assert maxratio < 1.005 and maxratio > 0.995 - elif dtype == torch.bfloat16: - if dim <= 512: - assert err1 < 6e-4 - assert relerr1 < 0.007 - assert maxerr1 < 0.015 - else: - assert err1 < 2e-4 - assert relerr1 < 0.002 - assert maxerr1 < 0.0012 - assert absratio < 1.005 and absratio > 0.995 - assert relratio < 1.04 and relratio > 0.96 - assert maxratio < 1.02 and maxratio > 0.98 + for i in range(100): + if kind == "fc1": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 4, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "fc2": + A = torch.randn(1, 4 * dim, dtype=dtype, device="cuda") + B = torch.randn(dim, 4 * dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + elif kind == "attn_packed": + A = torch.randn(1, dim, dtype=dtype, device="cuda") + B = torch.randn(dim * 3, dim, dtype=dtype, device="cuda") / math.sqrt(dim) + + qB, state = F.quantize_4bit( + B, + quant_type=storage_type, + compress_statistics=double_quant, + quant_storage=quant_storage, + ) + C3 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) + + err1 = (C1 - C2).abs().float() + err2 = (C3 - C2).abs().float() + err3 = (C3 - C1).abs().float() + + mag1 = torch.abs(C1).float() + 1e-5 + mag2 = torch.abs(C3).float() + 1e-5 + mag3 = torch.abs(C3).float() + 1e-5 + + relerr1 = err1 / mag1 + relerr2 = err2 / mag2 + relerr3 = err3 / mag3 + + max_err1 = err1.max() + max_err2 = err2.max() + max_err3 = err3.max() + + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + errs3.append(err3.mean().item()) + + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + relerrs3.append(relerr3.mean().item()) + + max_errs1.append(max_err1.item()) + max_errs2.append(max_err2.item()) + max_errs3.append(max_err3.item()) + + c = int(C1.numel() * 0.0014 * (dim / 256)) + 1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=0, throw=False) + err1 = sum(errs1) / len(errs1) / math.sqrt(dim) + err2 = sum(errs2) / len(errs2) / math.sqrt(dim) + err3 = sum(errs3) / len(errs3) / math.sqrt(dim) + relerr1 = sum(relerrs1) / len(relerrs1) / math.sqrt(dim) + relerr2 = sum(relerrs2) / len(relerrs2) / math.sqrt(dim) + relerr3 = sum(relerrs3) / len(relerrs3) / math.sqrt(dim) + maxerr1 = sum(max_errs1) / len(max_errs1) / math.sqrt(dim) + maxerr2 = sum(max_errs2) / len(max_errs2) / math.sqrt(dim) + maxerr3 = sum(max_errs3) / len(max_errs3) / math.sqrt(dim) + absratio = err2 / err3 + relratio = relerr2 / relerr3 + maxratio = relerr2 / relerr3 + + # for debugging if the tests fails + # + # print('='*80) + # print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:') + # print(C1.flatten()[-20:]) + # print(C2.flatten()[-20:]) + # print(f'inference vs training abs: {err1}') + # print(f'inference vs training rel: {relerr1}') + # print(f'inference vs training max: {maxerr1}') + # print(f'inference vs training vs torch err ratio abs: {absratio}') + # print(f'inference vs training vs torch err ratio rel: {relratio}') + # print(f'inference vs training vs torch err ratio max: {maxratio}') + if dtype == torch.float16: + if dim <= 512: + assert err1 < 7e-5 + assert relerr1 < 0.0008 + else: + assert err1 < 6e-5 + assert relerr1 < 2e-4 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 + elif dtype == torch.float32: + if dim <= 512: + assert err1 < 5e-8 + assert relerr1 < 1e-6 + assert maxerr1 < 1e-7 + else: + assert err1 < 5e-8 + assert relerr1 < 8e-6 + assert maxerr1 < 1e-7 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.005 and relratio > 0.995 + assert maxratio < 1.005 and maxratio > 0.995 + elif dtype == torch.bfloat16: + if dim <= 512: + assert err1 < 6e-4 + assert relerr1 < 0.007 + assert maxerr1 < 0.015 + else: + assert err1 < 2e-4 + assert relerr1 < 0.002 + assert maxerr1 < 0.0012 + assert absratio < 1.005 and absratio > 0.995 + assert relratio < 1.04 and relratio > 0.96 + assert maxratio < 1.02 and maxratio > 0.98 @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @@ -1363,9 +1358,9 @@ def test_managed(): assert (A == 17 * (2**3)).sum().item() == n * n -@pytest.mark.parametrize("dim1", get_test_dims(1, 64, n=1), ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", get_test_dims(32, 128, n=1), ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", get_test_dims(32, 256, n=1), ids=id_formatter("dim3")) +@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [64], ids=id_formatter("dim2")) +@pytest.mark.parametrize("dim3", [128], ids=id_formatter("dim3")) @pytest.mark.deprecated def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) diff --git a/tests/test_ops.py b/tests/test_ops.py index 93a2fb68e..8c9c6a646 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -149,6 +149,9 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize if device == "cpu" and quant_type != "nf4": pytest.skip("CPU implementation is only available for nf4") + if storage_dtype != torch.uint8: + pytest.xfail("Known issue with storage_dtype != uint8") + A = torch.randn(1024, 1024, dtype=dtype, device=device) out, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, quant_type, storage_dtype) From 76fb84a24785521b216b22babf3a8b4d81433ff5 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Wed, 26 Mar 2025 16:44:39 -0400 Subject: [PATCH 3/4] Additional deprecations/removals. --- bitsandbytes/functional.py | 297 +++++++++++++++++-------------------- bitsandbytes/optim/adam.py | 207 -------------------------- tests/test_functional.py | 15 -- 3 files changed, 134 insertions(+), 385 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c0e139e03..c81047802 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -251,12 +251,6 @@ def fill(A, value, device=None, prefetch=True): elementwise_func("fill", A, None, value) -@deprecated("Function will be removed in a future release.", category=FutureWarning) -def arange(A, device=None): - elementwise_func("arange", A, None, 0) - - -@deprecated("Function will be removed in a future release.", category=FutureWarning) def _mul(A, B, device=None): elementwise_func("_mul", A, B, 0) @@ -407,6 +401,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): return torch.tensor(data, dtype=torch.float32) +@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) def create_quantile_map(A, total_bits=8): q = estimate_quantiles(A, num_quantiles=2**total_bits - 1) q = q.tolist() @@ -480,17 +475,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def pre_call(device): - prev_device = torch.cuda.current_device() - torch.cuda.set_device(device) - return prev_device - - -@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) -def post_call(prev_device): - torch.cuda.set_device(prev_device) - - def estimate_quantiles( A: Tensor, out: Optional[torch.Tensor] = None, @@ -539,15 +523,16 @@ def estimate_quantiles( if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) - is_on_gpu([A, out]) - device = pre_call(A.device) - if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) - else: - raise NotImplementedError(f"Not supported data type {A.dtype}") - post_call(device) + + with _cuda_device_of(A): + is_on_gpu([A, out]) + + if A.dtype == torch.float32: + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + elif A.dtype == torch.float16: + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + else: + raise NotImplementedError(f"Not supported data type {A.dtype}") if num_quantiles < 256: step = round(256 / num_quantiles) @@ -1219,12 +1204,12 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No torch.Tensor: Quantized 8-bit tensor. """ - prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.uint8) - is_on_gpu([A, out]) - lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) - post_call(prev_device) + with _cuda_device_of(A): + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) + is_on_gpu([A, out]) + lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + return out @@ -1250,13 +1235,13 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = torch.Tensor: 32-bit output tensor. """ - prev_device = pre_call(A.device) - if out is None: - out = torch.zeros_like(A, dtype=torch.float32) - is_on_gpu([code, A, out]) - stream = _get_tensor_stream(A) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) - post_call(prev_device) + with _cuda_device_of(A): + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) + is_on_gpu([code, A, out]) + stream = _get_tensor_stream(A) + lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) + return out @@ -1444,61 +1429,60 @@ def optimizer_update_8bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(max1), - get_ptr(max2), - get_ptr(new_max1), - get_ptr(new_max2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_int32(g.numel()), - ) - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - post_call(prev_device) + with _cuda_device_of(g): + is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + str2optimizer8bit[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + str2optimizer8bit[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) def optimizer_update_8bit_blockwise( @@ -1577,25 +1561,24 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: The current optimization steps (number of past gradient norms). """ - prev_device = pre_call(grad.device) - is_on_gpu([grad, gnorm_vec]) - if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16( - get_ptr(grad), - get_ptr(gnorm_vec), - ct.c_int32(step), - ct.c_int32(grad.numel()), - ) - else: - raise ValueError(f"Gradient type {grad.dtype} not supported!") - post_call(prev_device) + with _cuda_device_of(grad): + is_on_gpu([grad, gnorm_vec]) + if grad.dtype == torch.float32: + lib.cpercentile_clipping_g32( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) + elif grad.dtype == torch.float16: + lib.cpercentile_clipping_g16( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) + else: + raise ValueError(f"Gradient type {grad.dtype} not supported!") current_gnorm = torch.sqrt(gnorm_vec[step % 100]) vals, idx = torch.sort(gnorm_vec) @@ -2333,7 +2316,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) nnz = cooA.nnz - prev_device = pre_call(B.device) + assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz @@ -2370,43 +2353,43 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) - if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8( - ptrMaxCount, - ptrMaxIdx, - ptrOffset, - ptrRowidx, - ptrColidx, - ptrValues, - ptrB, - ptrC, - ptrDequantStats, - cnnz_rows, - cnnz, - crowsA, - crowsB, - ccolsB, - ) - # else: assertion error - post_call(prev_device) + with _cuda_device_of(B): + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) + if B.dtype == torch.float16: + lib.cspmm_coo_very_sparse_naive_fp16( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + elif B.dtype == torch.int8: + lib.cspmm_coo_very_sparse_naive_int8( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + # else: assertion error return out @@ -2463,18 +2446,6 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): return None -@deprecated( - "This function is deprecated and will be removed in a future release.", - category=FutureWarning, -) -def vectorwise_dequant(xq, max1, quant_type="vector"): - if quant_type == "vector": - x = (xq / C * max1).to(torch.float32) - return x - else: - return None - - @deprecated( "This function is deprecated and will be removed in a future release.", category=FutureWarning, diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 740db26ac..1a8800843 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -3,13 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import math -import os -import torch -import torch.distributed as dist - -import bitsandbytes.functional as F from bitsandbytes.optim.optimizer import Optimizer2State @@ -377,204 +371,3 @@ def __init__( block_wise, is_paged=True, ) - - -class AnalysisAdam(torch.optim.Optimizer): - """Adam that performs 8-bit vs 32-bit error analysis. - - This implementation is modified from torch.optim.Adam based on: - `Fixed Weight Decay Regularization in Adam` - (see https://arxiv.org/abs/1711.05101) - - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - bnb_analysis="dynamic-blockwise", - savedir=None, - ): - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - amsgrad=amsgrad, - ) - super().__init__(params, defaults) - self.analysis = bnb_analysis - self.savedir = savedir - - @property - def supports_memory_efficient_fp16(self): - return True - - @property - def supports_flat_params(self): - return True - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p_id, p in enumerate(group["params"]): - if p.grad is None: - continue - grad = p.grad.data - if grad.dtype in {torch.float16, torch.bfloat16}: - grad = grad.float() - if grad.is_sparse: - raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead") - amsgrad = group.get("amsgrad", False) - assert not amsgrad - - p_data_fp32 = p.data - if p.data.dtype in {torch.float16, torch.bfloat16}: - p_data_fp32 = p_data_fp32.float() - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like(p_data_fp32) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state["abserrors"] = torch.zeros((256, 256), device=p_data_fp32.device) - state["relerrors"] = torch.zeros((256, 256), device=p_data_fp32.device) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) - else: - state["exp_avg"] = state["exp_avg"].to(p_data_fp32) - state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) - if amsgrad: - state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(p_data_fp32) - - state["step"] += 1 - beta1, beta2 = group["betas"] - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - e = state["abserrors"] - rele = state["relerrors"] - counts = state["counts"] - - if group["weight_decay"] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - if amsgrad: - max_exp_avg_sq = state["max_exp_avg_sq"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - denom = exp_avg_sq.sqrt().add_(group["eps"]) - update_fp32 = exp_avg / denom - - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: - # embedding layer or too small - p_data_fp32 += -step_size * update_fp32 - else: - if self.analysis == "dynamic-blockwise": - code1 = F.create_dynamic_map(signed=True).to(p.device) - code2 = F.create_dynamic_map(signed=False).to(p.device) - C1, S1 = F.quantize_blockwise(exp_avg, code=code1) - state1 = F.dequantize_blockwise(C1, S1) - C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) - state2 = F.dequantize_blockwise(C2, S2) - elif self.analysis == "dynamic": - code1 = F.create_dynamic_map(signed=True).to(p.device) - code2 = F.create_dynamic_map(signed=False).to(p.device) - C1, S1 = F.quantize(exp_avg, code=code1) - state1 = F.dequantize(C1, S1) - C2, S2 = F.quantize(exp_avg_sq, code=code2) - state2 = F.dequantize(C2, S2) - elif self.analysis == "linear": - code1 = F.create_linear_map(signed=True).to(p.device) - code2 = F.create_linear_map(signed=False).to(p.device) - C1, S1 = F.quantize(exp_avg, code=code1) - state1 = F.dequantize(C1, S1) - C2, S2 = F.quantize(exp_avg_sq, code=code2) - state2 = F.dequantize(C2, S2) - elif self.analysis == "quantile": - code1 = F.estimate_quantiles(exp_avg) - code2 = F.estimate_quantiles(exp_avg_sq) - C1 = F.quantize_no_absmax(exp_avg, code=code1) - state1 = F.dequantize_no_absmax(C1, code1) - C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) - state2 = F.dequantize_no_absmax(C2, code2) - elif self.analysis == "my-quantization-routine": - pass - # 1. get code - # 2. quantize - # 3. dequantize - # Error will be calculated automatically! - else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") - - denom = state2.sqrt().add_(group["eps"]) - update_8bit = state1 / denom - - abserr = torch.abs(update_8bit - update_fp32) - relerr = abserr / torch.abs(update_fp32 + 1e-6) - - C1, C2 = C1.int(), C2.int() - - F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) - F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) - - p_data_fp32 += -step_size * update_fp32 - - if not dist.is_initialized() or dist.get_rank() == 0: - if self.savedir != "" and state["step"] % 100 == 0: - if not os.path.exists(self.savedir): - os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f"{p_id}_{shapestr}_abserr.pkl") - pathrele = os.path.join(self.savedir, f"{p_id}_{shapestr}_relerr.pkl") - pathcounts = os.path.join(self.savedir, f"{p_id}_{shapestr}_counts.pkl") - torch.save(e, pathe) - torch.save(rele, pathrele) - torch.save(counts, pathcounts) - - if p.data.dtype in {torch.float16, torch.bfloat16}: - p.data.copy_(p_data_fp32) - - return loss diff --git a/tests/test_functional.py b/tests/test_functional.py index 7af5e4754..9396f5240 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1356,18 +1356,3 @@ def test_managed(): F._mul(A, B2) F._mul(A, B2) assert (A == 17 * (2**3)).sum().item() == n * n - - -@pytest.mark.parametrize("dim1", [32], ids=id_formatter("dim1")) -@pytest.mark.parametrize("dim2", [64], ids=id_formatter("dim2")) -@pytest.mark.parametrize("dim3", [128], ids=id_formatter("dim3")) -@pytest.mark.deprecated -def test_vector_quant(dim1, dim2, dim3): - dim2 = dim2 - (dim2 % 16) - dim3 = dim3 - (dim3 % 16) - for i in range(k): - A = torch.randn(size=(dim2, dim3), device="cuda") - qA, SA = F.vectorwise_quant(A, dim=0) - A1 = F.vectorwise_dequant(qA, SA) - n = A1.numel() - assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) From 7a78a4869ea0d9b9b33274c3fed9584f692556c1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 27 Mar 2025 13:14:09 -0400 Subject: [PATCH 4/4] Skip benchmark, deprecated, slow tests by default --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f4ae66a8e..2c83d7e0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ include = ["bitsandbytes*"] version = {attr = "bitsandbytes.__version__"} [tool.pytest.ini_options] -addopts = "-rP" +addopts = "-rP -m 'not slow and not benchmark and not deprecated'" # ; --cov=bitsandbytes # ; # contexts: record which test ran which line; can be seen in html coverage report # ; --cov-context=test