From 26077d250597c7c8891cab03044ca12a8bda4270 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:26:44 -0400 Subject: [PATCH] Improvement for torch.compile support on Params4bit --- bitsandbytes/nn/modules.py | 7 ------- tests/test_linear4bit.py | 5 +---- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5a3bc9e04..a9cc60dc1 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -291,13 +291,6 @@ def from_prequantized( return self - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - def _quantize(self, device): w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index b5db2eb6f..f28bfa29e 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -270,10 +270,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): - if device == "cpu" and quant_type == "fp4": - pytest.skip("FP4 is not supported for CPU") - - if fullgraph and torch.__version__ < (2, 8): + if fullgraph and torch.__version__ < (2, 8, 0, "dev"): pytest.skip("fullgraph mode requires torch 2.8 or higher") if device == "cuda" and platform.system() == "Windows":