Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile OverviewGreptile SummaryThis PR extracts Python GroupedTensor infrastructure from #2600, enabling PyTorch to store multiple weight tensors with contiguous memory layout. Key Changes
ArchitectureThe GroupedTensor acts as unified storage that exposes multiple Minor IssueType hint inconsistency on line 344 of Confidence Score: 4.5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedLinear
participant GroupedTensor
participant Quantizer
participant QuantizedTensor
Note over User,QuantizedTensor: Initialization Phase
User->>GroupedLinear: __init__(num_gemms, in_features, out_features)
GroupedLinear->>GroupedLinear: Create individual weight parameters
GroupedLinear->>GroupedLinear: reset_parameters()
GroupedLinear->>GroupedLinear: make_grouped_weights()
Note over GroupedLinear,GroupedTensor: GroupedTensor Creation
GroupedLinear->>GroupedTensor: make_grouped_tensor_with_shapes(num_tensors, shape, quantizer)
GroupedTensor->>GroupedTensor: Calculate logical_shape and offsets
GroupedTensor->>GroupedTensor: Allocate contiguous buffers (data, scale_inv, etc.)
GroupedTensor->>GroupedTensor: split_into_quantized_tensors()
loop For each tensor
GroupedTensor->>QuantizedTensor: Create view of contiguous storage
QuantizedTensor-->>GroupedTensor: Return quantized tensor view
end
GroupedTensor-->>GroupedLinear: Return GroupedTensor with quantized_tensors
loop For each weight
GroupedLinear->>QuantizedTensor: copy_(original_weight)
GroupedLinear->>GroupedLinear: register_parameter(weight_i, quantized_tensor)
end
Note over User,QuantizedTensor: Forward Pass
User->>GroupedLinear: forward(inp, m_splits)
GroupedLinear->>GroupedLinear: _get_weight_tensors()
GroupedLinear->>GroupedLinear: Prepare input quantization
GroupedLinear->>GroupedLinear: general_grouped_gemm(weights, inputs, outputs)
Note over GroupedLinear: All weights share contiguous storage
GroupedLinear-->>User: Return output
|
| def make_grouped_tensor( | ||
| num_tensors: int, | ||
| first_dims: Optional[torch.Tensor], | ||
| last_dims: Optional[torch.tensor], |
There was a problem hiding this comment.
Inconsistent type hint: torch.tensor should be torch.Tensor (uppercase T)
| last_dims: Optional[torch.tensor], | |
| last_dims: Optional[torch.Tensor], |
| from .nvfp4_tensor_storage import NVFP4TensorStorage | ||
|
|
||
|
|
||
| class GroupedTensor: |
There was a problem hiding this comment.
I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).
There was a problem hiding this comment.
This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?
There was a problem hiding this comment.
I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.uint8, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().delayed(): |
There was a problem hiding this comment.
Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.float32, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().float8_current_scaling(): |
There was a problem hiding this comment.
float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.
| result.append(tensor) | ||
|
|
||
| # Delayed scaling or current scaling (both use Float8TensorStorage) | ||
| elif recipe.delayed() or recipe.float8_current_scaling(): |
There was a problem hiding this comment.
let's assert an error for this case?
| dtype: Optional[torch.dtype] = None, | ||
| ) -> GroupedTensor: | ||
| """ | ||
| Create a GroupedTensor for storing multiple weight tensors of the same shape. |
There was a problem hiding this comment.
Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.
Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable
| torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), | ||
| torch.cumsum(first_dims * logical_last_dim, dim=0), | ||
| ] | ||
| ) |
There was a problem hiding this comment.
I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using
tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])
Description
Extracts the python pieces of
GroupedTensorinfrastructure from #2600. Since this is mainly focused on creation of weights as a singleGroupedTensorand exposing them as multipleQuantizedTensors for PyTorch, this portion does not need to be graph capturable.Type of change
Changes
GroupedTensorclass.GroupedTensorintoGroupedLinearsuch that the parameters are contiguous.Checklist: