diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 85db6366b..c7ad3a82c 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -236,7 +236,8 @@ def forward( ctx.state = state ctx.grad_shape = input_shape - ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + ctx.dtype_A = A.dtype + ctx.dtype_bias = None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): ctx.tensors = (CAt, subA, A)