Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _get_col_absmax(

@register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
A = A.contiguous()
torch._check_is_size(blocksize)

if ROCM_WARP_SIZE_64:
Expand Down Expand Up @@ -269,6 +270,7 @@ def _(
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
Expand Down Expand Up @@ -303,6 +305,7 @@ def _dequantize_blockwise_impl(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
Expand Down Expand Up @@ -385,6 +388,7 @@ def _dequantize_4bit_impl(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
A = A.contiguous()
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
else:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
reassembled = torch.cat(shards).reshape(qB.shape)

assert reassembled.dtype == qB.dtype
assert torch.equal(
reassembled.view(torch.uint8), qB.view(torch.uint8)
), "Bytes changed after shard roundtrip"
assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip"

out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
torch.testing.assert_close(out, ref)
Expand Down
105 changes: 105 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,108 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
assert out.isreal().all()

opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize))


class TestNonContiguousInputs:
"""Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly."""

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_quantize_blockwise_non_contiguous(self, device, dtype, blocksize):
if device == "cpu":
pytest.skip("Non-contiguous fix targets CUDA backend only")

code = bitsandbytes.functional.create_dynamic_map().to(device)

# Create non-contiguous tensor via slicing
A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
A_noncontig = A_full[:, ::2, :, :]
assert not A_noncontig.is_contiguous()

A_contig = A_noncontig.contiguous()

out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_blockwise(A_noncontig, code, blocksize)
out_c, absmax_c = torch.ops.bitsandbytes.quantize_blockwise(A_contig, code, blocksize)

torch.testing.assert_close(absmax_nc, absmax_c)
torch.testing.assert_close(out_nc, out_c)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_dequantize_blockwise_non_contiguous(self, device, dtype, blocksize):
if device == "cpu":
pytest.skip("Non-contiguous fix targets CUDA backend only")

code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32)

# Quantize a contiguous tensor, then create non-contiguous uint8 via transpose
A = torch.randn(1024, 1024, dtype=dtype, device=device)
quantized, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize)

# Create non-contiguous uint8 tensor by transposing and transposing back
q_noncontig = quantized.t().t()
# If that's still contiguous, use a different approach
if q_noncontig.is_contiguous():
# Pad and slice to force non-contiguity
q_padded = torch.zeros(1024, 1025, dtype=torch.uint8, device=device)
q_padded[:, :1024] = quantized
q_noncontig = q_padded[:, :1024]

assert not q_noncontig.is_contiguous()
q_contig = q_noncontig.contiguous()

out_nc = torch.ops.bitsandbytes.dequantize_blockwise(q_noncontig, absmax, code, blocksize, dtype)
out_c = torch.ops.bitsandbytes.dequantize_blockwise(q_contig, absmax, code, blocksize, dtype)

torch.testing.assert_close(out_nc, out_c)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_quantize_4bit_non_contiguous(self, device, dtype, quant_type, blocksize):
if device != "cuda":
pytest.skip("Non-contiguous fix targets CUDA backend only")

# Reproduce issue #1342: non-contiguous tensor from slicing
A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
A_noncontig = A_full[:, ::2, :, :]
assert not A_noncontig.is_contiguous()

A_contig = A_noncontig.contiguous()
storage_dtype = torch.uint8

out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype)
out_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype)

torch.testing.assert_close(absmax_nc, absmax_c)
torch.testing.assert_close(out_nc, out_c)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256])
def test_quantize_4bit_roundtrip_non_contiguous(self, device, dtype, quant_type, blocksize):
"""End-to-end test: quantize non-contiguous, dequantize, compare with contiguous path."""
if device != "cuda":
pytest.skip("Non-contiguous fix targets CUDA backend only")

A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device)
A_noncontig = A_full[:, ::2, :, :]
assert not A_noncontig.is_contiguous()

A_contig = A_noncontig.contiguous()
storage_dtype = torch.uint8

# Quantize both
q_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype)
q_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype)

# Dequantize both
shape = A_contig.shape
deq_nc = torch.ops.bitsandbytes.dequantize_4bit(q_nc, absmax_nc, blocksize, quant_type, shape, dtype)
deq_c = torch.ops.bitsandbytes.dequantize_4bit(q_c, absmax_c, blocksize, quant_type, shape, dtype)

torch.testing.assert_close(deq_nc, deq_c)