diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index fe687e1e8..27bbc3b45 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,7 +15,7 @@ from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from .cextension import ROCM_WARP_SIZE_64, lib +from .cextension import lib name2qmap = {} @@ -869,8 +869,6 @@ def quantize_fp4( compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) @@ -882,8 +880,6 @@ def quantize_nf4( compress_statistics=False, quant_storage=torch.uint8, ): - if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) @@ -905,7 +901,7 @@ def quantize_4bit( absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. + The size of the blocks. Defaults to 64. Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -921,7 +917,7 @@ def quantize_4bit( """ if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 + blocksize = 64 input_shape = A.shape @@ -975,8 +971,6 @@ def dequantize_fp4( out: Optional[torch.Tensor] = None, blocksize: Optional[int] = None, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") @@ -987,8 +981,6 @@ def dequantize_nf4( out: Optional[torch.Tensor] = None, blocksize: Optional[int] = None, ) -> torch.Tensor: - if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") @@ -1016,7 +1008,7 @@ def dequantize_4bit( Required if `quant_state` is not provided and ignored otherwise. out (`torch.Tensor`, *optional*): A tensor to use to store the result. blocksize (`int`, *optional*): - The size of the blocks. Defaults to 128 on ROCm and 64 otherwise. + The size of the blocks. Defaults to 64. Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. @@ -1028,7 +1020,7 @@ def dequantize_4bit( """ if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 + blocksize = 64 if quant_state is None: assert absmax is not None and out is not None diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 67847f40c..6f705ab19 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -11,7 +11,6 @@ import torch.nn.functional as F import bitsandbytes as bnb -from bitsandbytes.cextension import ROCM_WARP_SIZE_64 from bitsandbytes.functional import ( QuantState, _convert_weight_packed_for_cpu, @@ -226,7 +225,7 @@ def __new__( data = torch.empty(0) if blocksize is None: - blocksize = 64 if not ROCM_WARP_SIZE_64 else 128 + blocksize = 64 self = torch.Tensor._make_subclass(cls, data, requires_grad) self.blocksize = blocksize diff --git a/tests/test_parametrize.py b/tests/test_parametrize.py index be4a6b52c..2d9c87a03 100644 --- a/tests/test_parametrize.py +++ b/tests/test_parametrize.py @@ -37,10 +37,7 @@ def __init__(self, device="cpu", dtype=torch.float32): @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) -@pytest.mark.parametrize( - "blocksize", - [64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256], -) +@pytest.mark.parametrize("blocksize", [64, 128, 256]) def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize): """Test basic parameter replacement with 4-bit quantization on different dtypes.""" if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): @@ -267,7 +264,7 @@ def test_quant_state_preservation(device, dtype): module = ParametrizeTestModule(device=device, dtype=dtype) - blocksize = 128 if ROCM_WARP_SIZE_64 else 64 + blocksize = 64 # Apply parametrization with specific settings replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)