From 15c8c48772fea2f010fdf060f32a49405f9f9d46 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 15 Apr 2025 10:23:34 -0400 Subject: [PATCH] Support LLM.int8() inference with torch.compile --- bitsandbytes/_ops.py | 26 ++++++++++++++++++ bitsandbytes/autograd/_functions.py | 41 +++++++++++----------------- bitsandbytes/backends/cuda/ops.py | 42 +++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 25 deletions(-) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 2a12e40a1..ceba71572 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -15,6 +15,32 @@ register_fake = torch.library.impl_abstract register_kernel = torch.library.impl +# Int8 mixed precision matmul + dequant + bias +torch.library.define( + "bitsandbytes::int8_mixed_scaled_mm", + "(Tensor A, Tensor CA, Tensor CB, Tensor SCA, Tensor SCB, Tensor? outlier_cols=None, Tensor? bias=None) -> (Tensor, Tensor?)", +) + + +@register_fake("bitsandbytes::int8_mixed_scaled_mm") +def _( + A: torch.Tensor, + CA: torch.Tensor, + CB: torch.Tensor, + SCA: torch.Tensor, + SCB: torch.Tensor, + outlier_cols: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + shapeC = (*CA.shape[:-1], CB.shape[0]) + + out = torch.empty(shapeC, device=A.device, dtype=A.dtype) + + outlier_cols = torch.library.get_ctx().new_dynamic_size() + subA = A.new_empty(outlier_cols, dtype=torch.int64) + + return out, subA + # Higher level op: int8 matmul + dequant + bias torch.library.define( diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 5df8a0979..7fa846d92 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -210,37 +210,28 @@ def forward( # 2. Quantize B state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) - # Handle sparse decomposition. In some instances, we may have not found any - # outlier columns at all. In that case, we'll skip this part completely. - if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel(): + # Handle sparse decomposition + if state.threshold > 0.0: state.idx = outlier_cols - # Zero out the outliers in the transposed 8bit inputs. - if CAt is not None: - CAt[:, state.idx] = 0 - - # Extract the input outliers in original precision - subA = A[:, state.idx].contiguous() + # Mixed Int8 Matmul + Dequant + Bias + output, subA = torch.ops.bitsandbytes.int8_mixed_scaled_mm( + A, + CA, + state.CB, + SCA, + state.SCB, + outlier_cols, + bias, + ) - # Extract the corresponding weights - if state.has_fp16_weights: - state.subB = B[:, state.idx].t() - else: - # To dequantize our weights associated with the input outliers, - # we want to divide by 127. It's however more performant to multiply - # by the reciprocal. - outliers = state.CB[:, state.idx] - state.subB = F.int8_vectorwise_dequant(outliers, state.SCB).to(A.dtype).t() else: + # Int8 Matmul + Dequant + Bias + output = torch.ops.bitsandbytes.int8_scaled_mm.default( + CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype + ) subA = None - # 3. Int8 Matmul + Dequant + Bias - output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, state.CB, SCA, state.SCB, bias=bias, dtype=A.dtype) - - # 4. Mixed-precision decomposition matmul - if subA is not None and state.subB is not None: - output = output.addmm(subA, state.subB) - # 5. Save state ctx.state = state diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index c921af53a..783a32894 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -22,6 +22,45 @@ def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): _int8_linear_matmul_impl(A, B, out) +@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "cuda") +def _( + A: torch.Tensor, + CA: torch.Tensor, + CB: torch.Tensor, + SCA: torch.Tensor, + SCB: torch.Tensor, + outlier_cols: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + subB = None + + if outlier_cols is not None and outlier_cols.numel(): + # Extract the inputs with outliers in original precision + subA = A[:, outlier_cols].contiguous() + + # Dequantize the corresponding weight columns + subB = ( + torch.ops.bitsandbytes.int8_vectorwise_dequant.default(CB[:, outlier_cols].contiguous(), SCB) + .to(A.dtype) + .t() + ) + + # TODO: if state.has_fp16_weights: subB = B[:, outlier_cols].t() + + else: + # Needed for torch.compile when there are no outliers. + subA = torch.empty(0, device=A.device, dtype=A.dtype) + + # Int8 Matmul + Dequant + Bias + output = torch.ops.bitsandbytes.int8_scaled_mm.default(CA, CB, SCA, SCB, bias=bias, dtype=A.dtype) + + if subB is not None: + # Add the outlier columns back to the output + output = output.addmm(subA, subB) + + return output, subA + + def _int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): A, B = B, A @@ -143,6 +182,9 @@ def _(A: torch.Tensor, threshold=0.0): if outliers.any(): outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) + else: + # Needed for torch.compile support. + outlier_cols = torch.empty(0, device=A.device, dtype=torch.int64) with _cuda_device_of(A): lib.cint8_vector_quant(