diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 0f97cdd08..5df8a0979 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -84,6 +84,13 @@ def get_inverse_transform_indices( return permuted_tile_indices +# torch.compiler.is_compiling() is available only in torch >= 2.3 +if hasattr(torch.compiler, "is_compiling"): + _is_compiling = torch.compiler.is_compiling +else: + _is_compiling = torch._dynamo.is_compiling + + @deprecated( "This function is deprecated and will be removed in a future release.", category=FutureWarning, @@ -174,7 +181,7 @@ def forward( input_shape = A.shape # Cast A to fp16 - if A.dtype != torch.float16: + if A.dtype != torch.float16 and not _is_compiling(): warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") if len(A.shape) == 3: