diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index cd2d85c91c..e0ad09200d 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -32,6 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 8b084ca452..6ade3c6e6b 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -385,28 +385,41 @@ void performTest(const ProcessingMethod processing_method, NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); - nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &in_data_tensor, sizeof(in_data_tensor)); + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &grad_data_tensor, sizeof(grad_data_tensor)); if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; - nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor); + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, + &first_dims_tensor, sizeof(first_dims_tensor)); } if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_}; - nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor); + nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, + &last_dims_tensor, sizeof(last_dims_tensor)); } if (shape_rep != SAME_BOTH_DIMS) { NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_}; - nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor); + nvte_set_grouped_tensor_param(grad_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, + &offsets_tensor, sizeof(offsets_tensor)); } if (rowwise) { @@ -417,8 +430,11 @@ void performTest(const ProcessingMethod processing_method, NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast(otype), logical_shape_}; NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size()); NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_}; - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor); - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor); + nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, + &out_data_rowwise_tensor, sizeof(out_data_rowwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, + &out_scales_rowwise_tensor, sizeof(out_scales_rowwise_tensor)); } if (colwise) { @@ -429,8 +445,12 @@ void performTest(const ProcessingMethod processing_method, NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast(otype), logical_shape_}; NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size()); NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_}; - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor); - nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, + &out_data_colwise_tensor, sizeof(out_data_colwise_tensor)); + nvte_set_grouped_tensor_param(out_group_tensor, + NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, + &out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor)); } Tensor output_dbias("output_dbias", std::vector{ cols }, itype); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index af99d9c42f..b64ae24131 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1157,7 +1157,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, NVTEBasicTensor data_tensor{grouped.data.get(), static_cast(dtype), grouped.logical_shape}; NVTEGroupedTensor h = grouped.handle.get(); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor)); const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype); if (include_columnwise) { @@ -1172,7 +1172,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, NVTEBasicTensor col_tensor{grouped.columnwise_data.get(), static_cast(dtype), grouped.logical_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseData, &col_tensor, sizeof(col_tensor)); } if (!same_first) { @@ -1181,7 +1181,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1); NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedFirstDims, &fd_tensor, sizeof(fd_tensor)); } if (!same_last) { @@ -1190,7 +1190,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1); NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedLastDims, &ld_tensor, sizeof(ld_tensor)); } if (!same_first || !same_last) { @@ -1199,7 +1199,7 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice)); NVTEShape off_shape = nvte_make_shape(&num_tensors, 1); NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor)); } if (isFp8Type(dtype)) { @@ -1213,8 +1213,10 @@ GroupedBuffers build_grouped_tensor(const std::vector& tensors, sizeof(float) * num_tensors, cudaMemcpyHostToDevice)); NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1); NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape}; - nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor); - nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor); + nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor, + sizeof(scale_tensor)); + nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor, + sizeof(scale_tensor)); } return grouped; diff --git a/tests/pytorch/mxfp8/mxfp8_utils.py b/tests/pytorch/mxfp8/mxfp8_utils.py new file mode 100644 index 0000000000..99e088a201 --- /dev/null +++ b/tests/pytorch/mxfp8/mxfp8_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import math + + +# Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization without padding +def get_mxfp8_scale_shape_no_padding(shape, columnwise): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = M // 32 + inner = K + return (outer, inner) + # rowwise + outer = M + inner = K // 32 + return (outer, inner) + + +def _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M == scale.shape[0] + assert input_N // 32 == scale.shape[1] + + x = scale.view(input_M // 128, 4, 32, input_N // 128, 4) + x = x.permute(0, 3, 2, 1, 4) + x = x.contiguous() + # View back as original 2D shape + x = x.view(input_M, input_N // 32) + return x + + +def _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M // 32 == scale.shape[0] + assert input_N == scale.shape[1] + + x = scale.view(input_M // 128, 4, input_N // 128, 4, 32) + x = x.permute(2, 0, 4, 3, 1) + x = x.contiguous() + + # alternative way: transpose the scale and do rowwise swizzle with M, N swapped + x1 = _rowwise_swizzle_mxfp8_scale(input_N, input_M, scale.transpose(0, 1).contiguous()) + torch.testing.assert_close( + x.view(-1), x1.view(-1), atol=0.0, rtol=0.0, msg="columnwise swizzle sanity check failed" + ) + + # View back as original 2D shape + x = x.view(input_M // 32, input_N) + return x + + +def swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor: + if not columnwise: + return _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale) + else: + return _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale) diff --git a/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py new file mode 100644 index 0000000000..3c197bc6f3 --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_group_quantize_graph_safe.py @@ -0,0 +1,475 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer + +import pytest +import torch +import random +import math + +from mxfp8_utils import swizzle_mxfp8_scale, get_mxfp8_scale_shape_no_padding + +recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) + + +def generate_random_multiples_sum(total=8192, n=4, multiple=64): + if total % multiple != 0: + raise ValueError(f"Total ({total}) must be a multiple of {multiple}") + if (total // multiple) < n: + raise ValueError("Total too small for given n and multiple.") + + # Work in units of multiples + total_units = total // multiple + + # choose n−1 random cut points in [1, total_units−1) + cuts = sorted(random.sample(range(1, total_units), n - 1)) + + # convert to segment lengths + parts = ( + [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] + ) + + # convert back to multiples + return [p * multiple for p in parts] + + +def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]: + least_multiple = 128 + num_chunks = 4 + split_sections = None + + avg_split = M // num_chunks + + if M == 0 or N == 0: + # all zeros + return [0] * num_chunks + if edge_cases == "regular": + split_sections = [avg_split] * num_chunks + elif edge_cases == "zero_tokens_all": + split_sections = [0] * num_chunks + elif edge_cases == "zero_tokens_front": + split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] + elif edge_cases == "zero_tokens_end": + split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] + elif edge_cases == "zero_tokens_middle": + split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] + elif edge_cases == "random_uneven_split": + split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple) + else: + raise ValueError(f"Invalid edge case: {edge_cases}") + + # adds up the split_sections to make it M + assert sum(split_sections) == M, "The split_sections do not add up to M" + + # make sure every split_section is a multiple of least_multiple + for split_section in split_sections: + assert ( + split_section % least_multiple == 0 + ), "The split_sections are not multiples of least_multiple" + + return split_sections + + +def reference_group_quantize( + x: torch.Tensor, + quantizers: list[MXFP8Quantizer], + split_sections: list[int], + return_identity: bool, + return_transpose: bool, +) -> torch.Tensor: + x_chunks = torch.split(x, split_sections) + + # rowwise quantization + x_qx = [] + x_sx = [] + # columnwise quantization + x_qx_t = [] + x_sx_t = [] + + for i in range(len(x_chunks)): + x_chunk = x_chunks[i] + x_mxfp8_res = quantizers[i](x_chunk) + if return_identity: + x_qx.append(x_mxfp8_res._rowwise_data.view(dtype=torch.uint8)) + x_sx.append(x_mxfp8_res._rowwise_scale_inv) + else: + x_qx.append(None) + x_sx.append(None) + if return_transpose: + x_qx_t.append(x_mxfp8_res._columnwise_data.view(dtype=torch.uint8)) + x_sx_t.append(x_mxfp8_res._columnwise_scale_inv) + else: + x_qx_t.append(None) + x_sx_t.append(None) + + return x_qx, x_sx, x_qx_t, x_sx_t + + +def fused_grouped_quantize( + x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: MXFP8Quantizer +): + + # view x as a 2D tensor + hidden_dim = x.shape[-1] + x = x.view(-1, hidden_dim) + num_tensors = split_section_tensor.shape[0] + + grouped_output = tex.group_quantize(x, quantizer, num_tensors, split_section_tensor) + + return grouped_output + + +def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: + assert x.shape == y.shape + assert x.dtype == y.dtype + + +def check_grouped_tensor_mxfp8_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + x_splits = torch.split(x, split_sections) + + # Quantize + quantizers = [ + MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( + x, quantizers, split_sections, return_identity, return_transpose + ) + + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + # get a list of MXFP8 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +def check_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + valid_M: int = None, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" + assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input (fill the entire tensor with garbage too) + x = torch.randn((M, N), dtype=x_dtype, device=device) + valid_x = x[:valid_M, :].clone() + x_splits = torch.split(valid_x, split_sections) + + # Quantize + quantizers = [ + MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = reference_group_quantize( + valid_x, quantizers, split_sections, return_identity, return_transpose + ) + + # Note: for grouped quantize with paged stashing + # it's expected that we can just pass in the regular input x, not the valid_x + # the kernel is expected to porcess it correctly by becoming no-op for cuda graph + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + + # get a list of MXFP8 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_mxfp8_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # edge case, zero tokens for all + (0, 512), + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_mxfp8_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + optimize_for_gemm: bool, +) -> None: + + split_sections = generate_split_sections(M, N, edge_cases) + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_mxfp8_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + optimize_for_gemm=optimize_for_gemm, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # M won't be empty in paged stashing + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + # even if buffer is not empty, but the token splits are all zero + "zero_tokens_all", + # partially zero tokens + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + optimize_for_gemm: bool, +) -> None: + + # paged stashing means that the sum of total tokens is less than + # or equal to the buffer size, you can have buffer [2048, 1024] + # and when you only receive 1024 tokens, the last half is garbage + # so input has shape [2048, 1024] + # split sections can be [256, 256, 256, 256], sums to 1024 + valid_M = 0 if edge_cases == "zero_tokens_all" else M // 2 + split_sections = generate_split_sections(valid_M, N, edge_cases) + + # sanity check + if edge_cases == "zero_tokens_all": + assert valid_M == 0, "valid_M must be 0 when edge_cases is zero_tokens_all" + else: + assert valid_M == M // 2, "valid_M must be M // 2 when edge_cases is not zero_tokens_all" + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_mxfp8_with_paged_stashing( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + valid_M=valid_M, + optimize_for_gemm=optimize_for_gemm, + ) diff --git a/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py new file mode 100644 index 0000000000..94ea699d14 --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_quantize_swizzle_fusion.py @@ -0,0 +1,134 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage + +import pytest +import torch +import random +import math + +from typing import Tuple + +from mxfp8_utils import swizzle_mxfp8_scale, get_mxfp8_scale_shape_no_padding + +recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) + + +def unpack_quantized_tensor( + quantized_tensor: MXFP8TensorStorage, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + qx, sx, qx_t, sx_t = None, None, None, None + if quantized_tensor._rowwise_data is not None: + qx = quantized_tensor._rowwise_data.view(dtype=torch.uint8) + if quantized_tensor._rowwise_scale_inv is not None: + sx = quantized_tensor._rowwise_scale_inv + if quantized_tensor._columnwise_data is not None: + qx_t = quantized_tensor._columnwise_data.view(dtype=torch.uint8) + if quantized_tensor._columnwise_scale_inv is not None: + sx_t = quantized_tensor._columnwise_scale_inv + return qx, sx, qx_t, sx_t + + +def check_mxfp8_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, +) -> None: + + te_dtype = tex.DType.kFloat8E4M3 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + quantizer = MXFP8Quantizer( + fp8_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + ) + + quantizer_swizzle_fusion = quantizer.copy() + quantizer_swizzle_fusion.optimize_for_gemm = True + + x_qx_swf, x_sx_swf, x_qx_t_swf, x_sx_t_swf = unpack_quantized_tensor( + quantizer_swizzle_fusion(x) + ) + x_qx_ref, x_sx_ref, x_qx_t_ref, x_sx_t_ref = unpack_quantized_tensor(quantizer(x)) + + if return_identity: + torch.testing.assert_close(x_qx_swf, x_qx_ref, atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x.shape, False) + assert valid_scale_shape == x_sx_swf.shape, ( + "The scale shape is not correctly aligned, this test assumes no padding is needed for" + " scaling factors" + ) + x_sx_ref_swizzled = swizzle_mxfp8_scale(M, N, x_sx_ref, columnwise=False) + torch.testing.assert_close(x_sx_swf, x_sx_ref_swizzled, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(x_qx_t_swf, x_qx_t_ref, atol=0.0, rtol=0.0) + valid_scale_shape = get_mxfp8_scale_shape_no_padding(x.shape, True) + assert valid_scale_shape == x_sx_t_swf.shape, ( + "The scale shape is not correctly aligned, this test assumes no padding is needed for" + " scaling factors" + ) + x_sx_t_ref_swizzled = swizzle_mxfp8_scale(M, N, x_sx_t_ref, columnwise=True) + torch.testing.assert_close(x_sx_t_swf, x_sx_t_ref_swizzled, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +def test_mxfp8_quantize_swizzle_fusion( + x_dtype: torch.dtype, + M: int, + N: int, + quantize_mode: str, +) -> None: + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_mxfp8_quantize_swizzle_fusion( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + ) diff --git a/tests/pytorch/nvfp4/nvfp4_utils.py b/tests/pytorch/nvfp4/nvfp4_utils.py new file mode 100644 index 0000000000..5f1b5ac36c --- /dev/null +++ b/tests/pytorch/nvfp4/nvfp4_utils.py @@ -0,0 +1,159 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer + +import torch +import math +import random + + +# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding +def get_nvfp4_scale_shape_no_padding(shape, columnwise): + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = K + inner = math.ceil(M / 16) + return (outer, inner) + # rowwise + outer = M + inner = math.ceil(K / 16) + return (outer, inner) + + +def _rowwise_swizzle_nvfp4_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor: + assert scale.dim() == 2 + assert input_M == scale.shape[0] + assert input_N // 16 == scale.shape[1] + + x = scale.view(input_M // 128, 4, 32, input_N // 64, 4) + x = x.permute(0, 3, 2, 1, 4) + x = x.contiguous() + # View back as original 2D shape + x = x.view(input_M, input_N // 16) + return x + + +# TN-only layout for NVFP4 means that there is only rowwise swizzle +# just need to switch the M, N which means transposing the input +def swizzle_nvfp4_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor: + if not columnwise: + return _rowwise_swizzle_nvfp4_scale(input_M, input_N, scale) + else: + return _rowwise_swizzle_nvfp4_scale(input_N, input_M, scale) + + +# Helper function to generate random multiples sum +def _generate_random_multiples_sum(total=8192, n=4, multiple=64): + if total % multiple != 0: + raise ValueError(f"Total ({total}) must be a multiple of {multiple}") + if (total // multiple) < n: + raise ValueError("Total too small for given n and multiple.") + + # Work in units of multiples + total_units = total // multiple + + # choose n−1 random cut points in [1, total_units−1) + cuts = sorted(random.sample(range(1, total_units), n - 1)) + + # convert to segment lengths + parts = ( + [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] + ) + + # convert back to multiples + return [p * multiple for p in parts] + + +# Generate split sections for NVFP4 1D blockwise quantization +def generate_split_sections( + M: int, N: int, edge_cases: str, least_multiple: int = 128 +) -> list[int]: + num_chunks = 4 + split_sections = None + + avg_split = M // num_chunks + + if M == 0 or N == 0: + # all zeros + return [0] * num_chunks + if edge_cases == "regular": + split_sections = [avg_split] * num_chunks + elif edge_cases == "zero_tokens_all": + split_sections = [0] * num_chunks + elif edge_cases == "zero_tokens_front": + split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] + elif edge_cases == "zero_tokens_end": + split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] + elif edge_cases == "zero_tokens_middle": + split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] + elif edge_cases == "random_uneven_split": + split_sections = _generate_random_multiples_sum(M, num_chunks, least_multiple) + else: + raise ValueError(f"Invalid edge case: {edge_cases}") + + # adds up the split_sections to make it M + assert sum(split_sections) == M, "The split_sections do not add up to M" + + # make sure every split_section is a multiple of least_multiple + for split_section in split_sections: + assert ( + split_section % least_multiple == 0 + ), "The split_sections are not multiples of least_multiple" + + return split_sections + + +# Reference implementation of group quantization for NVFP4 1D blockwise quantization +def reference_group_quantize( + x: torch.Tensor, + quantizers: list[NVFP4Quantizer], + split_sections: list[int], + return_identity: bool, + return_transpose: bool, +) -> torch.Tensor: + x_view = x.reshape(-1, x.size(-1)) + x_chunks = torch.split(x, split_sections) + + # rowwise quantization + x_qx = [] + x_sx = [] + x_amax_rowwise = [] + # columnwise quantization + x_qx_t = [] + x_sx_t = [] + x_amax_colwise = [] + + for i in range(len(x_chunks)): + x_chunk = x_chunks[i] + x_nvfp4_res = quantizers[i](x_chunk) + if return_identity: + x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) + x_sx.append(x_nvfp4_res._rowwise_scale_inv) + x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) + else: + x_qx.append(None) + x_sx.append(None) + x_amax_rowwise.append(None) + if return_transpose: + x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8)) + x_sx_t.append(x_nvfp4_res._columnwise_scale_inv) + x_amax_colwise.append(x_nvfp4_res._amax_columnwise) + else: + x_qx_t.append(None) + x_sx_t.append(None) + x_amax_colwise.append(None) + + return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise + + +# Function to assert that two tensors have the same shape and dtype +def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: + assert x.shape == y.shape + assert x.dtype == y.dtype diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 01a4a01205..5f35e9ad10 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -23,126 +23,14 @@ import random import math -recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) - - -def generate_random_multiples_sum(total=8192, n=4, multiple=64): - if total % multiple != 0: - raise ValueError(f"Total ({total}) must be a multiple of {multiple}") - if (total // multiple) < n: - raise ValueError("Total too small for given n and multiple.") - - # Work in units of multiples - total_units = total // multiple - - # choose n−1 random cut points in [1, total_units−1) - cuts = sorted(random.sample(range(1, total_units), n - 1)) - - # convert to segment lengths - parts = ( - [cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]] - ) - - # convert back to multiples - return [p * multiple for p in parts] - - -def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]: - least_multiple = 64 - num_chunks = 4 - split_sections = None - - avg_split = M // num_chunks - - if M == 0 or N == 0: - # all zeros - return [0] * num_chunks - if edge_cases == "regular": - split_sections = [avg_split] * num_chunks - elif edge_cases == "zero_tokens_front": - split_sections = [0] + [avg_split] * (num_chunks - 2) + [avg_split * 2] - elif edge_cases == "zero_tokens_end": - split_sections = [avg_split * 2] + [avg_split] * (num_chunks - 2) + [0] - elif edge_cases == "zero_tokens_middle": - split_sections = [avg_split] * (num_chunks - 2) + [0] + [avg_split * 2] - elif edge_cases == "random_uneven_split": - split_sections = generate_random_multiples_sum(M, num_chunks, least_multiple) - else: - raise ValueError(f"Invalid edge case: {edge_cases}") - - # adds up the split_sections to make it M - assert sum(split_sections) == M, "The split_sections do not add up to M" - - # make sure every split_section is a multiple of least_multiple - for split_section in split_sections: - assert ( - split_section % least_multiple == 0 - ), "The split_sections are not multiples of least_multiple" - - return split_sections - - -# Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization without padding -def get_nvfp4_scale_shape_no_padding(shape, columnwise): - M, K = 1, 1 - M = math.prod(shape[:-1]) - K = shape[-1] - - if columnwise: - outer = K - inner = math.ceil(M / 16) - return (outer, inner) - # rowwise - outer = M - inner = math.ceil(K / 16) - return (outer, inner) - - -def reference_group_quantize( - x: torch.Tensor, - quantizers: list[NVFP4Quantizer], - split_sections: list[int], - return_identity: bool, - return_transpose: bool, -) -> torch.Tensor: - x_view = x.reshape(-1, x.size(-1)) - x_chunks = torch.split(x, split_sections) - - # rowwise quantization - x_qx = [] - x_sx = [] - x_amax_rowwise = [] - # columnwise quantization - x_qx_t = [] - x_sx_t = [] - x_amax_colwise = [] - - for i in range(len(x_chunks)): - x_chunk = x_chunks[i] - x_nvfp4_res = quantizers[i](x_chunk) - if return_identity: - x_qx.append(x_nvfp4_res._rowwise_data.view(dtype=torch.uint8)) - x_sx.append(x_nvfp4_res._rowwise_scale_inv) - x_amax_rowwise.append(x_nvfp4_res._amax_rowwise) - else: - x_qx.append(None) - x_sx.append(None) - x_amax_rowwise.append(None) - if return_transpose: - x_qx_t.append(x_nvfp4_res._columnwise_data.view(dtype=torch.uint8)) - x_sx_t.append(x_nvfp4_res._columnwise_scale_inv) - x_amax_colwise.append(x_nvfp4_res._amax_columnwise) - else: - x_qx_t.append(None) - x_sx_t.append(None) - x_amax_colwise.append(None) - - return x_qx, x_sx, x_amax_rowwise, x_qx_t, x_sx_t, x_amax_colwise - +from nvfp4_utils import ( + get_nvfp4_scale_shape_no_padding, + generate_split_sections, + assert_same_shape_and_dtype, + reference_group_quantize, +) -def assert_same_shape_and_dtype(x: torch.Tensor, y: torch.Tensor) -> None: - assert x.shape == y.shape - assert x.dtype == y.dtype +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) def check_group_quantization_nvfp4_versus_reference( @@ -279,7 +167,7 @@ def test_rht_with_quantization_block_tiling_versus_reference( with_rht: bool, ) -> None: - split_sections = generate_split_sections(M, N, edge_cases) + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) # currently disable pre-RHT amax with_post_rht_amax = with_rht diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py new file mode 100644 index 0000000000..1e62f91eb8 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -0,0 +1,451 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor + +import pytest +import torch +import random +import math + +from nvfp4_utils import ( + get_nvfp4_scale_shape_no_padding, + swizzle_nvfp4_scale, + generate_split_sections, + assert_same_shape_and_dtype, + reference_group_quantize, +) + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def fused_grouped_quantize( + x: torch.Tensor, split_section_tensor: torch.Tensor, quantizer: NVFP4Quantizer +): + + # view x as a 2D tensor + hidden_dim = x.shape[-1] + x = x.view(-1, hidden_dim) + num_tensors = split_section_tensor.shape[0] + + grouped_output = tex.group_quantize(x, quantizer, num_tensors, split_section_tensor) + + return grouped_output + + +def check_grouped_tensor_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat4E2M1 + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + num_chunks = len(split_sections) + + x_splits = torch.split(x, split_sections) + + # Quantize + quantizers = [ + NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( + reference_group_quantize(x, quantizers, split_sections, return_identity, return_transpose) + ) + + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + # get a list of nvfp4 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close( + x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close( + x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) + assert ( + valid_scale_shape == x_sx_t[i].shape + ), "The scale shape is not correctly aligned" + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +def check_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + return_identity: bool, + return_transpose: bool, + split_sections: list[int], + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, + valid_M: int = None, + optimize_for_gemm: bool = False, +) -> None: + + te_dtype = tex.DType.kFloat4E2M1 + + assert valid_M is not None, "valid_M must be provided when with_paged_stashing is True" + assert valid_M < M, "valid_M must be less than M when with_paged_stashing is True" + + split_section_tensor = torch.tensor(split_sections, dtype=torch.int64, device="cuda") + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input (fill the entire tensor with garbage too) + x = torch.randn((M, N), dtype=x_dtype, device=device) + valid_x = x[:valid_M, :].clone() + num_chunks = len(split_sections) + + x_splits = torch.split(valid_x, split_sections) + + # Quantize + quantizers = [ + NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=return_identity, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + for _ in range(len(split_sections)) + ] + + grouped_quantizer = quantizers[0].copy() + # configure grouped quantizer with swizzle fusion + # and compare with reference without swizzle fusion + grouped_quantizer.optimize_for_gemm = optimize_for_gemm + + x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = ( + reference_group_quantize( + valid_x, quantizers, split_sections, return_identity, return_transpose + ) + ) + + # Note: for grouped quantize with paged stashing + # it's expected that we can just pass in the regular input x, not the valid_x + # the kernel is expected to porcess it correctly by becoming no-op for cuda graph + group_quantized_output = fused_grouped_quantize(x, split_section_tensor, grouped_quantizer) + + # get a list of nvfp4 quantized tensors for testing + split_quantize_outputs = group_quantized_output.split_into_quantized_tensors() + + if return_identity: + x_qx = [output._rowwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs] + x_sx = [output._rowwise_scale_inv for output in split_quantize_outputs] + x_amax_rowwise = [output._amax_rowwise for output in split_quantize_outputs] + + for i in range(len(x_qx)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_rowwise[i], x_amax_rowwise_ref[i]) + assert_same_shape_and_dtype(x_qx[i], x_qx_ref[i]) + assert_same_shape_and_dtype(x_sx[i], x_sx_ref[i]) + else: + torch.testing.assert_close( + x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False) + assert ( + valid_scale_shape == x_sx[i].shape + ), "The scale shape is not correctly aligned" + x_sx_i = x_sx[i].clone() + x_sx_ref_i = x_sx_ref[i].clone() + if optimize_for_gemm: + x_sx_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_ref_i, columnwise=False + ) + torch.testing.assert_close(x_sx_i, x_sx_ref_i, atol=0.0, rtol=0.0) + + if return_transpose: + x_qx_t = [ + output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs + ] + x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs] + x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs] + # assert with zero tolerance + for i in range(len(x_qx_t)): + if split_sections[i] == 0: + # then just assert the same shape and dtype because the buffer won't be zero out + assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i]) + assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i]) + assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i]) + else: + torch.testing.assert_close( + x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0) + valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True) + x_sx_t_i = x_sx_t[i].clone() + x_sx_t_ref_i = x_sx_t_ref[i].clone() + if optimize_for_gemm: + x_sx_t_ref_i = swizzle_nvfp4_scale( + split_sections[i], N, x_sx_t_ref_i, columnwise=True + ) + torch.testing.assert_close(x_sx_t_i, x_sx_t_ref_i, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # edge case, zero tokens for all + (0, 512), + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +@pytest.mark.parametrize("with_rht", [True], ids=["with_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + with_random_sign_mask: bool, + with_rht: bool, + optimize_for_gemm: bool, +) -> None: + + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=128) + + # currently disable pre-RHT amax + with_post_rht_amax = with_rht + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + optimize_for_gemm=optimize_for_gemm, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # M won't be empty in paged stashing + # full tile cases + (1024, 256), + # larger sizes + (8192, 1024), + (16384, 8192), + (16384, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + # even if buffer is not empty, but the token splits are all zero + "zero_tokens_all", + # partially zero tokens + "zero_tokens_front", + "zero_tokens_end", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "quantize_mode", ["quantize", "quantize_transpose", "quantize_colwise_only"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +@pytest.mark.parametrize("with_rht", [True], ids=["with_rht"]) +@pytest.mark.parametrize( + "optimize_for_gemm", [True, False], ids=["optimize_for_gemm", "no_optimize_for_gemm"] +) +def test_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + quantize_mode: str, + with_random_sign_mask: bool, + with_rht: bool, + optimize_for_gemm: bool, +) -> None: + + # paged stashing means that the sum of total tokens is less than + # or equal to the buffer size, you can have buffer [2048, 1024] + # and when you only receive 1024 tokens, the last half is garbage + # so input has shape [2048, 1024] + # split sections can be [256, 256, 256, 256], sums to 1024 + valid_M = 0 if edge_cases == "zero_tokens_all" else M // 2 + split_sections = generate_split_sections(valid_M, N, edge_cases, least_multiple=128) + + # sanity check + if edge_cases == "zero_tokens_all": + assert valid_M == 0, "valid_M must be 0 when edge_cases is zero_tokens_all" + else: + assert valid_M == M // 2, "valid_M must be M // 2 when edge_cases is not zero_tokens_all" + + # currently disable pre-RHT amax + with_post_rht_amax = with_rht + + if quantize_mode == "quantize": + return_identity = True + return_transpose = False + elif quantize_mode == "quantize_transpose": + return_identity = True + return_transpose = True + elif quantize_mode == "quantize_colwise_only": + return_identity = False + return_transpose = True + else: + raise ValueError(f"Invalid quantize mode: {quantize_mode}") + + check_grouped_tensor_nvfp4_with_paged_stashing( + x_dtype=x_dtype, + M=M, + N=N, + return_identity=return_identity, + return_transpose=return_transpose, + split_sections=split_sections, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + valid_M=valid_M, + optimize_for_gemm=optimize_for_gemm, + ) diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 318009c669..ad08c0474d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -55,7 +55,7 @@ def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: - """Create quantizers for given quantization scheme""" + """Create quantizer for given quantization scheme""" if quantization == "fp8_delayed_scaling": quantizer = Float8Quantizer( @@ -203,12 +203,12 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -260,12 +260,12 @@ def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -300,12 +300,12 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 shape = [(256, 512), (512, 512), (768, 512)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -334,7 +334,7 @@ def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) # Create input tensors input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] @@ -342,7 +342,7 @@ def test_static_quantize_method(self, quantization: str) -> None: # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( tensors=input_tensors, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -361,6 +361,99 @@ def test_static_quantize_method(self, quantization: str) -> None: expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a 2D tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantized_tensors = [ + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors + ] + grouped_input = torch.cat(input_tensors, dim=0) + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + # Quantize using grouped API + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + ) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor in input_tensors: + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_quantize_cudagraph_capturable(self) -> None: + """Ensure group_quantize is CUDA graph capturable.""" + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + torch.cuda.synchronize() + static_input = grouped_input.clone() + static_first_dims = first_dims.clone() + + # Warmup to initialize kernels and allocator state + _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + ) + + fresh_input = torch.cat( + [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], + dim=0, + ) + static_input.copy_(fresh_input) + graph.replay() + torch.cuda.synchronize() + + expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + assert torch.equal(static_output.data, expected.data) + assert torch.equal(static_output.scale_inv, expected.scale_inv) + def test_clear(self) -> None: """Test clear method""" num_tensors = 3 diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 582172a88e..57404ae8a5 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -124,7 +124,8 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, } // Group quantize assumes contiguous inputs and outputs in memory allocation -// TODO (zhongbo): find a better way to make it a more generalized API +// Note: this API assumes knowing split sections from the host, if split information +// comes from D2H copy, it will break cuda graph capture void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, const size_t num_tensors, const NVTEQuantizationConfig quant_config, @@ -134,6 +135,6 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out constexpr bool IS_ACT = false; - dispatch::group_quantize_fwd_helper(input, outputs, split_sections, - num_tensors, quant_config, stream); + dispatch::group_quantize_fwd_host_aware_helper( + input, outputs, split_sections, num_tensors, quant_config, stream); } diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index b83df1dedf..98a3fb8cba 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -308,10 +308,12 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens } } +// Host-aware and not graph-safe: group quantization with split section info from the host. template -void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs, - const size_t *split_sections, const size_t num_tensors, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { +void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *outputs, + const size_t *split_sections, const size_t num_tensors, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; const Tensor *input_tensor = convertNVTETensorCheck(input); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a29a09836e..bdc400d87e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -277,6 +277,30 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel // grouped tensor can be treated as continuous tensor for MXFP8 const size_t tensor_base = is_single_tensor ? 0 : static_cast(offsets_ptr[tensor_id]); + // For grouped tensors represented as a single logical tensor, scale swizzle must still be + // computed per tensor (expert) and then concatenated along dim-0. + const size_t tensor_base_for_scales = (is_single_tensor && num_tensors > 1) + ? static_cast(offsets_ptr[tensor_id]) + : tensor_base; + + // In graph-safe paged stashing, the logical shape can include trailing garbage. Skip CTAs that + // map outside the current tensor's valid [rows, cols] region. + if (rows == 0 || cols == 0) { + return; + } + if (shape_rep != SAME_BOTH_DIMS) { + const size_t tensor_start_offset = static_cast(offsets_ptr[tensor_id]); + const size_t tensor_end_offset = static_cast(offsets_ptr[tensor_id + 1]); + if (block_global_offset >= tensor_end_offset) { + return; + } + const size_t tensor_offset_from_start = block_global_offset - tensor_start_offset; + const size_t block_offset_Y_in_tensor = tensor_offset_from_start / cols; + const size_t block_offset_X_in_tensor = tensor_offset_from_start % cols; + if (block_offset_Y_in_tensor >= rows || block_offset_X_in_tensor >= cols) { + return; + } + } const CUtensorMap &tensor_map_input = is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id]; @@ -481,7 +505,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel size_t scale_idx = 0; if constexpr (WITH_GEMM_SWIZZLED_SCALES) { - scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + const size_t tensor_base_row = tensor_base_for_scales / cols; + const size_t tensor_scales_offset_Y_base = tensor_base_row / SCALE_DIM_Y; + const size_t tensor_scales_offset_colwise_base = tensor_base_for_scales / SCALE_DIM_Y; + const size_t local_scales_offset_Y = global_scales_offset_Y - tensor_scales_offset_Y_base; + scale_idx = tensor_scales_offset_colwise_base + + gemm_swizzled_scale_idx(global_scales_offset_X, local_scales_offset_Y, DIVUP(rows, static_cast(128))); } else { scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; @@ -827,9 +856,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations "First dimension of a grouped tensor should be divisible by 128."); } - const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(output->last_dims.dptr); float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2d7f0e7e8c..8437693c1e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -341,6 +341,21 @@ struct GroupedTensor { */ bool with_gemm_swizzled_scales = false; + /*! Map from NVTEGroupedTensorParam to parameter sizes */ + static constexpr size_t attr_sizes[] = { + sizeof(NVTEBasicTensor), // kNVTEGroupedRowwiseData + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseData + sizeof(NVTEBasicTensor), // kNVTEGroupedScale + sizeof(NVTEBasicTensor), // kNVTEGroupedAmax + sizeof(NVTEBasicTensor), // kNVTEGroupedRowwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseScaleInv + sizeof(NVTEBasicTensor), // kNVTEGroupedColumnwiseAmax + sizeof(NVTEBasicTensor), // kNVTEGroupedFirstDims + sizeof(NVTEBasicTensor), // kNVTEGroupedLastDims + sizeof(NVTEBasicTensor), // kNVTEGroupedTensorOffsets + sizeof(uint8_t) // kNVTEGroupedWithGEMMSwizzledScales + }; + GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) : data(), columnwise_data(), diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 986229aabf..871aeb0373 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -251,6 +251,11 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( // calculate the global offset to get tensor id size_t global_offset = blockIdx.y * CHUNK_DIM_Y * last_logical_dim; + // paged stashing: will have input buffer [M, N], where M is larger than sum(first_dims) + // also need to early return if this CTA is processing a region larger than the last offsets[num_tensors] + if (global_offset >= offsets_ptr[num_tensors]) { + return; + } int tensor_id = get_current_tensor_id(shape_rep, num_tensors, global_offset, first_logical_dim, last_logical_dim, offsets_ptr); output_pre_rht_amax_ptr = static_cast(amax_rowwise_ptr) + tensor_id; @@ -440,9 +445,8 @@ void group_hadamard_transform_amax_graph_safe(const GroupedTensor* input, Groupe float* const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); float* const amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); - const int64_t* const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t* const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - // const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + const int64_t* const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t* const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); // some sanity checks if (all_return_pre_rht_amax) { diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu index 030dddfce4..19583b3afb 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu @@ -1428,9 +1428,8 @@ void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input, float *const amax_rowwise_base_ptr = reinterpret_cast(output->amax.dptr); float *const amax_colwise_base_ptr = reinterpret_cast(output->columnwise_amax.dptr); - const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - // const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS || shape_rep == ShapeRepresentation::VARYING_FIRST_DIM); @@ -1441,7 +1440,7 @@ void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input, int k_tile_size = 1024; - const bool use_swizzle_sf_output = false; + const bool use_swizzle_sf_output = output->with_gemm_swizzled_scales; TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, kEnableStochasticRounding, diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..e316f8be8c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -449,6 +449,8 @@ enum NVTEGroupedTensorParam { kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */ kNVTEGroupedTensorOffsets = 9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */ + kNVTEGroupedWithGEMMSwizzledScales = + 10, /*!< Whether scaling factors are in format expected by GEMM */ kNVTENumGroupedTensorParams }; @@ -479,25 +481,30 @@ NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_ void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Set a parameter of the grouped tensor. +/*! \brief Set a grouped tensor parameter. * - * \param[in/out] tensor Grouped tensor. - * \param[in] param_name The parameter to be set. - * \param[in] param The value to be set (NVTEBasicTensor). + * \param[in/out] tensor Grouped tensor. + * \param[in] param Grouped tensor parameter type. + * \param[in] buf Memory address to read parameter value. + * \param[in] size_in_bytes Size of buf. */ -void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, - const NVTEBasicTensor *param); +void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ -/*! \brief Get a value of the parameter of the grouped tensor. - * - * \param[in] tensor Grouped tensor. - * \param[in] param_name The parameter to be queried. +/*! \brief Query a grouped tensor parameter. * - * \return NVTEBasicTensor containing the parameter data. + * \param[in] tensor Grouped tensor. + * \param[in] param Grouped tensor parameter type. + * \param[out] buf Memory address to write parameter value. + * Ignored if NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. */ -NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name); +void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written); /* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */ /*! \brief Get the number of tensors in a grouped tensor. @@ -957,8 +964,235 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; -/*! \warning Deprecated */ -enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(tensor_, param, &data, sizeof(data)); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) { + const auto val = static_cast(with_gemm_swizzled_scales); + nvte_set_grouped_tensor_param(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val)); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + NVTEBasicTensor ret; + nvte_get_grouped_tensor_param(tensor_, param, &ret, sizeof(ret), nullptr); + return ret; + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + bool get_with_gemm_swizzled_scales() const { + uint8_t val = 0; + nvte_get_grouped_tensor_param(tensor_, kNVTEGroupedWithGEMMSwizzledScales, &val, sizeof(val), + nullptr); + return static_cast(val); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + +/*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ +enum class Float8BlockScaleTensorFormat { + /*! FP8 data is transposed if needed and scales are swizzled */ + GEMM_READY = 0, + /*! FP8 data is untransposed and scales are not swizzled or padded */ + COMPACT = 1, + INVALID +}; /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..cd02074fbd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1145,8 +1145,8 @@ NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_ NVTEShape logical_shape) { NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0"); NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D"); - NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0, - "Logical shape must have positive dimensions"); + // NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0, + // "Logical shape must have positive dimensions"); NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate( scaling_mode, num_tensors, logical_shape); return ret; @@ -1156,88 +1156,178 @@ void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) { transformer_engine::GroupedTensorAllocator::instance().Free(tensor); } -void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name, - const NVTEBasicTensor *param) { +void nvte_set_grouped_tensor_param(NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + const void *buf, size_t size_in_bytes) { + using namespace transformer_engine; + + // Check attribute and buffer + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL."); - auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor); - NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated."); - NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL."); + auto &t = *convertNVTEGroupedTensorCheck(tensor); + const auto &attr_size = GroupedTensor::attr_sizes[param]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); - switch (param_name) { - case kNVTEGroupedRowwiseData: - t->data = *param; + // Read from buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.data = *basic_tensor; break; - case kNVTEGroupedColumnwiseData: - t->columnwise_data = *param; + } + case kNVTEGroupedColumnwiseData: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_data = *basic_tensor; break; - case kNVTEGroupedScale: - t->scale = *param; + } + case kNVTEGroupedScale: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale = *basic_tensor; break; - case kNVTEGroupedAmax: - t->amax = *param; + } + case kNVTEGroupedAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.amax = *basic_tensor; break; - case kNVTEGroupedRowwiseScaleInv: - t->scale_inv = *param; + } + case kNVTEGroupedRowwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.scale_inv = *basic_tensor; break; - case kNVTEGroupedColumnwiseScaleInv: - t->columnwise_scale_inv = *param; + } + case kNVTEGroupedColumnwiseScaleInv: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_scale_inv = *basic_tensor; break; - case kNVTEGroupedColumnwiseAmax: - t->columnwise_amax = *param; + } + case kNVTEGroupedColumnwiseAmax: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.columnwise_amax = *basic_tensor; break; - case kNVTEGroupedFirstDims: - t->first_dims = *param; - // Validate it's Int64 - NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64, - "first_dims must have dtype Int64"); + } + case kNVTEGroupedFirstDims: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.first_dims = *basic_tensor; + NVTE_CHECK(t.first_dims.dtype == DType::kInt64, "first_dims must have dtype Int64"); break; - case kNVTEGroupedLastDims: - t->last_dims = *param; - // Validate it's Int64 - NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64, - "last_dims must have dtype Int64"); + } + case kNVTEGroupedLastDims: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.last_dims = *basic_tensor; + NVTE_CHECK(t.last_dims.dtype == DType::kInt64, "last_dims must have dtype Int64"); break; - case kNVTEGroupedTensorOffsets: - t->tensor_offsets = *param; - // Validate it's Int64 - NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64, - "tensor_offsets must have dtype Int64"); + } + case kNVTEGroupedTensorOffsets: { + const NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + t.tensor_offsets = *basic_tensor; + NVTE_CHECK(t.tensor_offsets.dtype == DType::kInt64, "tensor_offsets must have dtype Int64"); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; default: - NVTE_ERROR("Unknown grouped tensor parameter!"); + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); } } -NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, - NVTEGroupedTensorParam param_name) { - if (tensor == nullptr) { - return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)}; +void nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor, NVTEGroupedTensorParam param, + void *buf, size_t size_in_bytes, size_t *size_written) { + using namespace transformer_engine; + + // Check param + NVTE_CHECK(param < kNVTENumGroupedTensorParams, "Invalid NVTEGroupedTensorParam (got ", + static_cast(param), ")"); + + // Write attribute size if provided + const auto &attr_size = GroupedTensor::attr_sizes[param]; + if (size_written != nullptr) { + *size_written = attr_size; } - const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor); - switch (param_name) { - case kNVTEGroupedRowwiseData: - return t.data; - case kNVTEGroupedColumnwiseData: - return t.columnwise_data; - case kNVTEGroupedScale: - return t.scale; - case kNVTEGroupedAmax: - return t.amax; - case kNVTEGroupedRowwiseScaleInv: - return t.scale_inv; - case kNVTEGroupedColumnwiseScaleInv: - return t.columnwise_scale_inv; - case kNVTEGroupedColumnwiseAmax: - return t.columnwise_amax; - case kNVTEGroupedFirstDims: - return t.first_dims; - case kNVTEGroupedLastDims: - return t.last_dims; - case kNVTEGroupedTensorOffsets: - return t.tensor_offsets; + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for grouped tensor parameter " + "(parameter ", + static_cast(param), " needs ", attr_size, " bytes, but buffer has ", + size_in_bytes, " bytes)"); + + // Get C++ grouped tensor + const GroupedTensor *t = convertNVTEGroupedTensor(tensor); + std::optional dummy; + if (t == nullptr) { + // Make dummy grouped tensor if provided tensor is invalid + dummy.emplace(NVTE_DELAYED_TENSOR_SCALING, 1); + t = &(*dummy); + } + + // Write to buffer + switch (param) { + case kNVTEGroupedRowwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->data); + break; + } + case kNVTEGroupedColumnwiseData: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_data); + break; + } + case kNVTEGroupedScale: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale); + break; + } + case kNVTEGroupedAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->amax); + break; + } + case kNVTEGroupedRowwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->scale_inv); + break; + } + case kNVTEGroupedColumnwiseScaleInv: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_scale_inv); + break; + } + case kNVTEGroupedColumnwiseAmax: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->columnwise_amax); + break; + } + case kNVTEGroupedFirstDims: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->first_dims); + break; + } + case kNVTEGroupedLastDims: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->last_dims); + break; + } + case kNVTEGroupedTensorOffsets: { + NVTEBasicTensor *basic_tensor = reinterpret_cast(buf); + *basic_tensor = static_cast(t->tensor_offsets); + break; + } + case kNVTEGroupedWithGEMMSwizzledScales: + *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); + break; default: - NVTE_ERROR("Unknown grouped tensor parameter!"); + NVTE_ERROR("Unsupported grouped tensor parameter (", static_cast(param), ")"); } } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..6aab9938b3 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -103,6 +103,12 @@ class Quantizer { virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; + /*! @brief Construct a grouped tensor with uninitialized data */ + virtual std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const = 0; + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -138,6 +144,11 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; @@ -164,6 +175,11 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -196,6 +212,11 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -253,6 +274,11 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -274,6 +300,11 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -308,6 +339,11 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0ea3d6b78..8f6189fc8d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -250,6 +250,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5c9d0f5b07..f8f793f036 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,157 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +namespace { + +// helper functions for NVFP4 grouped quantization (cuda graph safe with shapes stored in device without D2H copy) +void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, + GroupedTensorWrapper &grouped_output_tensor, + NVFP4Quantizer *nvfp4_quantizer_cpp, cudaStream_t stream) { + size_t num_tensors = grouped_input_tensor.num_tensors(); + + // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet + NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, + "2D scaling grouped quant kernel is not ready yet"); + + auto quant_config_cpp = QuantizationConfigWrapper(); + + // stochastic rounding + bool need_stochastic_rounding = nvfp4_quantizer_cpp->stochastic_rounding; + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + at::Tensor rng_states_tensor; // Declare tensor outside, do not allocate yet + TensorWrapper te_rng_state; + + if (need_stochastic_rounding) { + // in fused kernel, one rng state will be used by the grouped kernel to generate random + // number for different tensors in the group, so we only need to allocate one rng state + const size_t rng_elts_per_thread = 1024 * num_tensors; + rng_states_tensor = torch::empty({2}, opts); + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + philox_unpack(philox_args, static_cast(rng_states_tensor.data_ptr())); + + te_rng_state = makeTransformerEngineTensor(rng_states_tensor); + quant_config_cpp.set_rng_state(te_rng_state.data()); + quant_config_cpp.set_stochastic_rounding(true); + } + + // fast math + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + quant_config_cpp.set_use_fast_math(true); + } + + // so far, only the RHT path has grouped kernel support + // grouped kernels for non-RHT path will be added later + + if (nvfp4_quantizer_cpp->with_rht) { + // post-RHT amax or not + if (nvfp4_quantizer_cpp->with_post_rht_amax) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_hadamard_transform_amax_graph_safe( + grouped_input_tensor.data(), grouped_output_tensor.data(), 0, + nvfp4_quantizer_cpp->rht_matrix_random_sign_mask_t, stream); + }); + } else { + NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); + } + + // RHT cast fusion + auto tile_scheduler_workspace_torch = + at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32)); + auto nvte_tile_scheduler_workspace = + makeTransformerEngineTensor(tile_scheduler_workspace_torch); + + auto rht_matrix_nvte = makeTransformerEngineTensor(nvfp4_quantizer_cpp->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_hadamard_transform_cast_fusion_graph_safe( + grouped_input_tensor.data(), grouped_output_tensor.data(), rht_matrix_nvte.data(), + quant_config_cpp, nvte_tile_scheduler_workspace.data(), stream); + }); + + } else { + NVTE_ERROR("graph safe grouped quant kernel for non-RHT path is not ready yet"); + } +} + +} // namespace + +// NOTE: Only supports varying first dim. +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + + NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); + + std::vector logical_shape; + for (const auto &d : tensor.sizes()) { + logical_shape.push_back(d); + } + const auto logical_first_dim = logical_shape[0]; + const auto logical_last_dim = logical_shape[1]; + + bool empty_input_buffer = logical_first_dim == 0 || logical_last_dim == 0; + + auto quantizer_cpp = convert_quantizer(quantizer); + + // Create input GroupedTensor. + auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); + grouped_input_tensor.set_rowwise_data( + tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); + + // Create output GroupedTensor. + auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( + num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), + py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, + logical_last_dim); + + // dispatch to scaling methods + enum class GroupedQuantizationMode { + MXFP8_GROUPED_QUANTIZE, + NVFP4_GROUPED_QUANTIZE, + INVALID_FOR_GROUPED_QUANTIZE + }; + GroupedQuantizationMode grouped_quantization_mode = + GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE; + if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + grouped_quantization_mode = GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + grouped_quantization_mode = GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE; + } + + if (empty_input_buffer) { + // early return for empty input buffer + // just return the output tensor as is + // no need to quantize + return py::reinterpret_borrow(grouped_output_py); + } + + switch (grouped_quantization_mode) { + case GroupedQuantizationMode::NVFP4_GROUPED_QUANTIZE: { + // NVFP4 grouped quantization + NVFP4Quantizer *nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + group_quantize_nvfp4_impl(grouped_input_tensor, grouped_output_tensor_cpp, + nvfp4_quantizer_cpp, at::cuda::getCurrentCUDAStream()); + break; + } + case GroupedQuantizationMode::MXFP8_GROUPED_QUANTIZE: { + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + break; + } + case GroupedQuantizationMode::INVALID_FOR_GROUPED_QUANTIZE: + default: + NVTE_ERROR("group_quantize: only support NVFP4 or MXFP8 quantizer."); + break; + } + + return py::reinterpret_borrow(grouped_output_py); +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1e907d9bc0..5e9eccced0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,6 +35,7 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +PyTypeObject *GroupedTensorStoragePythonClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -104,11 +105,22 @@ void init_nvfp4_extensions() { "Internal error: could not initialize pyTorch NVFP4 extension."); } +void init_grouped_tensor_extension() { + if (GroupedTensorStoragePythonClass) return; + auto grouped_tensor_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor"); + GroupedTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); + NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); + init_grouped_tensor_extension(); } } // namespace transformer_engine::pytorch @@ -121,7 +133,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); - + m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), + py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..059eb5e3fb 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -43,6 +43,7 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *GroupedTensorStoragePythonClass; void init_extension(); @@ -95,6 +96,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..e715d8f5ba 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -42,6 +42,35 @@ std::vector convert_shape_for_fp4(const std::vector& shape) { return ret; } +std::optional build_grouped_tensor_offsets(const size_t num_tensors, + const std::optional& first_dims, + const size_t logical_last_dim) { + if (!first_dims.has_value()) { + return std::nullopt; + } + + const auto& first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); + NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, + "first_dims must have length ", num_tensors, "."); + + const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); + auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64; + + // Single kernel needed for these ops. + auto cumsum = at::cumsum(scaled_first_dims, 0); + auto zero = at::zeros({1}, cumsum.options()); + return at::cat({zero, cumsum}); +} + +at::TensorOptions grouped_tensor_data_options(const DType dtype) { + return at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); +} + +py::object maybe_tensor_to_py(const std::optional& tensor) { + return tensor ? py::cast(*tensor) : py::none(); +} + } // namespace constexpr size_t NVFP4_BLOCK_SIZE = 16; @@ -88,6 +117,60 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + std::optional rowwise_data; + std::optional columnwise_data; + const bool with_rowwise_data = rowwise_usage; + const bool with_columnwise_data = columnwise_usage; + if (with_rowwise_data) { + rowwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + if (with_columnwise_data) { + columnwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (with_rowwise_data) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, getTensorShape(*rowwise_data)); + } + if (with_columnwise_data) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, + getTensorShape(*columnwise_data)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = py::none(), + "columnwise_scale_inv"_a = py::none(), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -184,6 +267,73 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); @@ -390,6 +540,75 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8CurrentScalingQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor scale = at::empty({static_cast(num_tensors)}, float_opts); + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_scale(scale.data_ptr(), DType::kFloat32, getTensorShape(scale)); + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, + "columnwise_amax"_a = py::none(), "scale"_a = scale, + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, DType dtype, @@ -638,6 +857,77 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8BlockQuantizer::convert_and_update_tensor( py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); @@ -940,6 +1230,78 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); @@ -1240,6 +1602,90 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair NVFP4Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + NVTE_CHECK(total_elements % 2 == 0, "NVFP4 data size must be divisible by 2."); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + std::optional rowwise_amax; + std::optional columnwise_amax; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + const int64_t total_data_elements = total_elements / 2; + + if (rowwise_usage) { + rowwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + columnwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(rowwise_amax->data_ptr(), DType::kFloat32, getTensorShape(*rowwise_amax)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(columnwise_amax->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_amax)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + out_cpp.set_with_gemm_swizzled_scales(this->optimize_for_gemm); + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), + "amax"_a = maybe_tensor_to_py(rowwise_amax), + "columnwise_amax"_a = maybe_tensor_to_py(columnwise_amax), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( TensorWrapper& quantized_tensor, DType dtype) { // Construct tensor diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..eda5e8fc54 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,121 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer"); + if (!quantizer.is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + } + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index dad4d1d0ea..bf5792ffc9 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -52,7 +52,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shape: List[Tuple[int, int]], + shape: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -245,6 +245,7 @@ def clear(self) -> None: """ Reset tensor data and clear all buffers. """ + self.shape = None self.data = None self.columnwise_data = None self.scale_inv = None @@ -452,8 +453,7 @@ def make_grouped_tensor( scale_inv_shape = quantizer.get_scale_shape(s, False) scale_elements = math.prod(scale_inv_shape) total_scale_elements += scale_elements - if i < num_tensors - 1: - scale_inv_offsets.append(total_scale_elements) + scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) if columnwise_usage: @@ -466,8 +466,7 @@ def make_grouped_tensor( scale_inv_shape = quantizer.get_scale_shape(s, False) columnwise_scale_elements = math.prod(scale_inv_shape) total_columnwise_scale_elements += columnwise_scale_elements - if i < num_tensors - 1: - columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) @@ -477,16 +476,16 @@ def make_grouped_tensor( data = torch.empty(total_elements, dtype=torch.uint8, device=device) # Scale inverse - one per tensor scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) - # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 - scale_inv_offsets = list(range(num_tensors)) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + scale_inv_offsets = list(range(num_tensors + 1)) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8) columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) # Columnwise scale inverse - one per tensor columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) - # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 - columnwise_scale_inv_offsets = list(range(num_tensors)) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + columnwise_scale_inv_offsets = list(range(num_tensors + 1)) # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) @@ -502,8 +501,7 @@ def make_grouped_tensor( for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) - if i < num_tensors - 1: - scale_inv_offsets.append(total_scale_elements) + scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) # Amax buffer - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) @@ -519,8 +517,7 @@ def make_grouped_tensor( for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) - if i < num_tensors - 1: - columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) @@ -537,8 +534,7 @@ def make_grouped_tensor( for i, s in enumerate(shape): scale_inv_shape = quantizer.get_scale_shape(s, False) total_scale_elements += math.prod(scale_inv_shape) - if i < num_tensors - 1: - scale_inv_offsets.append(total_scale_elements) + scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) if columnwise_usage: @@ -550,8 +546,7 @@ def make_grouped_tensor( for i, s in enumerate(shape): columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) - if i < num_tensors - 1: - columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.float32, device=device ) @@ -562,16 +557,16 @@ def make_grouped_tensor( data = torch.empty(total_elements, dtype=torch.uint8, device=device) # Scale inverse - one per tensor scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) - # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 - scale_inv_offsets = list(range(num_tensors)) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + scale_inv_offsets = list(range(num_tensors + 1)) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8) columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) # Columnwise scale inverse - one per tensor columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) - # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 - columnwise_scale_inv_offsets = list(range(num_tensors)) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors + columnwise_scale_inv_offsets = list(range(num_tensors + 1)) # Scale and amax buffers for current scaling - one per tensor scale = torch.empty(num_tensors, dtype=torch.float32, device=device) @@ -615,6 +610,8 @@ def split_into_quantized_tensors( If quantizer.internal is True, returns QuantizedTensorStorage. Otherwise, returns QuantizedTensor. + This API is NOT graph safe, but can be used for testing & debugging. + TODO(ksivaman): Block cases where any dims are varying. This is needed only to expose the weights as separate parameters. """ @@ -623,6 +620,27 @@ def split_into_quantized_tensors( no_quantization = self.quantizer is None + # if self.shape is None, then trigger D2H copy and get the shape (not graph safe) + if self.shape is None: + first_dims_list = ( + [self.logical_shape[0]] * self.num_tensors + if self.first_dims is None + else self.first_dims.tolist() + ) + last_dims_list = ( + [self.logical_shape[1]] * self.num_tensors + if self.last_dims is None + else self.last_dims.tolist() + ) + shape_list = [] + for i in range(self.num_tensors): + shape_list.append((first_dims_list[i], last_dims_list[i])) + self.shape = shape_list + + # edge case: handle the case where tensor_offsets is given but offsets is not set + if self.offsets is None and self.tensor_offsets is not None: + self.offsets = self.tensor_offsets.tolist() + # Case 1: No quantization - return regular torch tensors if no_quantization: for i in range(self.num_tensors): @@ -667,6 +685,18 @@ def split_into_quantized_tensors( # Case 2: Quantized tensors recipe = self.quantizer._get_compatible_recipe() + # populate scale_inv_offsets from the tensor offsets + if self.scale_inv is not None and self.scale_inv_offsets is None: + if recipe.nvfp4(): + self.scale_inv_offsets = self.tensor_offsets // 16 + if recipe.mxfp8(): + self.scale_inv_offsets = self.tensor_offsets // 32 + if self.columnwise_scale_inv is not None and self.columnwise_scale_inv_offsets is None: + if recipe.nvfp4(): + self.columnwise_scale_inv_offsets = self.tensor_offsets // 16 + if recipe.mxfp8(): + self.columnwise_scale_inv_offsets = self.tensor_offsets // 32 + for i in range(self.num_tensors): # Get tensor shape tensor_shape = self.shape[i] @@ -716,10 +746,8 @@ def split_into_quantized_tensors( if self.scale_inv is not None and self.scale_inv_offsets is not None: scale_start = self.scale_inv_offsets[i] - if i < self.num_tensors - 1: - scale_end = self.scale_inv_offsets[i + 1] - else: - scale_end = self.scale_inv.numel() + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] # Calculate expected scale shape for MXFP8 scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) @@ -730,10 +758,8 @@ def split_into_quantized_tensors( and self.columnwise_scale_inv_offsets is not None ): cscale_start = self.columnwise_scale_inv_offsets[i] - if i < self.num_tensors - 1: - cscale_end = self.columnwise_scale_inv_offsets[i + 1] - else: - cscale_end = self.columnwise_scale_inv.numel() + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( @@ -788,10 +814,8 @@ def split_into_quantized_tensors( if self.scale_inv is not None and self.scale_inv_offsets is not None: scale_start = self.scale_inv_offsets[i] - if i < self.num_tensors - 1: - scale_end = self.scale_inv_offsets[i + 1] - else: - scale_end = self.scale_inv.numel() + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) @@ -802,10 +826,8 @@ def split_into_quantized_tensors( and self.columnwise_scale_inv_offsets is not None ): cscale_start = self.columnwise_scale_inv_offsets[i] - if i < self.num_tensors - 1: - cscale_end = self.columnwise_scale_inv_offsets[i + 1] - else: - cscale_end = self.columnwise_scale_inv.numel() + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) @@ -844,10 +866,8 @@ def split_into_quantized_tensors( if self.scale_inv is not None and self.scale_inv_offsets is not None: scale_start = self.scale_inv_offsets[i] - if i < self.num_tensors - 1: - scale_end = self.scale_inv_offsets[i + 1] - else: - scale_end = self.scale_inv.numel() + # for paged stashing, scale_inv should depend on the split offsets + scale_end = self.scale_inv_offsets[i + 1] # Get scale shape from quantizer scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) @@ -858,10 +878,8 @@ def split_into_quantized_tensors( and self.columnwise_scale_inv_offsets is not None ): cscale_start = self.columnwise_scale_inv_offsets[i] - if i < self.num_tensors - 1: - cscale_end = self.columnwise_scale_inv_offsets[i + 1] - else: - cscale_end = self.columnwise_scale_inv.numel() + # for paged stashing, columnwise_scale_inv should depend on the split offsets + cscale_end = self.columnwise_scale_inv_offsets[i + 1] # Get columnwise scale shape from quantizer cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True)