diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 0da9eac94..d5ab9aa88 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -1,6 +1,5 @@ from collections.abc import Sequence import ctypes as ct -from typing import Optional import torch @@ -24,29 +23,6 @@ def _(A: torch.Tensor, B: torch.Tensor): ).reshape(*A.shape[:-1], B.shape[0]) -@register_kernel("bitsandbytes::int8_mm_dequant", "cpu") -def _( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - dtype: Optional[torch.dtype] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") - torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") - torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") - - A_calc = A.view(-1, A.shape[-1]) - row_stats = row_stats.reshape(-1).unsqueeze(-1) - col_stats = col_stats.reshape(-1).unsqueeze(0) - - out = A_calc * (row_stats * col_stats) * 6.200124e-05 - if bias is not None: - out += bias - - return out.to(dtype or torch.float16) - - @register_kernel("bitsandbytes::quantize_blockwise", "cpu") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 20e596f25..729c2b047 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -6,6 +6,29 @@ from ..._ops import register_kernel +@register_kernel("bitsandbytes::int8_mm_dequant", "default") +def _( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) + + out = A_calc * (row_stats * col_stats) * 6.200124e-05 + if bias is not None: + out += bias + + return out.to(dtype or torch.float16) + + @register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default") def _( A: torch.Tensor,