From d9aa37aaf21f16768aa2e9c80ca7dfaa5c694d44 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:59:47 -0400 Subject: [PATCH] Improve torch.compile support for int8 with torch>=2.8 nightly --- bitsandbytes/autograd/_functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 85db6366b..c7ad3a82c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -236,7 +236,8 @@ def forward( ctx.state = state ctx.grad_shape = input_shape - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + ctx.dtype_A = A.dtype + ctx.dtype_bias = None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): ctx.tensors = (CAt, subA, A)