From 31440005442472fc2d776c886abc3629be4aaa12 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 10 Apr 2025 14:19:22 -0400 Subject: [PATCH] Fix #1588 - torch compatability for <=2.4 --- bitsandbytes/_ops.py | 14 +++++++------- bitsandbytes/backends/cpu/ops.py | 4 ++-- bitsandbytes/backends/cuda/ops.py | 4 ++-- bitsandbytes/backends/default/ops.py | 17 +++++++++++------ 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py index 2a12e40a1..451a1e0ef 100644 --- a/bitsandbytes/_ops.py +++ b/bitsandbytes/_ops.py @@ -19,7 +19,7 @@ # Higher level op: int8 matmul + dequant + bias torch.library.define( "bitsandbytes::int8_scaled_mm", - "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor", + "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None) -> Tensor", ) @@ -30,10 +30,10 @@ def _( row_stats: torch.Tensor, col_stats: torch.Tensor, bias: Optional[torch.Tensor] = None, - dtype=torch.float16, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: shapeC = (*A.shape[:-1], B.shape[0]) - return torch.empty(shapeC, device=A.device, dtype=dtype) + return torch.empty(shapeC, device=A.device, dtype=dtype or torch.float16) torch.library.define( @@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor: # Default PyTorch-native implementation -@register_kernel("bitsandbytes::int8_vectorwise_dequant", None) +@register_kernel("bitsandbytes::int8_vectorwise_dequant", "default") def _(A: torch.Tensor, stats: torch.Tensor): # To dequantize we divide by 127, or multiply by the reciprocal. return A * stats.view(-1, 1) * 7.874015718698502e-3 @@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor): torch.library.define( "bitsandbytes::int8_mm_dequant", - "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor", + "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None, Tensor? bias=None) -> Tensor", ) @@ -115,11 +115,11 @@ def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - dtype=torch.float16, + dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: "A must be int32") - return torch.empty_like(A, dtype=dtype) + return torch.empty_like(A, dtype=dtype or torch.float16) torch.library.define( diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ac906b7ec..b7513c4d3 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -28,7 +28,7 @@ def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - dtype=torch.float16, + dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") @@ -43,7 +43,7 @@ def _( if bias is not None: out += bias - return out.to(dtype) + return out.to(dtype or torch.float16) @register_kernel("bitsandbytes::quantize_blockwise", "cpu") diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index c921af53a..dd9f2f9f7 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -90,7 +90,7 @@ def _( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - dtype=torch.float16, + dtype: Optional[torch.dtype] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") @@ -121,7 +121,7 @@ def _( if bias is not None and bias.dtype != torch.float16: out.add_(bias) - return out.to(dtype) + return out.to(dtype or torch.float16) @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 33bb97f8c..6e581038d 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -5,26 +5,31 @@ from ..._ops import register_kernel -@register_kernel("bitsandbytes::int8_scaled_mm", None) +@register_kernel("bitsandbytes::int8_scaled_mm", "default") def _( A: torch.Tensor, B: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, bias: Optional[torch.Tensor] = None, - dtype=torch.float16, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) - out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias) - return out + return torch.ops.bitsandbytes.int8_mm_dequant.default( + out_i32, + row_stats, + col_stats, + dtype=dtype or torch.float16, + bias=bias, + ) -@register_kernel("bitsandbytes::int8_linear_matmul", None) +@register_kernel("bitsandbytes::int8_linear_matmul", "default") def _(A: torch.Tensor, B: torch.Tensor): return _int8_linear_matmul_impl(A, B) -@register_kernel("bitsandbytes::int8_linear_matmul.out", None) +@register_kernel("bitsandbytes::int8_linear_matmul.out", "default") def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): torch._check(out.dtype == torch.int32) _int8_linear_matmul_impl(A, B, out)