Skip to content

Commit 6fd6209

Browse files
sudhakarsingh27pre-commit-ci[bot]timmoon10
authored
[PyTorch] Make sure Float8Tensor.contiguous supports autograd (#2533)
* add early return back (removed in 2427) Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure Float8Tensor.contiguous supports autograd Expand quantized tensor tests to check identity ops. Signed-off-by: Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> Signed-off-by: Tim Moon <tmoon@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <tmoon@nvidia.com>
1 parent 3e69397 commit 6fd6209

File tree

2 files changed

+187
-73
lines changed

2 files changed

+187
-73
lines changed

tests/pytorch/test_quantized_tensor.py

Lines changed: 165 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Float8Tensor,
2121
MXFP8Tensor,
2222
NVFP4Tensor,
23+
QuantizedTensor,
2324
)
2425

2526
from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported
@@ -50,14 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List:
5051
# Types that can be interpreted as tensor dims
5152
DimsType = Union[Iterable[int], int]
5253

53-
# Check if FP8 is supported
54+
# Supported quantization recipes
5455
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
55-
5656
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
5757
return_reason=True
5858
)
5959
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
6060
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)
61+
_quantization_list: List[str] = []
62+
if fp8_available:
63+
_quantization_list.append("fp8")
64+
if fp8_block_scaling_available:
65+
_quantization_list.append("fp8_blockwise")
66+
if mxfp8_available:
67+
_quantization_list.append("mxfp8")
68+
if nvfp4_available:
69+
_quantization_list.append("nvfp4")
6170

6271

6372
# delayed scaling
@@ -98,6 +107,79 @@ def to_float8_CS(
98107
return quantizer(tensor)
99108

100109

110+
@torch.no_grad()
111+
def make_reference_and_test_tensors(
112+
shape: int | Iterable[int],
113+
quantization: Optional[str] = None,
114+
ref_dtype: torch.dtype = torch.float64,
115+
ref_device: torch.device = "cpu",
116+
test_dtype: torch.dtype = torch.float32,
117+
test_device: torch.device = "cuda",
118+
requires_grad: bool = True,
119+
) -> tuple[torch.Tensor, torch.Tensor]:
120+
"""Construct tensors with the same values
121+
122+
The reference tensor is intended for use in plain PyTorch
123+
operations in high precision. The test tensor is intended for use
124+
in Transformer Engine operations.
125+
126+
If a quantization scheme is provided, the tensor values are
127+
quantized so that they are representable.
128+
129+
"""
130+
131+
# Random reference tensor
132+
ref = torch.rand(shape, dtype=ref_dtype, device=ref_device)
133+
134+
# Construct test tensor from reference tensor
135+
test = ref.to(device=test_device, dtype=test_dtype)
136+
if quantization is None:
137+
if test.data_ptr() == ref.data_ptr():
138+
test = test.clone()
139+
elif quantization in ("fp8", "fp8_delayed_scaling"):
140+
quantizer = Float8Quantizer(
141+
scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(),
142+
amax=torch.zeros(1, dtype=torch.float32, device=test_device),
143+
fp8_dtype=tex.DType.kFloat8E4M3,
144+
)
145+
test = quantizer(test)
146+
elif quantization == "fp8_current_scaling":
147+
quantizer = Float8CurrentScalingQuantizer(
148+
fp8_dtype=tex.DType.kFloat8E4M3,
149+
device=test_device,
150+
)
151+
test = quantizer(test)
152+
elif quantization == "fp8_blockwise":
153+
quantizer = Float8BlockQuantizer(
154+
fp8_dtype=tex.DType.kFloat8E4M3,
155+
rowwise=True,
156+
columnwise=True,
157+
force_pow_2_scales=True,
158+
amax_epsilon=0.0,
159+
block_scaling_dim=1,
160+
)
161+
test = quantizer(test)
162+
elif quantization == "mxfp8":
163+
test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test)
164+
elif quantization == "nvfp4":
165+
test = NVFP4Quantizer(
166+
with_rht=False,
167+
with_post_rht_amax=False,
168+
with_2d_quantization=False,
169+
stochastic_rounding=False,
170+
with_random_sign_mask=False,
171+
)(test)
172+
else:
173+
raise ValueError(f"Unsupported quantization scheme ({quantization})")
174+
175+
# Make sure reference and test tensors match each other
176+
ref.copy_(test)
177+
178+
ref.requires_grad_(requires_grad)
179+
test.requires_grad_(requires_grad)
180+
return ref, test
181+
182+
101183
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
102184
class TestFloat8Tensor:
103185

@@ -466,86 +548,111 @@ def test_quantize_dequantize(
466548
torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])
467549

468550

469-
class TestAllQuantizedTensors:
551+
class TestQuantizedTensor:
470552
@staticmethod
471553
def setup_class(cls) -> None:
472554
# Configure RNG
473555
seed = 1234
474556
torch.manual_seed(seed)
475557
torch.cuda.manual_seed(seed)
476558

477-
@pytest.mark.parametrize("quantization", ["fp8", "mxfp8", "nvfp4", "fp8_blockwise"])
559+
@pytest.mark.parametrize("op", ("clone", "view", "reshape", "contiguous"))
560+
@pytest.mark.parametrize("quantization", _quantization_list)
561+
def test_identity_op(
562+
self,
563+
*,
564+
op: str,
565+
quantization: str,
566+
shape: Iterable[int] = (128, 128),
567+
dtype: torch.dtype = torch.bfloat16,
568+
device: torch.device = "cuda",
569+
) -> None:
570+
"""Test operations that do not affect tensor values.
571+
572+
These operations are must produce outputs that are bit-wise
573+
equivalent to the inputs. They must support autograd.
574+
575+
"""
576+
577+
# Create reference and quantized tensor
578+
x_ref, x_test = make_reference_and_test_tensors(
579+
shape=shape,
580+
quantization=quantization,
581+
test_dtype=dtype,
582+
)
583+
dy_ref, dy_test = make_reference_and_test_tensors(
584+
shape=shape,
585+
test_dtype=dtype,
586+
requires_grad=False,
587+
)
588+
589+
# Apply identity operation
590+
if op == "clone":
591+
y_ref = x_ref.clone()
592+
y_test = x_test.clone()
593+
elif op == "view":
594+
y_ref = x_ref.view(shape)
595+
y_test = x_test.view(shape)
596+
elif op == "reshape":
597+
y_ref = x_ref.reshape(shape)
598+
y_test = x_test.reshape(shape)
599+
elif op == "contiguous":
600+
y_ref = x_ref.contiguous()
601+
y_test = x_test.contiguous()
602+
603+
# Check autograd
604+
y_test.backward(dy_test)
605+
assert x_test.grad is not None
606+
607+
# Check values
608+
tols = dict(rtol=0, atol=0)
609+
if isinstance(y_test, QuantizedTensor):
610+
y_test = y_test.dequantize()
611+
y_test = y_test.to(dtype=torch.float64, device="cpu")
612+
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
613+
dx_ref = dy_ref
614+
torch.testing.assert_close(y_test, y_ref, **tols)
615+
torch.testing.assert_close(dx_test, dx_ref, **tols)
616+
617+
@pytest.mark.parametrize("quantization", _quantization_list)
478618
@pytest.mark.parametrize("dim", [0, 1])
479619
def test_chunk(
480620
self,
621+
*,
481622
quantization: str,
482623
dim: int,
483624
shape: Iterable[int] = (128, 128),
484625
chunks: int = 2,
485626
dtype: torch.dtype = torch.bfloat16,
486627
device: torch.device = "cuda",
487628
) -> None:
488-
# Skip invalid configs
489-
if quantization == "fp8" and not fp8_available:
490-
pytest.skip(reason_for_no_fp8)
491-
if quantization == "fp8_blockwise" and not fp8_block_scaling_available:
492-
pytest.skip(reason_for_no_fp8_block_scaling)
493-
if quantization == "mxfp8" and not mxfp8_available:
494-
pytest.skip(reason_for_no_mxfp8)
495-
if quantization == "nvfp4" and not nvfp4_available:
496-
pytest.skip(reason_for_no_nvfp4)
497-
# Create quantizer
498-
if quantization == "fp8":
499-
quantizer = Float8Quantizer(
500-
scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(),
501-
amax=torch.zeros(1, dtype=torch.float32, device=device),
502-
fp8_dtype=tex.DType.kFloat8E4M3,
503-
)
504-
elif quantization == "mxfp8":
505-
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
506-
elif quantization == "fp8_blockwise":
507-
quantizer = Float8BlockQuantizer(
508-
fp8_dtype=tex.DType.kFloat8E4M3,
509-
rowwise=True,
510-
columnwise=True,
511-
force_pow_2_scales=True,
512-
amax_epsilon=0.0,
513-
block_scaling_dim=1,
514-
)
515-
elif quantization == "nvfp4":
516-
quantizer = NVFP4Quantizer(
517-
with_rht=False,
518-
with_post_rht_amax=False,
519-
with_2d_quantization=False,
520-
stochastic_rounding=False,
521-
with_random_sign_mask=False,
522-
)
523-
else:
524-
raise ValueError(f"Unknown quantizer ({quantizer})")
629+
525630
# Create reference and quantized tensor
526-
ref_tensor = torch.randn(shape, device=device, dtype=dtype)
527-
quantized_tensor = quantizer(ref_tensor)
528-
ref_tensor.copy_(quantized_tensor)
631+
x_ref, x_test = make_reference_and_test_tensors(
632+
shape=shape,
633+
quantization=quantization,
634+
test_dtype=dtype,
635+
)
529636

530637
# Chunk tensors
531-
ref_splits = torch.chunk(ref_tensor, chunks, dim=dim)
532-
quantized_splits = torch.chunk(quantized_tensor, chunks, dim=dim)
638+
ys_ref = torch.chunk(x_ref, chunks, dim=dim)
639+
ys_test = torch.chunk(x_test, chunks, dim=dim)
640+
533641
# Check splits
534-
for ref_split, quantized_split in zip(ref_splits, quantized_splits):
642+
for y_ref, y_test in zip(ys_ref, ys_test):
643+
535644
# Check split shapes
536-
assert ref_split.size() == quantized_split.size()
645+
assert y_ref.size() == y_test.size()
537646

538647
# Check that splits are quantized when expected
539648
if quantization == "fp8":
540-
assert isinstance(quantized_split, Float8Tensor)
541-
expected_value = quantized_split.dequantize()
649+
assert isinstance(y_test, Float8Tensor)
650+
y_test = y_test.dequantize()
542651
elif quantization == "mxfp8" and dim == 0:
543-
assert isinstance(quantized_split, MXFP8Tensor)
544-
expected_value = quantized_split.dequantize()
545-
else:
546-
# Otherwise torch dispatch would default to base implementation
547-
# dequantize and computing output and hence output from torch chunk
548-
# is already dequantized.
549-
expected_value = quantized_split
652+
assert isinstance(y_test, MXFP8Tensor)
653+
y_test = y_test.dequantize()
654+
550655
# Check values
551-
torch.testing.assert_close(expected_value, ref_split)
656+
tols = dict(rtol=0, atol=0) # Chunking is exact
657+
y_test = y_test.to(dtype=torch.float64, device="cpu")
658+
torch.testing.assert_close(y_test, y_ref, **tols)

transformer_engine/pytorch/tensor/float8_tensor.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -551,24 +551,31 @@ def contiguous(
551551
) -> Float8Tensor:
552552
"""Returns tensor with data in provided memory format
553553
554-
Returns `self` if data is already in correct memory format.
554+
Returns ``self`` if data is already in correct memory format.
555555
556556
"""
557-
# requires_grad remains unaltered when calling contiguous on
558-
# torch tensor and so should be the case for our custom float8 tensor
559-
# as well.
560-
return Float8Tensor.make_like(
561-
tensor=self,
562-
data=self._data.contiguous(memory_format=memory_format),
563-
data_transpose=(
564-
self._transpose.contiguous(memory_format=memory_format)
565-
if self._transpose is not None
566-
else None
567-
),
568-
requires_grad=self.requires_grad,
569-
)
570557

571-
# raise ValueError("Float8Tensor does not support different memory formats!")
558+
# Check if tensor already has correct memory format
559+
if self._data is not None and not self._data.is_contiguous(memory_format=memory_format):
560+
pass
561+
elif self._transpose is not None and not self._transpose.is_contiguous(
562+
memory_format=memory_format
563+
):
564+
pass
565+
else:
566+
# Tensor has correct memory format, so return immediately
567+
return self
568+
569+
# Construct tensor with correct data format
570+
data, data_transpose = None, None
571+
if self._data is not None:
572+
data = self._data.contiguous(memory_format=memory_format)
573+
if self._transpose is not None and not self._transpose_invalid:
574+
data_transpose = self._transpose.contiguous(memory_format=memory_format)
575+
return _IdentityFunc.apply(
576+
self,
577+
{"data": data, "data_transpose": data_transpose},
578+
)
572579

573580
def _reset_caches(self) -> None:
574581
"""

0 commit comments

Comments
 (0)