diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index d92f9a490..ccbe3549f 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -209,6 +209,7 @@ def _get_col_absmax( @register_kernel("bitsandbytes::quantize_blockwise", "cuda") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: + A = A.contiguous() torch._check_is_size(blocksize) if ROCM_WARP_SIZE_64: @@ -269,6 +270,7 @@ def _( def _dequantize_blockwise_impl( A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor ) -> None: + A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: @@ -303,6 +305,7 @@ def _dequantize_blockwise_impl( def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: + A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: @@ -385,6 +388,7 @@ def _dequantize_4bit_impl( dtype: torch.dtype, out: torch.Tensor, ) -> None: + A = A.contiguous() if ROCM_WARP_SIZE_64: torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) else: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index ee8bafe80..de40d158c 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage): reassembled = torch.cat(shards).reshape(qB.shape) assert reassembled.dtype == qB.dtype - assert torch.equal( - reassembled.view(torch.uint8), qB.view(torch.uint8) - ), "Bytes changed after shard roundtrip" + assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip" out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state) torch.testing.assert_close(out, ref) diff --git a/tests/test_ops.py b/tests/test_ops.py index 5f780f2ac..a33f9ab62 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -246,3 +246,108 @@ def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): assert out.isreal().all() opcheck(torch.ops.bitsandbytes.gemv_4bit.default, (A, B_q, B.shape, absmax, code, blocksize)) + + +class TestNonContiguousInputs: + """Regression tests for #1342 and #1690: quantization must handle non-contiguous tensors correctly.""" + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("blocksize", [64, 128, 256]) + def test_quantize_blockwise_non_contiguous(self, device, dtype, blocksize): + if device == "cpu": + pytest.skip("Non-contiguous fix targets CUDA backend only") + + code = bitsandbytes.functional.create_dynamic_map().to(device) + + # Create non-contiguous tensor via slicing + A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device) + A_noncontig = A_full[:, ::2, :, :] + assert not A_noncontig.is_contiguous() + + A_contig = A_noncontig.contiguous() + + out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_blockwise(A_noncontig, code, blocksize) + out_c, absmax_c = torch.ops.bitsandbytes.quantize_blockwise(A_contig, code, blocksize) + + torch.testing.assert_close(absmax_nc, absmax_c) + torch.testing.assert_close(out_nc, out_c) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("blocksize", [64, 128, 256]) + def test_dequantize_blockwise_non_contiguous(self, device, dtype, blocksize): + if device == "cpu": + pytest.skip("Non-contiguous fix targets CUDA backend only") + + code = bitsandbytes.functional.create_dynamic_map().to(device, dtype=torch.float32) + + # Quantize a contiguous tensor, then create non-contiguous uint8 via transpose + A = torch.randn(1024, 1024, dtype=dtype, device=device) + quantized, absmax = torch.ops.bitsandbytes.quantize_blockwise(A, code, blocksize) + + # Create non-contiguous uint8 tensor by transposing and transposing back + q_noncontig = quantized.t().t() + # If that's still contiguous, use a different approach + if q_noncontig.is_contiguous(): + # Pad and slice to force non-contiguity + q_padded = torch.zeros(1024, 1025, dtype=torch.uint8, device=device) + q_padded[:, :1024] = quantized + q_noncontig = q_padded[:, :1024] + + assert not q_noncontig.is_contiguous() + q_contig = q_noncontig.contiguous() + + out_nc = torch.ops.bitsandbytes.dequantize_blockwise(q_noncontig, absmax, code, blocksize, dtype) + out_c = torch.ops.bitsandbytes.dequantize_blockwise(q_contig, absmax, code, blocksize, dtype) + + torch.testing.assert_close(out_nc, out_c) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256]) + def test_quantize_4bit_non_contiguous(self, device, dtype, quant_type, blocksize): + if device != "cuda": + pytest.skip("Non-contiguous fix targets CUDA backend only") + + # Reproduce issue #1342: non-contiguous tensor from slicing + A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device) + A_noncontig = A_full[:, ::2, :, :] + assert not A_noncontig.is_contiguous() + + A_contig = A_noncontig.contiguous() + storage_dtype = torch.uint8 + + out_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype) + out_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype) + + torch.testing.assert_close(absmax_nc, absmax_c) + torch.testing.assert_close(out_nc, out_c) + + @pytest.mark.parametrize("device", get_available_devices()) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype")) + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + @pytest.mark.parametrize("blocksize", [64, 128, 256]) + def test_quantize_4bit_roundtrip_non_contiguous(self, device, dtype, quant_type, blocksize): + """End-to-end test: quantize non-contiguous, dequantize, compare with contiguous path.""" + if device != "cuda": + pytest.skip("Non-contiguous fix targets CUDA backend only") + + A_full = torch.randn(3, 4, 6, 256, dtype=dtype, device=device) + A_noncontig = A_full[:, ::2, :, :] + assert not A_noncontig.is_contiguous() + + A_contig = A_noncontig.contiguous() + storage_dtype = torch.uint8 + + # Quantize both + q_nc, absmax_nc = torch.ops.bitsandbytes.quantize_4bit(A_noncontig, blocksize, quant_type, storage_dtype) + q_c, absmax_c = torch.ops.bitsandbytes.quantize_4bit(A_contig, blocksize, quant_type, storage_dtype) + + # Dequantize both + shape = A_contig.shape + deq_nc = torch.ops.bitsandbytes.dequantize_4bit(q_nc, absmax_nc, blocksize, quant_type, shape, dtype) + deq_c = torch.ops.bitsandbytes.dequantize_4bit(q_c, absmax_c, blocksize, quant_type, shape, dtype) + + torch.testing.assert_close(deq_nc, deq_c)