From f176eabd266a837f1eb6bdd29afee29e500cc721 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 23 May 2025 20:10:48 -0400 Subject: [PATCH 1/3] Add torch.compile tests --- bitsandbytes/functional.py | 4 +- bitsandbytes/nn/modules.py | 2 +- tests/test_linear4bit.py | 78 +++++++++++++++++++++++++++++++++++++- tests/test_linear8bitlt.py | 55 +++++++++++++++++++++++++++ 4 files changed, 135 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..f84f16c21 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -771,14 +771,14 @@ def quantize_blockwise( qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) quant_state = QuantState( absmax=qabsmax, - code=code, + code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2, ) else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype) # TODO(matthewdouglas): Deprecate out kwarg out = out.copy_(_out) if out is not None else _out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 937084cf1..500102ab1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -493,7 +493,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67b61cb05..d665b0017 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -7,7 +7,14 @@ import torch import bitsandbytes as bnb -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer +from tests.helpers import ( + TRUE_FALSE, + describe_dtype, + get_available_devices, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) storage = { "uint8": torch.uint8, @@ -275,3 +282,72 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s # there was a bug where deepcopy would modify the original object assert dict_keys_before == dict_keys_after assert dict_keys_before == dict_keys_deserialized + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) +@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) +@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): + if device == "cpu" and quant_type == "fp4": + pytest.skip("FP4 is not supported for CPU") + + if fullgraph and torch.__version__ < (2, 8): + pytest.skip("fullgraph mode requires torch 2.8 or higher") + + dim = 256 + batch_size = 16 + + torch.compiler.reset() + + # Create a small network with Linear4bit layers + net = torch.nn.Sequential( + *[ + bnb.nn.Linear4bit( + dim, + dim, + bias=bias, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + for _ in range(4) + ] + ).to(device) + + # Create input tensor + x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device) + + # Get reference output before compilation + with torch.no_grad(): + ref_output = net(x) + + # Compile the model + compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode) + + # Get output from compiled model + with torch.no_grad(): + compiled_output = compiled_net(x) + + # Check outputs match + assert compiled_output.shape == ref_output.shape + assert compiled_output.device == ref_output.device + assert compiled_output.dtype == ref_output.dtype + torch.testing.assert_close(compiled_output, ref_output) + + # Test with gradients + x.requires_grad_(True) + y1 = net(x).sum() + y1.backward() + grad_ref = x.grad.clone() + + x.grad = None + y2 = compiled_net(x).sum() + y2.backward() + grad_compiled = x.grad.clone() + + torch.testing.assert_close(grad_compiled, grad_ref) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 8c08cfa2c..58705fde3 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -224,3 +224,58 @@ def test_linear8bit_serialization(linear8bit): # check for a bug where SCB and CB were not copied assert (linear8bit.weight.SCB == deserialized.weight.SCB).all() assert (linear8bit.weight.CB == deserialized.weight.CB).all() + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) +@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) +@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): + dim = 256 + batch_size = 16 + + torch.compiler.reset() + + torch._dynamo.config.patch() + # Create a small network with Linear8bitLt layers + net = torch.nn.Sequential( + *[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)] + ).to(device) + + dynamic_output_shapes = fullgraph and threshold > 0 + with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes): + # Create input tensor + x = torch.randn(batch_size, dim, dtype=torch.float16, device=device) + + # Get reference output before compilation + with torch.no_grad(): + ref_output = net(x) + + # Compile the model + compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode) + + # Get output from compiled model + with torch.no_grad(): + compiled_output = compiled_net(x) + + # Check outputs match + assert compiled_output.shape == ref_output.shape + assert compiled_output.device == ref_output.device + assert compiled_output.dtype == ref_output.dtype + torch.testing.assert_close(compiled_output, ref_output) + + # Test with gradients. Currently only works with threshold=0. + if threshold == 0: + x.requires_grad_(True) + y1 = net(x).sum() + y1.backward() + grad_ref = x.grad.clone() + + x.grad = None + y2 = compiled_net(x).sum() + y2.backward() + grad_compiled = x.grad.clone() + + torch.testing.assert_close(grad_compiled, grad_ref) From 9c49d098c4235b7bc606716f51d2cfe957638db1 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 23 May 2025 22:05:00 -0400 Subject: [PATCH 2/3] Tests: WA aarch64 CPU regressions for torch 2.6.0; add Windows torch==2.7.0+cu118 test config --- .github/workflows/tests.yml | 13 +++++++++++++ tests/test_linear4bit.py | 11 +++++++++++ tests/test_linear8bitlt.py | 12 ++++++++++-- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5d2a2708b..abd188c2c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -129,6 +129,10 @@ jobs: with: python-version: 3.9 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl for torch.compile + - name: Install dependencies run: | pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu @@ -188,6 +192,15 @@ jobs: torch_version: "2.7.0" pypi_index: "https://download.pytorch.org/whl/cu128" + # Add torch 2.7+cu118 for Windows. + - os: windows-2025 + arch: x86_64 + gpu: T4 + runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.7.0" + pypi_index: "https://download.pytorch.org/whl/cu118" + # L40S runners - os: ubuntu-22.04 gpu: L40S diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index d665b0017..d7355d531 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,6 +1,7 @@ import copy import os import pickle +import platform from tempfile import TemporaryDirectory import pytest @@ -299,6 +300,16 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st if fullgraph and torch.__version__ < (2, 8): pytest.skip("fullgraph mode requires torch 2.8 or higher") + # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False. + if ( + not fullgraph + and device == "cpu" + and platform.machine() == "aarch64" + and platform.system() == "Linux" + and ((2, 7) > torch.__version__ >= (2, 6)) + ): + pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU") + dim = 256 batch_size = 16 diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 58705fde3..2dd222789 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -2,6 +2,7 @@ import copy import os import pickle +import platform from tempfile import TemporaryDirectory import pytest @@ -238,7 +239,6 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): torch.compiler.reset() - torch._dynamo.config.patch() # Create a small network with Linear8bitLt layers net = torch.nn.Sequential( *[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)] @@ -267,7 +267,15 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): torch.testing.assert_close(compiled_output, ref_output) # Test with gradients. Currently only works with threshold=0. - if threshold == 0: + # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. + is_broken_platform = ( + device == "cpu" + and platform.machine() == "aarch64" + and platform.system() == "Linux" + and ((2, 7) > torch.__version__ >= (2, 6)) + ) + + if threshold == 0 and not is_broken_platform: x.requires_grad_(True) y1 = net(x).sum() y1.backward() From fef62f71928d6922cd3fbe5934b948e51822e29f Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 23 May 2025 23:08:43 -0400 Subject: [PATCH 3/3] Tests: skip torch.compile for cuda on windows --- .github/workflows/tests.yml | 33 +++++++++++++++++++++++---------- tests/test_linear4bit.py | 3 +++ tests/test_linear8bitlt.py | 3 +++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2bf9cbc27..b93bff4f0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -205,27 +205,40 @@ jobs: torch_version: "2.7.0" pypi_index: "https://download.pytorch.org/whl/cu128" - # Add torch 2.7+cu118 for Windows. - - os: windows-2025 - arch: x86_64 - gpu: T4 - runner: CUDA-Windows-x64 - cuda_version: "11.8.0" - torch_version: "2.7.0" - pypi_index: "https://download.pytorch.org/whl/cu118" - # L40S runners + # Linux L40S runners - os: ubuntu-22.04 gpu: L40S runner: bandb-aws-g6e-4xlarge-plus-use1-public-80 - # T4 runners + # Linux T4 runners - os: ubuntu-22.04 gpu: T4 runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80 + + # Specific Windows runners using cu118 + - os: windows-2025 + arch: x86_64 + gpu: T4 + runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.2.0" + pypi_index: "https://download.pytorch.org/whl/cu118" - os: windows-2025 + arch: x86_64 gpu: T4 runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.6.0" + pypi_index: "https://download.pytorch.org/whl/cu118" + - os: windows-2025 + arch: x86_64 + gpu: T4 + runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.7.0" + pypi_index: "https://download.pytorch.org/whl/cu118" + exclude: # Our current T4 Windows runner has a driver too old (471.11) # and cannot support CUDA 12+. Skip for now. diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index d7355d531..f3673797c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -300,6 +300,9 @@ def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_st if fullgraph and torch.__version__ < (2, 8): pytest.skip("fullgraph mode requires torch 2.8 or higher") + if device == "cuda" and platform.system() == "Windows": + pytest.skip("Triton is not officially supported on Windows") + # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False. if ( not fullgraph diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 2dd222789..a77c693e0 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -234,6 +234,9 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): + if device == "cuda" and platform.system() == "Windows": + pytest.skip("Triton is not officially supported on Windows") + dim = 256 batch_size = 16