Handle non-contiguous tensors in quantize/dequantize ops#1859
Handle non-contiguous tensors in quantize/dequantize ops#1859matthewdouglas merged 2 commits intomainfrom
Conversation
…#1342, #1690) Add A.contiguous() calls at the top of quantize_blockwise, quantize_4bit, and their dequantize counterparts in the CUDA backend. The CUDA kernels use raw pointers and assume contiguous memory layout, so non-contiguous inputs (e.g. tensor slices with strides) produced silently incorrect results. Add regression tests verifying non-contiguous tensors produce identical results to their contiguous equivalents for all four ops. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
TimDettmers
left a comment
There was a problem hiding this comment.
PR Review: #1859 — Handle non-contiguous tensors in quantize/dequantize ops
Bug fix: adds A = A.contiguous() at the top of the four core CUDA backend quantization functions (quantize_blockwise, _dequantize_blockwise_impl, quantize_4bit, _dequantize_4bit_impl). Non-contiguous tensors passed to these functions were silently producing incorrect results because the underlying CUDA kernels use get_ptr() which assumes contiguous memory layout. Includes a thorough regression test class (TestNonContiguousInputs) with 4 test methods covering all four ops.
Classification: [bug-fix] [test]
No blocking issues.
The fix is correct and well-placed. Placing .contiguous() at the CUDA backend layer (rather than in the public functional.py wrappers) is the right design — it keeps the fix close to the assumption it enforces (get_ptr() needs contiguous memory), and it naturally covers both direct op calls via torch.ops.bitsandbytes.* and calls through the higher-level functional.py API. The .contiguous() call is a no-op on already-contiguous tensors (returns self with no copy), so the common case pays zero cost.
Root cause analysis: The CUDA kernels receive raw data pointers via get_ptr() and iterate over them assuming contiguous C-order layout. When a tensor has non-unit strides (from slicing, transposing, etc.), the pointer still points to the start of the storage, but the kernel reads elements sequentially — skipping the stride logic — producing silently corrupted output. The fix materializes a contiguous copy before pointer extraction. This matches the root cause described in both #1342 and #1690.
Scope check — other get_ptr() call sites in cuda/ops.py: The file has additional get_ptr() calls in int8_linear_matmul, int8_mm_dequant, int8_vectorwise_quant, gemv_4bit, and the optimizer update functions. These are not covered by this PR. However, the scope is appropriate: the PR targets the four ops reported in the linked issues, and the other ops have different calling conventions (int8 ops receive already-quantized int8 tensors which are always freshly allocated and contiguous; gemv_4bit receives quantized weight tensors and activation vectors that come from nn.Linear.forward which produces contiguous outputs; optimizer tensors are parameter/gradient buffers managed by PyTorch which are contiguous). A follow-up to audit the remaining call sites would be a reasonable hardening measure but is not blocking.
Non-blocking suggestions:
- Consider whether a thin
_ensure_contiguous()helper or a comment nearget_ptr()documenting the contiguity requirement would help prevent regressions as new ops are added.
Downstream Impact
Risk level: NONE (beneficial)
This fix only adds automatic contiguity enforcement. It does not change any function signatures, return types, class attributes, or serialization formats. Downstream projects that were already passing contiguous tensors see no change. Downstream projects that were inadvertently passing non-contiguous tensors (which would have produced silently wrong results) now get correct results.
- Transformers: not affected (beneficial)
- PEFT: not affected (beneficial)
- Accelerate: not affected (beneficial)
- TGI: not affected
- vLLM: not affected
Performance Impact
Hot path affected: yes (quantize/dequantize are in the forward path for 4-bit inference)
Changes:
- Four new
.contiguous()calls at the top of CUDA backend quantize/dequantize functions .contiguous()on an already-contiguous tensor is a no-op (returnsself, no allocation or copy)- For non-contiguous inputs, a copy is made — but without this fix those inputs produced wrong results, so correctness trumps the copy cost
Expected impact: negligible for the common case; correctness fix for the uncommon case.
Recommendation: no concern
Cross-PR Conflicts
-
PR #1858 (Add k-bit blockwise quantization): overlaps on
bitsandbytes/backends/cuda/ops.py. The changes are in different functions (this PR modifies existing quantize/dequantize functions; #1858 adds new kbit functions). Merge conflicts are unlikely; no semantic conflict. -
PRs #1860, #1861, #1863, #1864, #1865, #1866 overlap on
tests/test_linear4bit.py. This PR's change to that file is a one-line formatting fix (joining a multi-line assert onto one line) — trivially resolvable. -
Security: Clear
-
Downstream impact: None (beneficial — prevents silent corruption)
-
Tests: Adequate — 4 test methods covering all 4 affected ops, with 3 dtypes, multiple blocksizes, fp4/nf4 for 4-bit ops, and an end-to-end roundtrip test
-
CI: All pass (lint, CPU builds/tests across platforms, CUDA builds/tests on L40S and T4 with multiple CUDA versions, Windows CUDA, ROCm builds)
-
Performance: Negligible (no-op on contiguous tensors)
-
Serialization: Not affected
-
torch.compile: Not affected (no op registration changes)
Summary
quantize_blockwise,dequantize_blockwise,quantize_4bit, anddequantize_4bitget_ptr()) which assume contiguous memory layout; non-contiguous inputs (e.g. strided slices) produced silently incorrect resultsFixes #1342
Fixes #1690
Test plan
TestNonContiguousInputsclass intests/test_ops.pywith 4 test methods (54 CUDA parametrizations)test_4bit_quant_largeunrelated to this change)🤖 Generated with Claude Code