diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f1a5dca69..b93bff4f0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -137,6 +137,10 @@ jobs: with: python-version: 3.9 + - name: Setup MSVC + if: startsWith(matrix.os, 'windows') + uses: ilammy/msvc-dev-cmd@v1.13.0 # to use cl for torch.compile + - name: Install dependencies run: | pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/cpu @@ -201,18 +205,40 @@ jobs: torch_version: "2.7.0" pypi_index: "https://download.pytorch.org/whl/cu128" - # L40S runners + + # Linux L40S runners - os: ubuntu-22.04 gpu: L40S runner: bandb-aws-g6e-4xlarge-plus-use1-public-80 - # T4 runners + # Linux T4 runners - os: ubuntu-22.04 gpu: T4 runner: bandb-aws-g4dn-4xlarge-plus-use1-public-80 + + # Specific Windows runners using cu118 + - os: windows-2025 + arch: x86_64 + gpu: T4 + runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.2.0" + pypi_index: "https://download.pytorch.org/whl/cu118" - os: windows-2025 + arch: x86_64 + gpu: T4 + runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.6.0" + pypi_index: "https://download.pytorch.org/whl/cu118" + - os: windows-2025 + arch: x86_64 gpu: T4 runner: CUDA-Windows-x64 + cuda_version: "11.8.0" + torch_version: "2.7.0" + pypi_index: "https://download.pytorch.org/whl/cu118" + exclude: # Our current T4 Windows runner has a driver too old (471.11) # and cannot support CUDA 12+. Skip for now. diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b0092ffd1..f84f16c21 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -771,14 +771,14 @@ def quantize_blockwise( qabsmax, state2 = quantize_blockwise(_absmax, blocksize=blocksize, nested=False) quant_state = QuantState( absmax=qabsmax, - code=code, + code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype, offset=offset, state2=state2, ) else: - quant_state = QuantState(absmax=_absmax, code=code.to(A.device), blocksize=blocksize, dtype=A.dtype) + quant_state = QuantState(absmax=_absmax, code=code.to(A.device, copy=True), blocksize=blocksize, dtype=A.dtype) # TODO(matthewdouglas): Deprecate out kwarg out = out.copy_(_out) if out is not None else _out diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 937084cf1..500102ab1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -493,7 +493,7 @@ def forward(self, x: torch.Tensor): bias = None if self.bias is None else self.bias.to(self.compute_dtype) - return bnb.matmul_4bit(x, self.weight.data.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) + return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype) class LinearFP4(Linear4bit): diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 67b61cb05..f3673797c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -1,13 +1,21 @@ import copy import os import pickle +import platform from tempfile import TemporaryDirectory import pytest import torch import bitsandbytes as bnb -from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, torch_load_from_buffer, torch_save_to_buffer +from tests.helpers import ( + TRUE_FALSE, + describe_dtype, + get_available_devices, + id_formatter, + torch_load_from_buffer, + torch_save_to_buffer, +) storage = { "uint8": torch.uint8, @@ -275,3 +283,85 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s # there was a bug where deepcopy would modify the original object assert dict_keys_before == dict_keys_after assert dict_keys_before == dict_keys_deserialized + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("compute_dtype", [torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) +@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) +@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): + if device == "cpu" and quant_type == "fp4": + pytest.skip("FP4 is not supported for CPU") + + if fullgraph and torch.__version__ < (2, 8): + pytest.skip("fullgraph mode requires torch 2.8 or higher") + + if device == "cuda" and platform.system() == "Windows": + pytest.skip("Triton is not officially supported on Windows") + + # Has a strange regression on Linux aarch64 CPU in torch==2.6.0 when fullgraph=False. + if ( + not fullgraph + and device == "cpu" + and platform.machine() == "aarch64" + and platform.system() == "Linux" + and ((2, 7) > torch.__version__ >= (2, 6)) + ): + pytest.xfail("Regression in torch==2.6.0 on Linux aarch64 CPU") + + dim = 256 + batch_size = 16 + + torch.compiler.reset() + + # Create a small network with Linear4bit layers + net = torch.nn.Sequential( + *[ + bnb.nn.Linear4bit( + dim, + dim, + bias=bias, + compute_dtype=compute_dtype, + compress_statistics=compress_statistics, + quant_type=quant_type, + ) + for _ in range(4) + ] + ).to(device) + + # Create input tensor + x = torch.randn(batch_size, dim, dtype=compute_dtype, device=device) + + # Get reference output before compilation + with torch.no_grad(): + ref_output = net(x) + + # Compile the model + compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode) + + # Get output from compiled model + with torch.no_grad(): + compiled_output = compiled_net(x) + + # Check outputs match + assert compiled_output.shape == ref_output.shape + assert compiled_output.device == ref_output.device + assert compiled_output.dtype == ref_output.dtype + torch.testing.assert_close(compiled_output, ref_output) + + # Test with gradients + x.requires_grad_(True) + y1 = net(x).sum() + y1.backward() + grad_ref = x.grad.clone() + + x.grad = None + y2 = compiled_net(x).sum() + y2.backward() + grad_compiled = x.grad.clone() + + torch.testing.assert_close(grad_compiled, grad_ref) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 8c08cfa2c..a77c693e0 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -2,6 +2,7 @@ import copy import os import pickle +import platform from tempfile import TemporaryDirectory import pytest @@ -224,3 +225,68 @@ def test_linear8bit_serialization(linear8bit): # check for a bug where SCB and CB were not copied assert (linear8bit.weight.SCB == deserialized.weight.SCB).all() assert (linear8bit.weight.CB == deserialized.weight.CB).all() + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("threshold", [0.0, 6.0], ids=id_formatter("threshold")) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) +@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) +@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): + if device == "cuda" and platform.system() == "Windows": + pytest.skip("Triton is not officially supported on Windows") + + dim = 256 + batch_size = 16 + + torch.compiler.reset() + + # Create a small network with Linear8bitLt layers + net = torch.nn.Sequential( + *[bnb.nn.Linear8bitLt(dim, dim, bias=bias, has_fp16_weights=False, threshold=threshold) for _ in range(4)] + ).to(device) + + dynamic_output_shapes = fullgraph and threshold > 0 + with torch._dynamo.config.patch("capture_dynamic_output_shape_ops", dynamic_output_shapes): + # Create input tensor + x = torch.randn(batch_size, dim, dtype=torch.float16, device=device) + + # Get reference output before compilation + with torch.no_grad(): + ref_output = net(x) + + # Compile the model + compiled_net = torch.compile(net, fullgraph=fullgraph, mode=mode) + + # Get output from compiled model + with torch.no_grad(): + compiled_output = compiled_net(x) + + # Check outputs match + assert compiled_output.shape == ref_output.shape + assert compiled_output.device == ref_output.device + assert compiled_output.dtype == ref_output.dtype + torch.testing.assert_close(compiled_output, ref_output) + + # Test with gradients. Currently only works with threshold=0. + # Has a strange regression on Linux aarch64 CPU in torch==2.6.0. + is_broken_platform = ( + device == "cpu" + and platform.machine() == "aarch64" + and platform.system() == "Linux" + and ((2, 7) > torch.__version__ >= (2, 6)) + ) + + if threshold == 0 and not is_broken_platform: + x.requires_grad_(True) + y1 = net(x).sum() + y1.backward() + grad_ref = x.grad.clone() + + x.grad = None + y2 = compiled_net(x).sum() + y2.backward() + grad_compiled = x.grad.clone() + + torch.testing.assert_close(grad_compiled, grad_ref)