|
20 | 20 | Float8Tensor, |
21 | 21 | MXFP8Tensor, |
22 | 22 | NVFP4Tensor, |
| 23 | + QuantizedTensor, |
23 | 24 | ) |
24 | 25 |
|
25 | 26 | from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported |
@@ -50,14 +51,22 @@ def _to_list(x: Union[Iterable, Any]) -> List: |
50 | 51 | # Types that can be interpreted as tensor dims |
51 | 52 | DimsType = Union[Iterable[int], int] |
52 | 53 |
|
53 | | -# Check if FP8 is supported |
| 54 | +# Supported quantization recipes |
54 | 55 | fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) |
55 | | - |
56 | 56 | fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( |
57 | 57 | return_reason=True |
58 | 58 | ) |
59 | 59 | mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) |
60 | 60 | nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) |
| 61 | +_quantization_list: List[str] = [] |
| 62 | +if fp8_available: |
| 63 | + _quantization_list.append("fp8") |
| 64 | +if fp8_block_scaling_available: |
| 65 | + _quantization_list.append("fp8_blockwise") |
| 66 | +if mxfp8_available: |
| 67 | + _quantization_list.append("mxfp8") |
| 68 | +if nvfp4_available: |
| 69 | + _quantization_list.append("nvfp4") |
61 | 70 |
|
62 | 71 |
|
63 | 72 | # delayed scaling |
@@ -98,6 +107,79 @@ def to_float8_CS( |
98 | 107 | return quantizer(tensor) |
99 | 108 |
|
100 | 109 |
|
| 110 | +@torch.no_grad() |
| 111 | +def make_reference_and_test_tensors( |
| 112 | + shape: int | Iterable[int], |
| 113 | + quantization: Optional[str] = None, |
| 114 | + ref_dtype: torch.dtype = torch.float64, |
| 115 | + ref_device: torch.device = "cpu", |
| 116 | + test_dtype: torch.dtype = torch.float32, |
| 117 | + test_device: torch.device = "cuda", |
| 118 | + requires_grad: bool = True, |
| 119 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 120 | + """Construct tensors with the same values |
| 121 | +
|
| 122 | + The reference tensor is intended for use in plain PyTorch |
| 123 | + operations in high precision. The test tensor is intended for use |
| 124 | + in Transformer Engine operations. |
| 125 | +
|
| 126 | + If a quantization scheme is provided, the tensor values are |
| 127 | + quantized so that they are representable. |
| 128 | +
|
| 129 | + """ |
| 130 | + |
| 131 | + # Random reference tensor |
| 132 | + ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) |
| 133 | + |
| 134 | + # Construct test tensor from reference tensor |
| 135 | + test = ref.to(device=test_device, dtype=test_dtype) |
| 136 | + if quantization is None: |
| 137 | + if test.data_ptr() == ref.data_ptr(): |
| 138 | + test = test.clone() |
| 139 | + elif quantization in ("fp8", "fp8_delayed_scaling"): |
| 140 | + quantizer = Float8Quantizer( |
| 141 | + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), |
| 142 | + amax=torch.zeros(1, dtype=torch.float32, device=test_device), |
| 143 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 144 | + ) |
| 145 | + test = quantizer(test) |
| 146 | + elif quantization == "fp8_current_scaling": |
| 147 | + quantizer = Float8CurrentScalingQuantizer( |
| 148 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 149 | + device=test_device, |
| 150 | + ) |
| 151 | + test = quantizer(test) |
| 152 | + elif quantization == "fp8_blockwise": |
| 153 | + quantizer = Float8BlockQuantizer( |
| 154 | + fp8_dtype=tex.DType.kFloat8E4M3, |
| 155 | + rowwise=True, |
| 156 | + columnwise=True, |
| 157 | + force_pow_2_scales=True, |
| 158 | + amax_epsilon=0.0, |
| 159 | + block_scaling_dim=1, |
| 160 | + ) |
| 161 | + test = quantizer(test) |
| 162 | + elif quantization == "mxfp8": |
| 163 | + test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) |
| 164 | + elif quantization == "nvfp4": |
| 165 | + test = NVFP4Quantizer( |
| 166 | + with_rht=False, |
| 167 | + with_post_rht_amax=False, |
| 168 | + with_2d_quantization=False, |
| 169 | + stochastic_rounding=False, |
| 170 | + with_random_sign_mask=False, |
| 171 | + )(test) |
| 172 | + else: |
| 173 | + raise ValueError(f"Unsupported quantization scheme ({quantization})") |
| 174 | + |
| 175 | + # Make sure reference and test tensors match each other |
| 176 | + ref.copy_(test) |
| 177 | + |
| 178 | + ref.requires_grad_(requires_grad) |
| 179 | + test.requires_grad_(requires_grad) |
| 180 | + return ref, test |
| 181 | + |
| 182 | + |
101 | 183 | @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) |
102 | 184 | class TestFloat8Tensor: |
103 | 185 |
|
@@ -466,86 +548,111 @@ def test_quantize_dequantize( |
466 | 548 | torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) |
467 | 549 |
|
468 | 550 |
|
469 | | -class TestAllQuantizedTensors: |
| 551 | +class TestQuantizedTensor: |
470 | 552 | @staticmethod |
471 | 553 | def setup_class(cls) -> None: |
472 | 554 | # Configure RNG |
473 | 555 | seed = 1234 |
474 | 556 | torch.manual_seed(seed) |
475 | 557 | torch.cuda.manual_seed(seed) |
476 | 558 |
|
477 | | - @pytest.mark.parametrize("quantization", ["fp8", "mxfp8", "nvfp4", "fp8_blockwise"]) |
| 559 | + @pytest.mark.parametrize("op", ("clone", "view", "reshape", "contiguous")) |
| 560 | + @pytest.mark.parametrize("quantization", _quantization_list) |
| 561 | + def test_identity_op( |
| 562 | + self, |
| 563 | + *, |
| 564 | + op: str, |
| 565 | + quantization: str, |
| 566 | + shape: Iterable[int] = (128, 128), |
| 567 | + dtype: torch.dtype = torch.bfloat16, |
| 568 | + device: torch.device = "cuda", |
| 569 | + ) -> None: |
| 570 | + """Test operations that do not affect tensor values. |
| 571 | +
|
| 572 | + These operations are must produce outputs that are bit-wise |
| 573 | + equivalent to the inputs. They must support autograd. |
| 574 | +
|
| 575 | + """ |
| 576 | + |
| 577 | + # Create reference and quantized tensor |
| 578 | + x_ref, x_test = make_reference_and_test_tensors( |
| 579 | + shape=shape, |
| 580 | + quantization=quantization, |
| 581 | + test_dtype=dtype, |
| 582 | + ) |
| 583 | + dy_ref, dy_test = make_reference_and_test_tensors( |
| 584 | + shape=shape, |
| 585 | + test_dtype=dtype, |
| 586 | + requires_grad=False, |
| 587 | + ) |
| 588 | + |
| 589 | + # Apply identity operation |
| 590 | + if op == "clone": |
| 591 | + y_ref = x_ref.clone() |
| 592 | + y_test = x_test.clone() |
| 593 | + elif op == "view": |
| 594 | + y_ref = x_ref.view(shape) |
| 595 | + y_test = x_test.view(shape) |
| 596 | + elif op == "reshape": |
| 597 | + y_ref = x_ref.reshape(shape) |
| 598 | + y_test = x_test.reshape(shape) |
| 599 | + elif op == "contiguous": |
| 600 | + y_ref = x_ref.contiguous() |
| 601 | + y_test = x_test.contiguous() |
| 602 | + |
| 603 | + # Check autograd |
| 604 | + y_test.backward(dy_test) |
| 605 | + assert x_test.grad is not None |
| 606 | + |
| 607 | + # Check values |
| 608 | + tols = dict(rtol=0, atol=0) |
| 609 | + if isinstance(y_test, QuantizedTensor): |
| 610 | + y_test = y_test.dequantize() |
| 611 | + y_test = y_test.to(dtype=torch.float64, device="cpu") |
| 612 | + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") |
| 613 | + dx_ref = dy_ref |
| 614 | + torch.testing.assert_close(y_test, y_ref, **tols) |
| 615 | + torch.testing.assert_close(dx_test, dx_ref, **tols) |
| 616 | + |
| 617 | + @pytest.mark.parametrize("quantization", _quantization_list) |
478 | 618 | @pytest.mark.parametrize("dim", [0, 1]) |
479 | 619 | def test_chunk( |
480 | 620 | self, |
| 621 | + *, |
481 | 622 | quantization: str, |
482 | 623 | dim: int, |
483 | 624 | shape: Iterable[int] = (128, 128), |
484 | 625 | chunks: int = 2, |
485 | 626 | dtype: torch.dtype = torch.bfloat16, |
486 | 627 | device: torch.device = "cuda", |
487 | 628 | ) -> None: |
488 | | - # Skip invalid configs |
489 | | - if quantization == "fp8" and not fp8_available: |
490 | | - pytest.skip(reason_for_no_fp8) |
491 | | - if quantization == "fp8_blockwise" and not fp8_block_scaling_available: |
492 | | - pytest.skip(reason_for_no_fp8_block_scaling) |
493 | | - if quantization == "mxfp8" and not mxfp8_available: |
494 | | - pytest.skip(reason_for_no_mxfp8) |
495 | | - if quantization == "nvfp4" and not nvfp4_available: |
496 | | - pytest.skip(reason_for_no_nvfp4) |
497 | | - # Create quantizer |
498 | | - if quantization == "fp8": |
499 | | - quantizer = Float8Quantizer( |
500 | | - scale=torch.ones(1, dtype=torch.float32, device=device).squeeze(), |
501 | | - amax=torch.zeros(1, dtype=torch.float32, device=device), |
502 | | - fp8_dtype=tex.DType.kFloat8E4M3, |
503 | | - ) |
504 | | - elif quantization == "mxfp8": |
505 | | - quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) |
506 | | - elif quantization == "fp8_blockwise": |
507 | | - quantizer = Float8BlockQuantizer( |
508 | | - fp8_dtype=tex.DType.kFloat8E4M3, |
509 | | - rowwise=True, |
510 | | - columnwise=True, |
511 | | - force_pow_2_scales=True, |
512 | | - amax_epsilon=0.0, |
513 | | - block_scaling_dim=1, |
514 | | - ) |
515 | | - elif quantization == "nvfp4": |
516 | | - quantizer = NVFP4Quantizer( |
517 | | - with_rht=False, |
518 | | - with_post_rht_amax=False, |
519 | | - with_2d_quantization=False, |
520 | | - stochastic_rounding=False, |
521 | | - with_random_sign_mask=False, |
522 | | - ) |
523 | | - else: |
524 | | - raise ValueError(f"Unknown quantizer ({quantizer})") |
| 629 | + |
525 | 630 | # Create reference and quantized tensor |
526 | | - ref_tensor = torch.randn(shape, device=device, dtype=dtype) |
527 | | - quantized_tensor = quantizer(ref_tensor) |
528 | | - ref_tensor.copy_(quantized_tensor) |
| 631 | + x_ref, x_test = make_reference_and_test_tensors( |
| 632 | + shape=shape, |
| 633 | + quantization=quantization, |
| 634 | + test_dtype=dtype, |
| 635 | + ) |
529 | 636 |
|
530 | 637 | # Chunk tensors |
531 | | - ref_splits = torch.chunk(ref_tensor, chunks, dim=dim) |
532 | | - quantized_splits = torch.chunk(quantized_tensor, chunks, dim=dim) |
| 638 | + ys_ref = torch.chunk(x_ref, chunks, dim=dim) |
| 639 | + ys_test = torch.chunk(x_test, chunks, dim=dim) |
| 640 | + |
533 | 641 | # Check splits |
534 | | - for ref_split, quantized_split in zip(ref_splits, quantized_splits): |
| 642 | + for y_ref, y_test in zip(ys_ref, ys_test): |
| 643 | + |
535 | 644 | # Check split shapes |
536 | | - assert ref_split.size() == quantized_split.size() |
| 645 | + assert y_ref.size() == y_test.size() |
537 | 646 |
|
538 | 647 | # Check that splits are quantized when expected |
539 | 648 | if quantization == "fp8": |
540 | | - assert isinstance(quantized_split, Float8Tensor) |
541 | | - expected_value = quantized_split.dequantize() |
| 649 | + assert isinstance(y_test, Float8Tensor) |
| 650 | + y_test = y_test.dequantize() |
542 | 651 | elif quantization == "mxfp8" and dim == 0: |
543 | | - assert isinstance(quantized_split, MXFP8Tensor) |
544 | | - expected_value = quantized_split.dequantize() |
545 | | - else: |
546 | | - # Otherwise torch dispatch would default to base implementation |
547 | | - # dequantize and computing output and hence output from torch chunk |
548 | | - # is already dequantized. |
549 | | - expected_value = quantized_split |
| 652 | + assert isinstance(y_test, MXFP8Tensor) |
| 653 | + y_test = y_test.dequantize() |
| 654 | + |
550 | 655 | # Check values |
551 | | - torch.testing.assert_close(expected_value, ref_split) |
| 656 | + tols = dict(rtol=0, atol=0) # Chunking is exact |
| 657 | + y_test = y_test.to(dtype=torch.float64, device="cpu") |
| 658 | + torch.testing.assert_close(y_test, y_ref, **tols) |
0 commit comments