diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c9341230f..fa032af91 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -958,7 +958,34 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) + """Quantize tensor A in blocks to fp4 values. + + Quantizes tensor A by dividing it into blocks which are independently quantized. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`. + + Returns: + tuple[`torch.Tensor`, `QuantState`]: + A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor with packed 4-bit values. + - [`QuantState`]: The state object used to undo the quantization. + """ + try: + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) + except Exception as e: + raise RuntimeError( + f"Error during 4-bit FP4 quantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected." + ) from e def quantize_nf4( @@ -969,7 +996,34 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) + """Quantize tensor A in blocks to nf4 values. + + Quantizes tensor A by dividing it into blocks which are independently quantized. + + Args: + A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes. + absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. + quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`. + + Returns: + tuple[`torch.Tensor`, `QuantState`]: + A tuple containing the quantization results. + - `torch.Tensor`: The quantized tensor with packed 4-bit values. + - [`QuantState`]: The state object used to undo the quantization. + """ + try: + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) + except Exception as e: + raise RuntimeError( + f"Error during 4-bit NF4 quantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected." + ) from e def quantize_4bit( @@ -998,55 +1052,157 @@ def quantize_4bit( Raises: ValueError: Raised when the input data type is not supported. + RuntimeError: Raised when the CUDA extension could not be loaded. Returns: Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results. - `torch.Tensor`: The quantized tensor with packed 4-bit values. - [`QuantState`]: The state object used to undo the quantization. """ - input_shape = A.shape - - _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( - A, - blocksize, - quant_type, - quant_storage, - ) - - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = _absmax.mean() - qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256) - del _absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) - else: - state = QuantState( - absmax=_absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) - - # TODO(matthewdouglas): Deprecate out kwarg - out = out.copy_(_out) if out is not None else _out - - # TODO(matthewdouglas): Deprecate absmax kwarg - if absmax is not None: - state.absmax = absmax.copy_(state.absmax) + # Check if torch.ops.bitsandbytes is available, otherwise fallback to C extension + try: + if hasattr(torch.ops, 'bitsandbytes') and hasattr(torch.ops.bitsandbytes, 'quantize_4bit'): + input_shape = A.shape + + _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( + A, + blocksize, + quant_type, + quant_storage, + ) - return out, state + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = _absmax.mean() + qabsmax, state2 = quantize_blockwise(_absmax - offset, blocksize=256) + del _absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=_absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + # TODO(matthewdouglas): Deprecate out kwarg + out = out.copy_(_out) if out is not None else _out + + # TODO(matthewdouglas): Deprecate absmax kwarg + if absmax is not None: + state.absmax = absmax.copy_(state.absmax) + + return out, state + else: + # Fallback to C extension + from .cextension import lib + + if lib is None: + raise RuntimeError( + "BitsAndBytes CUDA extension could not be loaded. " + "4-bit quantization requires the CUDA extension. " + "Please run 'python -m bitsandbytes' to diagnose installation issues or check " + "https://github.com/bitsandbytes-foundation/bitsandbytes/issues for known problems." + ) + + input_shape = A.shape + + if blocksize > 4096 or blocksize < 64: + raise ValueError(f"Blocksize {blocksize} not supported. Supported values: [64, 128, 256, 512, 1024, 2048, 4096]") + + if blocksize & (blocksize - 1) != 0: + raise ValueError(f"Blocksize {blocksize} not supported. Blocksize must be a power of 2.") + + if quant_type not in ["fp4", "nf4"]: + raise ValueError(f"quant_type {quant_type} not supported. Supported values: ['fp4', 'nf4']") + + code = get_4bit_type(quant_type, device=A.device) + + # Calculate output shape + numel = A.numel() + blocks = (numel + blocksize - 1) // blocksize + output_numel = (numel + 1) // 2 + + out_shape = A.shape + if out is None: + out = torch.zeros(output_numel, dtype=quant_storage, device=A.device) + + absmax_tensor = torch.zeros(blocks, dtype=torch.float32, device=A.device) + + is_on_gpu([A, out, absmax_tensor]) + with _cuda_device_of(A): + if A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(A), get_ptr(out), get_ptr(absmax_tensor), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: # nf4 + lib.cquantize_blockwise_fp32_nf4( + get_ptr(A), get_ptr(out), get_ptr(absmax_tensor), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(A), get_ptr(out), get_ptr(absmax_tensor), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: # nf4 + lib.cquantize_blockwise_fp16_nf4( + get_ptr(A), get_ptr(out), get_ptr(absmax_tensor), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: + raise ValueError(f"Datatype {A.dtype} not supported for quantization.") + + if compress_statistics: + offset = absmax_tensor.mean() + qabsmax, state2 = quantize_blockwise(absmax_tensor - offset, blocksize=256) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax_tensor, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + if absmax is not None: + state.absmax = absmax.copy_(state.absmax) + + return out, state + + except Exception as e: + raise RuntimeError( + f"Error during 4-bit quantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected.\n" + "Consider updating to a newer PyTorch version or checking CUDA library paths." + ) from e def dequantize_fp4( @@ -1056,7 +1212,36 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") + """Dequantizes a packed 4-bit fp4 quantized tensor. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_4bit`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + + Returns: + `torch.Tensor`: The dequantized tensor. + """ + try: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") + except Exception as e: + raise RuntimeError( + f"Error during 4-bit FP4 dequantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected." + ) from e def dequantize_nf4( @@ -1066,7 +1251,36 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: int = 64, ) -> torch.Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + """Dequantizes a packed 4-bit nf4 quantized tensor. + + The input tensor is dequantized by dividing it into blocks of `blocksize` values. + The the absolute maximum value within these blocks is used for scaling + the non-linear dequantization. + + Args: + A (`torch.Tensor`): The quantized input tensor. + quant_state ([`QuantState`], *optional*): + The quantization state as returned by [`quantize_4bit`]. + Required if `absmax` is not provided. + absmax (`torch.Tensor`, *optional*): + A tensor containing the scaling values. + Required if `quant_state` is not provided and ignored otherwise. + out (`torch.Tensor`, *optional*): A tensor to use to store the result. + blocksize (`int`, *optional*): + The size of the blocks. Defaults to 64. + Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. + + Returns: + `torch.Tensor`: The dequantized tensor. + """ + try: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") + except Exception as e: + raise RuntimeError( + f"Error during 4-bit NF4 dequantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected." + ) from e def dequantize_4bit( @@ -1099,47 +1313,136 @@ def dequantize_4bit( Raises: ValueError: Raised when the input data type or blocksize is not supported. + RuntimeError: Raised when the CUDA extension could not be loaded. Returns: `torch.Tensor`: The dequantized tensor. """ - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) - - else: - absmax = quant_state.absmax - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - - if out is not None: - torch.ops.bitsandbytes.dequantize_4bit.out( - A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out - ) - else: - out = torch.ops.bitsandbytes.dequantize_4bit.default( - A, - absmax, - quant_state.blocksize, - quant_state.quant_type, - quant_state.shape, - quant_state.dtype, - ) + try: + if hasattr(torch.ops, 'bitsandbytes') and hasattr(torch.ops.bitsandbytes, 'dequantize_4bit'): + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) - if A.shape[0] == 1: # is transposed, transpose back - return out.t() - return out + else: + absmax = quant_state.absmax + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is not None: + torch.ops.bitsandbytes.dequantize_4bit.out( + A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out + ) + else: + out = torch.ops.bitsandbytes.dequantize_4bit.default( + A, + absmax, + quant_state.blocksize, + quant_state.quant_type, + quant_state.shape, + quant_state.dtype, + ) + + if A.shape[0] == 1: # is transposed, transpose back + return out.t() + return out + else: + # Fallback to C extension + from .cextension import lib + + if lib is None: + raise RuntimeError( + "BitsAndBytes CUDA extension could not be loaded. " + "4-bit dequantization requires the CUDA extension. " + "Please run 'python -m bitsandbytes' to diagnose installation issues or check " + "https://github.com/bitsandbytes-foundation/bitsandbytes/issues for known problems." + ) + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + else: + absmax = quant_state.absmax + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + # Validate parameters + if blocksize > 4096 or blocksize < 64: + raise ValueError(f"Blocksize {blocksize} not supported. Supported values: [64, 128, 256, 512, 1024, 2048, 4096]") + + if blocksize & (blocksize - 1) != 0: + raise ValueError(f"Blocksize {blocksize} not supported. Blocksize must be a power of 2.") + + if quant_type not in ["fp4", "nf4"]: + raise ValueError(f"quant_type {quant_type} not supported. Supported values: ['fp4', 'nf4']") + + output_shape = quant_state.shape + numel = prod(output_shape) + + if out is None: + out = torch.zeros(output_shape, dtype=quant_state.dtype, device=A.device) + + # Since A is a packed 4-bit tensor, its numel should be half the output numel + is_on_gpu([A, out, absmax]) + with _cuda_device_of(A): + if quant_state.dtype == torch.float32: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(A), get_ptr(out), get_ptr(absmax), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: # nf4 + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(A), get_ptr(out), get_ptr(absmax), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + elif quant_state.dtype == torch.float16: + if quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(A), get_ptr(out), get_ptr(absmax), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: # nf4 + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(A), get_ptr(out), get_ptr(absmax), + ct.c_int32(numel), ct.c_int32(blocksize) + ) + else: + raise ValueError(f"Datatype {quant_state.dtype} not supported for dequantization.") + + if A.shape[0] == 1: # is transposed, transpose back + return out.t() + return out + + except Exception as e: + raise RuntimeError( + f"Error during 4-bit dequantization: {str(e)}\n" + "This is likely due to missing CUDA libraries or compatibility issues.\n" + "For troubleshooting, run 'python -m bitsandbytes' and check if CUDA libraries are properly detected.\n" + "Consider updating to a newer PyTorch version or checking CUDA library paths." + ) from e @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)