Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman Jan 15, 2026
66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 15, 2026
40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman Jan 16, 2026
cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 16, 2026
759e7bb
C++ utils, untested
ksivaman Jan 16, 2026
1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 Jan 23, 2026
e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 2, 2026
3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 Feb 4, 2026
ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 4, 2026
4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
5ab30f5
Fix make grouped tensor api
ksivaman Feb 5, 2026
05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 6, 2026
68ce836
Fixes to tests
ksivaman Feb 6, 2026
3e7859c
PyTorch-Python GroupedTensor
ksivaman Feb 6, 2026
53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 8, 2026
d57651d
Fix test
ksivaman Feb 9, 2026
bd41fd0
All tests pass
ksivaman Feb 9, 2026
351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman Feb 9, 2026
24cfd8c
Remove mxfp8 gq test
ksivaman Feb 9, 2026
97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman Feb 9, 2026
82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
e1788b3
Compiles
ksivaman Feb 10, 2026
9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
11095a9
Fix runtime failure for test
ksivaman Feb 10, 2026
373d9e3
Fix IMA in mxfp8 GQ
ksivaman Feb 10, 2026
1601960
Add CG test for grouped_quantize
ksivaman Feb 10, 2026
bd57000
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
91ab416
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
a5de7a5
Fix device test
ksivaman Feb 11, 2026
77fa728
Disable grouped weights for unsupported recipes
ksivaman Feb 11, 2026
9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman Feb 11, 2026
bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
6b0c420
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
e3278dd
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 12, 2026
864c484
Integrate NVFP4 Graph Safe Group Quantize (#14)
zhongbozhu Feb 14, 2026
4ee0339
improve mxfp8 unit test
zhongbozhu Feb 17, 2026
9f5f24c
pre-swizzle nvfp4 mxfp8 for MoE
zhongbozhu Feb 18, 2026
22f8a5b
avoid having nvte_get_grouped_tensor_param_v2
zhongbozhu Feb 18, 2026
63e1563
more tests
zhongbozhu Feb 18, 2026
4d66324
fix group quantize mxfp8 kernel
zhongbozhu Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_mxfp8.xml $TE_PATH/tests/pytorch/mxfp8 || test_fail "test_mxfp8"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
Expand Down
50 changes: 35 additions & 15 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -385,28 +385,41 @@ void performTest(const ProcessingMethod processing_method,

NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
&in_data_tensor, sizeof(in_data_tensor));
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
&grad_data_tensor, sizeof(grad_data_tensor));

if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims, &first_dims_tensor);
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
&first_dims_tensor, sizeof(first_dims_tensor));
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
&first_dims_tensor, sizeof(first_dims_tensor));
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedFirstDims,
&first_dims_tensor, sizeof(first_dims_tensor));
}

if ((shape_rep == VARYING_LAST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor last_dims_tensor = {last_dims_d, kNVTEInt64, last_dims_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims, &last_dims_tensor);
nvte_set_grouped_tensor_param(grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
&last_dims_tensor, sizeof(last_dims_tensor));
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
&last_dims_tensor, sizeof(last_dims_tensor));
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedLastDims,
&last_dims_tensor, sizeof(last_dims_tensor));
}

if (shape_rep != SAME_BOTH_DIMS) {
NVTEBasicTensor offsets_tensor = {offsets_d, kNVTEInt64, offsets_shape_};
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets, &offsets_tensor);
nvte_set_grouped_tensor_param(grad_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
&offsets_tensor, sizeof(offsets_tensor));
nvte_set_grouped_tensor_param(in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
&offsets_tensor, sizeof(offsets_tensor));
nvte_set_grouped_tensor_param(out_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedTensorOffsets,
&offsets_tensor, sizeof(offsets_tensor));
}

if (rowwise) {
Expand All @@ -417,8 +430,11 @@ void performTest(const ProcessingMethod processing_method,
NVTEBasicTensor out_data_rowwise_tensor = {out_data_rowwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_rowwise_shape_ = nvte_make_shape(scales_rowwise_shape.data(), scales_rowwise_shape.size());
NVTEBasicTensor out_scales_rowwise_tensor = {out_scales_rowwise_d, NVTEDType::kNVTEFloat8E8M0, scales_rowwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &out_data_rowwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv, &out_scales_rowwise_tensor);
nvte_set_grouped_tensor_param(out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData,
&out_data_rowwise_tensor, sizeof(out_data_rowwise_tensor));
nvte_set_grouped_tensor_param(out_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedRowwiseScaleInv,
&out_scales_rowwise_tensor, sizeof(out_scales_rowwise_tensor));
}

if (colwise) {
Expand All @@ -429,8 +445,12 @@ void performTest(const ProcessingMethod processing_method,
NVTEBasicTensor out_data_colwise_tensor = {out_data_colwise_d, static_cast<NVTEDType>(otype), logical_shape_};
NVTEShape scales_colwise_shape_ = nvte_make_shape(scales_colwise_shape.data(), scales_colwise_shape.size());
NVTEBasicTensor out_scales_colwise_tensor = {out_scales_colwise_d, NVTEDType::kNVTEFloat8E8M0, scales_colwise_shape_};
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData, &out_data_colwise_tensor);
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
nvte_set_grouped_tensor_param(out_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedColumnwiseData,
&out_data_colwise_tensor, sizeof(out_data_colwise_tensor));
nvte_set_grouped_tensor_param(out_group_tensor,
NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv,
&out_scales_colwise_tensor, sizeof(out_scales_colwise_tensor));
}

Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
Expand Down
16 changes: 9 additions & 7 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,

NVTEBasicTensor data_tensor{grouped.data.get(), static_cast<NVTEDType>(dtype), grouped.logical_shape};
NVTEGroupedTensor h = grouped.handle.get();
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseData, &data_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseData, &data_tensor, sizeof(data_tensor));

const bool include_columnwise = isFp8Type(dtype) || isFp4Type(dtype);
if (include_columnwise) {
Expand All @@ -1172,7 +1172,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
NVTEBasicTensor col_tensor{grouped.columnwise_data.get(),
static_cast<NVTEDType>(dtype),
grouped.logical_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseData, &col_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseData, &col_tensor, sizeof(col_tensor));
}

if (!same_first) {
Expand All @@ -1181,7 +1181,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape fd_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor fd_tensor{grouped.first_dims_dev.get(), kNVTEInt64, fd_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedFirstDims, &fd_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedFirstDims, &fd_tensor, sizeof(fd_tensor));
}

if (!same_last) {
Expand All @@ -1190,7 +1190,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape ld_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor ld_tensor{grouped.last_dims_dev.get(), kNVTEInt64, ld_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedLastDims, &ld_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedLastDims, &ld_tensor, sizeof(ld_tensor));
}

if (!same_first || !same_last) {
Expand All @@ -1199,7 +1199,7 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
num_tensors * sizeof(int64_t), cudaMemcpyHostToDevice));
NVTEShape off_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor off_tensor{grouped.offsets_dev.get(), kNVTEInt64, off_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedTensorOffsets, &off_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedTensorOffsets, &off_tensor, sizeof(off_tensor));
}

if (isFp8Type(dtype)) {
Expand All @@ -1213,8 +1213,10 @@ GroupedBuffers build_grouped_tensor(const std::vector<Tensor*>& tensors,
sizeof(float) * num_tensors, cudaMemcpyHostToDevice));
NVTEShape scale_shape = nvte_make_shape(&num_tensors, 1);
NVTEBasicTensor scale_tensor{grouped.scale_inv.get(), kNVTEFloat32, scale_shape};
nvte_set_grouped_tensor_param(&h, kNVTEGroupedRowwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(&h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor);
nvte_set_grouped_tensor_param(h, kNVTEGroupedRowwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
nvte_set_grouped_tensor_param(h, kNVTEGroupedColumnwiseScaleInv, &scale_tensor,
sizeof(scale_tensor));
}

return grouped;
Expand Down
62 changes: 62 additions & 0 deletions tests/pytorch/mxfp8/mxfp8_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
import math


# Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization without padding
def get_mxfp8_scale_shape_no_padding(shape, columnwise):
M, K = 1, 1
M = math.prod(shape[:-1])
K = shape[-1]

if columnwise:
outer = M // 32
inner = K
return (outer, inner)
# rowwise
outer = M
inner = K // 32
return (outer, inner)


def _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor:
assert scale.dim() == 2
assert input_M == scale.shape[0]
assert input_N // 32 == scale.shape[1]

x = scale.view(input_M // 128, 4, 32, input_N // 128, 4)
x = x.permute(0, 3, 2, 1, 4)
x = x.contiguous()
# View back as original 2D shape
x = x.view(input_M, input_N // 32)
return x


def _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor) -> torch.Tensor:
assert scale.dim() == 2
assert input_M // 32 == scale.shape[0]
assert input_N == scale.shape[1]

x = scale.view(input_M // 128, 4, input_N // 128, 4, 32)
x = x.permute(2, 0, 4, 3, 1)
x = x.contiguous()

# alternative way: transpose the scale and do rowwise swizzle with M, N swapped
x1 = _rowwise_swizzle_mxfp8_scale(input_N, input_M, scale.transpose(0, 1).contiguous())
torch.testing.assert_close(
x.view(-1), x1.view(-1), atol=0.0, rtol=0.0, msg="columnwise swizzle sanity check failed"
)

# View back as original 2D shape
x = x.view(input_M // 32, input_N)
return x


def swizzle_mxfp8_scale(input_M, input_N, scale: torch.Tensor, columnwise: bool) -> torch.Tensor:
if not columnwise:
return _rowwise_swizzle_mxfp8_scale(input_M, input_N, scale)
else:
return _columnwise_swizzle_mxfp8_scale(input_M, input_N, scale)
Loading
Loading