Skip to content

[PyTorch] Python GroupedTensor#2654

Draft
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:pytorch_python_grouped_tensor
Draft

[PyTorch] Python GroupedTensor#2654
ksivaman wants to merge 1 commit intoNVIDIA:mainfrom
ksivaman:pytorch_python_grouped_tensor

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Extracts the python pieces of GroupedTensor infrastructure from #2600. Since this is mainly focused on creation of weights as a single GroupedTensor and exposing them as multiple QuantizedTensors for PyTorch, this portion does not need to be graph capturable.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Expose a python GroupedTensor class.
  • Integrate GroupedTensor into GroupedLinear such that the parameters are contiguous.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR extracts Python GroupedTensor infrastructure from #2600, enabling PyTorch to store multiple weight tensors with contiguous memory layout.

Key Changes

  • New GroupedTensor class (942 lines): Manages collections of tensors with different shapes but same dtype/scaling. Stores all data in contiguous 1D buffers with logical 2D shape representation. Supports all quantization recipes (FP8 delayed/current scaling, MXFP8, NVFP4, Float8 block scaling).

  • GroupedLinear integration: Added make_grouped_weights() method that converts individual weight parameters into a single GroupedTensor with contiguous storage, then re-registers them as parameters. This creates memory-efficient grouped storage while maintaining the same parameter interface.

  • Recipe API refactor: Changed Recipe class methods from instance methods (isinstance(self, Recipe)) to class methods (issubclass(cls, Recipe)). Required because _get_compatible_recipe() returns recipe classes, not instances.

  • Comprehensive tests: Added 441-line test suite covering construction, quantization, splitting, and varying tensor shapes. Enhanced sanity tests with contiguous memory verification.

Architecture

The GroupedTensor acts as unified storage that exposes multiple QuantizedTensor views. Each view shares the underlying contiguous buffers but presents individual tensor semantics. This enables efficient grouped GEMM operations while maintaining PyTorch's parameter interface.

Minor Issue

Type hint inconsistency on line 344 of grouped_tensor.py: uses torch.tensor instead of torch.Tensor.

Confidence Score: 4.5/5

  • This PR is safe to merge with only a minor type hint fix needed
  • The implementation is well-architected with comprehensive test coverage. The GroupedTensor design properly handles memory management and supports all quantization formats. The GroupedLinear integration is clean and maintains backward compatibility. Only issue found is a minor type hint inconsistency (torch.tensor vs torch.Tensor) that doesn't affect runtime behavior.
  • transformer_engine/pytorch/tensor/storage/grouped_tensor.py needs type hint fix on line 344

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/storage/grouped_tensor.py New 942-line GroupedTensor class for storing multiple tensors with shared contiguous storage. Supports all quantization recipes (FP8, MXFP8, NVFP4). Minor type hint issue on line 344.
transformer_engine/pytorch/module/grouped_linear.py Integrated GroupedTensor into GroupedLinear by adding make_grouped_weights() method. Creates contiguous parameter storage during reset_parameters(). Clean implementation with proper parameter re-registration.
tests/pytorch/test_grouped_tensor.py New comprehensive test file with 441 lines covering GroupedTensor construction, quantization, splitting, and varying shapes across all recipes. Excellent coverage.
tests/pytorch/test_sanity.py Added check_grouped_tensor_pointers() helper to verify contiguous memory layout in GroupedLinear tests. Validates that GroupedTensor integration works correctly.
transformer_engine/common/recipe/init.py Changed Recipe methods from instance to class methods (isinstance to issubclass). Necessary for GroupedTensor to check recipe types using _get_compatible_recipe() which returns classes.

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear
    participant GroupedTensor
    participant Quantizer
    participant QuantizedTensor
    
    Note over User,QuantizedTensor: Initialization Phase
    User->>GroupedLinear: __init__(num_gemms, in_features, out_features)
    GroupedLinear->>GroupedLinear: Create individual weight parameters
    GroupedLinear->>GroupedLinear: reset_parameters()
    GroupedLinear->>GroupedLinear: make_grouped_weights()
    
    Note over GroupedLinear,GroupedTensor: GroupedTensor Creation
    GroupedLinear->>GroupedTensor: make_grouped_tensor_with_shapes(num_tensors, shape, quantizer)
    GroupedTensor->>GroupedTensor: Calculate logical_shape and offsets
    GroupedTensor->>GroupedTensor: Allocate contiguous buffers (data, scale_inv, etc.)
    GroupedTensor->>GroupedTensor: split_into_quantized_tensors()
    
    loop For each tensor
        GroupedTensor->>QuantizedTensor: Create view of contiguous storage
        QuantizedTensor-->>GroupedTensor: Return quantized tensor view
    end
    
    GroupedTensor-->>GroupedLinear: Return GroupedTensor with quantized_tensors
    
    loop For each weight
        GroupedLinear->>QuantizedTensor: copy_(original_weight)
        GroupedLinear->>GroupedLinear: register_parameter(weight_i, quantized_tensor)
    end
    
    Note over User,QuantizedTensor: Forward Pass
    User->>GroupedLinear: forward(inp, m_splits)
    GroupedLinear->>GroupedLinear: _get_weight_tensors()
    GroupedLinear->>GroupedLinear: Prepare input quantization
    GroupedLinear->>GroupedLinear: general_grouped_gemm(weights, inputs, outputs)
    Note over GroupedLinear: All weights share contiguous storage
    GroupedLinear-->>User: Return output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

def make_grouped_tensor(
num_tensors: int,
first_dims: Optional[torch.Tensor],
last_dims: Optional[torch.tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent type hint: torch.tensor should be torch.Tensor (uppercase T)

Suggested change
last_dims: Optional[torch.tensor],
last_dims: Optional[torch.Tensor],

@ksivaman ksivaman added the MoE label Feb 6, 2026
from .nvfp4_tensor_storage import NVFP4TensorStorage


class GroupedTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.

@ksivaman ksivaman marked this pull request as draft February 6, 2026 21:28
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.

columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.float32, device=device
)
elif quantizer._get_compatible_recipe().float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.

result.append(tensor)

# Delayed scaling or current scaling (both use Float8TensorStorage)
elif recipe.delayed() or recipe.float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's assert an error for this case?

dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.

Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable

torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype),
torch.cumsum(first_dims * logical_last_dim, dim=0),
]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using

tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants