diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index 436676c99..107f26c84 100644 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -12,6 +12,8 @@ logger = logging.getLogger(__name__) +_has_avx512 = torch.backends.cpu.get_cpu_capability() == "AVX512" + # torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+. # However, we can overflow if we use this without AVX512_VNNI support. # This is fixed in torch 2.6+, so we set this as the minimum to be safe. @@ -134,8 +136,14 @@ def _( lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) + # Fallback as AVX512 implementation has accuracy issues with fp16/fp32 and blocksize >= 2048 + # Note: this is not a common use case. + avx512_fallback = _has_avx512 and blocksize >= 2048 and dtype != torch.bfloat16 + # Odd shape is not supported by this kernel; fallback to generic implementation - if shape[-1] % 2 != 0: + shape_fallback = shape[-1] % 2 != 0 + + if avx512_fallback or shape_fallback: from ..default.ops import _dequantize_4bit_impl return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)