From 77d75d98a3432057e00386432f216660f55ab78d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 1 Apr 2025 13:13:57 -0400 Subject: [PATCH] Fix torch.compile issue for LLM.int8() with threshold=0 --- bitsandbytes/autograd/_functions.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0f97cdd08..5df8a0979 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -84,6 +84,13 @@ def get_inverse_transform_indices( return permuted_tile_indices +# torch.compiler.is_compiling() is available only in torch >= 2.3 +if hasattr(torch.compiler, "is_compiling"): + _is_compiling = torch.compiler.is_compiling +else: + _is_compiling = torch._dynamo.is_compiling + + @deprecated( "This function is deprecated and will be removed in a future release.", category=FutureWarning, @@ -174,7 +181,7 @@ def forward( input_shape = A.shape # Cast A to fp16 - if A.dtype != torch.float16: + if A.dtype != torch.float16 and not _is_compiling(): warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") if len(A.shape) == 3: