Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import ROCM_WARP_SIZE_64, lib
from .cextension import lib

name2qmap = {}

Expand Down Expand Up @@ -869,8 +869,6 @@ def quantize_fp4(
compress_statistics=False,
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)


Expand All @@ -882,8 +880,6 @@ def quantize_nf4(
compress_statistics=False,
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)


Expand All @@ -905,7 +901,7 @@ def quantize_4bit(
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
The size of the blocks. Defaults to 64.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
Expand All @@ -921,7 +917,7 @@ def quantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
blocksize = 64

input_shape = A.shape

Expand Down Expand Up @@ -975,8 +971,6 @@ def dequantize_fp4(
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


Expand All @@ -987,8 +981,6 @@ def dequantize_nf4(
out: Optional[torch.Tensor] = None,
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


Expand Down Expand Up @@ -1016,7 +1008,7 @@ def dequantize_4bit(
Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
The size of the blocks. Defaults to 64.
Valid values are 32, 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.

Expand All @@ -1028,7 +1020,7 @@ def dequantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
blocksize = 64

if quant_state is None:
assert absmax is not None and out is not None
Expand Down
3 changes: 1 addition & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.functional import (
QuantState,
_convert_weight_packed_for_cpu,
Expand Down Expand Up @@ -226,7 +225,7 @@ def __new__(
data = torch.empty(0)

if blocksize is None:
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
blocksize = 64

self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
Expand Down
7 changes: 2 additions & 5 deletions tests/test_parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ def __init__(self, device="cpu", dtype=torch.float32):
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
@pytest.mark.parametrize(
"blocksize",
[64, 128, 256] if not ROCM_WARP_SIZE_64 else [128, 256],
)
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
"""Test basic parameter replacement with 4-bit quantization on different dtypes."""
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
Expand Down Expand Up @@ -267,7 +264,7 @@ def test_quant_state_preservation(device, dtype):

module = ParametrizeTestModule(device=device, dtype=dtype)

blocksize = 128 if ROCM_WARP_SIZE_64 else 64
blocksize = 64

# Apply parametrization with specific settings
replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)
Expand Down