From d8264ea597fbbe20c7501c0181aeae0d5f9fbe19 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Thu, 25 Dec 2025 10:09:18 +0000 Subject: [PATCH] [PyTorch]: add moe fp8 flow under blockwise recipe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. add fp8 rowwise scaling-aware transpose op for wgrad columwise. 2. support Float8BlockwiseQTensor input in grouped_linear. 3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path. Signed-off-by: xiaoxi-wangfj <690912414@qq.com> Co-authored-by: dantesuu@gmail.com Co-authored-by: xzhu@zhejianglab.org Co-authored-by: 123sssmmm@gmail.com --- .../pytorch/module/grouped_linear.py | 39 ++- transformer_engine/pytorch/permutation.py | 14 +- .../pytorch/tensor/float8_blockwise_tensor.py | 91 +++++++ .../blockwise_scaling_aware_fp8_transpose.py | 230 ++++++++++++++++++ .../pytorch/triton/permutation.py | 2 +- 5 files changed, 360 insertions(+), 16 deletions(-) create mode 100644 transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c4d35a9c2cd..6589a1cf8ec 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,6 +44,7 @@ from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer +from ..tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, @@ -143,7 +144,12 @@ def forward( inp_view = inp.reshape(-1, in_features) inputmats: list if fp8 and not debug: - inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) + if isinstance(inp_view, Float8BlockwiseQTensor): + inputmats = inp_view.split_scaling_aware_fp8_transpose( + m_splits, input_quantizers + ) + else: + inputmats = tex.split_quantize(inp_view, m_splits, input_quantizers) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( inp_view, input_quantizers, m_splits, activation_dtype @@ -343,18 +349,28 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Unfused bias grad and multi-tensor quantize for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) + if isinstance(grad_output_view, Float8BlockwiseQTensor): + grad_output = grad_output_view.split_scaling_aware_fp8_transpose( + ctx.m_splits, ctx.grad_output_quantizers + ) + else: + grad_output = tex.split_quantize( + grad_output_view, + ctx.m_splits, + ctx.grad_output_quantizers, + ) + else: + # Multi-tensor quantize + if isinstance(grad_output_view, Float8BlockwiseQTensor): + grad_output = grad_output_view.split_scaling_aware_fp8_transpose( + ctx.m_splits, ctx.grad_output_quantizers + ) + else: grad_output = tex.split_quantize( grad_output_view, ctx.m_splits, ctx.grad_output_quantizers, ) - else: - # Multi-tensor quantize - grad_output = tex.split_quantize( - grad_output_view, - ctx.m_splits, - ctx.grad_output_quantizers, - ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) for i in range(ctx.num_gemms): @@ -781,9 +797,10 @@ def forward( """ debug = self.is_debug_iter() - assert not isinstance( - inp, QuantizedTensorStorage - ), "GroupedLinear doesn't support input tensor in FP8." + if not isinstance(inp, Float8BlockwiseQTensor): + assert not isinstance( + inp, QuantizedTensorStorage + ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." is_grad_enabled = torch.is_grad_enabled() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index d15814585ee..78766ad67dd 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -224,7 +224,10 @@ def forward( fake_dtype = inp.dtype # blockwise scaling if blockwise_recipe: - fp8_scale = inp._rowwise_scale_inv.T.contiguous() + if inp._rowwise_data.shape[0] == inp._rowwise_scale_inv.shape[0]: + fp8_scale = inp._rowwise_scale_inv + else: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() scale_hidden_dim = fp8_scale.shape[1] assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" inp = inp._rowwise_data @@ -275,7 +278,7 @@ def forward( shape=output.shape, dtype=fake_dtype, rowwise_data=output, - rowwise_scale_inv=permuted_scale.T.contiguous(), + rowwise_scale_inv=permuted_scale, columnwise_data=None, columnwise_scale_inv=None, fp8_dtype=fp8_dtype, @@ -423,7 +426,10 @@ def backward(ctx, unpermuted_act_grad): unpermuted_act_grad = unpermuted_act_grad._data # blockwise scaling elif blockwise_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + if unpermuted_act_grad._rowwise_data.shape[0] == unpermuted_act_grad._rowwise_scale_inv.shape[0]: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv + else: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" @@ -485,7 +491,7 @@ def backward(ctx, unpermuted_act_grad): shape=act_grad.shape, dtype=fake_dtype, rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.T.contiguous(), + rowwise_scale_inv=permuted_scale, columnwise_data=None, columnwise_scale_inv=None, fp8_dtype=fp8_dtype, diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 01e03e53551..6d3e4cdbd4e 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -17,6 +17,9 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple +from ..triton.blockwise_scaling_aware_fp8_transpose import ( + blockwise_scaling_aware_fp8_transpose, +) aten = torch.ops.aten @@ -437,6 +440,94 @@ def untyped_storage(self) -> torch.UntypedStorage: return data.untyped_storage() return torch.UntypedStorage(0, device=self.device) + def split_scaling_aware_fp8_transpose(self, m_splits, quantizers): + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "split_transpose_quantize only supports rowwise_data inputs." + + # Temporary solution: perf fp8flow + assert ( + self._rowwise_scale_inv.shape[0] == self._rowwise_data.shape[0] + ), "rowwise_data and rowwise_scale_inv must have same M (rows)." + if ( + self._is_gemm_ready_format() + and self._rowwise_data.shape[0] == self._rowwise_scale_inv.shape[0] + ): + self._data_format = tex.Float8BlockScaleTensorFormat.COMPACT + assert ( + not self._is_gemm_ready_format() + ), "Only COMPACT input format is supported." + + rowwise_usage = quantizers[0].rowwise_usage + device = self._rowwise_data.device + kept = [i for i, m in enumerate(m_splits) if m > 0] + m_splits_kept = [m_splits[i] for i in kept] + + if len(m_splits_kept) > 0: + ( + rowwise_data_list, + rowwise_scale_inv_t_list, + columnwise_data_list, + columnwise_scale_inv_list, + ) = blockwise_scaling_aware_fp8_transpose( + self._rowwise_data, self._rowwise_scale_inv, m_splits_kept + ) + + if len(m_splits_kept) != len(m_splits): + K = self._rowwise_data.shape[1] + empty_rw_data = ( + torch.empty((0, K), dtype=self._rowwise_data.dtype, device=device) + if rowwise_usage + else None + ) + empty_rw_si_t = ( + torch.empty( + (self._rowwise_scale_inv.shape[1], 0), + dtype=self._rowwise_scale_inv.dtype, + device=device, + ) + if rowwise_usage + else None + ) + empty_cw_data = torch.empty( + (K, 0), dtype=self._rowwise_data.dtype, device=device + ) + empty_cw_si = torch.empty( + (0, K), dtype=self._rowwise_scale_inv.dtype, device=device + ) + + results = [] + kept_idx = 0 + for i, m in enumerate(m_splits): + if m == 0: + rowwise_data = empty_rw_data + rowwise_scale_inv_t = empty_rw_si_t + columnwise_data = empty_cw_data + columnwise_scale_inv = empty_cw_si + else: + rowwise_data = rowwise_data_list[kept_idx] if rowwise_usage else None + rowwise_scale_inv_t = ( + rowwise_scale_inv_t_list[kept_idx] if rowwise_usage else None + ) + columnwise_data = columnwise_data_list[kept_idx] + columnwise_scale_inv = columnwise_scale_inv_list[kept_idx] + kept_idx += 1 + + results.append( + Float8BlockwiseQTensorStorage( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv_t, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self._fp8_dtype, + quantizer=quantizers[i], + is_2D_scaled=self._is_2D_scaled, + data_format=tex.Float8BlockScaleTensorFormat.GEMM_READY, + ) + ) + + return results + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py b/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py new file mode 100644 index 00000000000..52f2481138d --- /dev/null +++ b/transformer_engine/pytorch/triton/blockwise_scaling_aware_fp8_transpose.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper functions and scaling_aware_fp8_transpose Triton kernels.""" +import torch +import triton +import triton.language as tl + + +@triton.jit +def _scaling_aware_fp8_transpose_kernel( + # input pointers + rowwise_data_ptrs, + rowwise_scale_inv_ptrs, + columnwise_data_ptrs, + columnwise_scale_inv_ptrs, + rowwise_scale_inv_t_ptrs, + rows_ptr, + # sizes + cols, + rsi_cols, + # strides + stride_rowwise_data_r, + stride_rsi_r, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_group_index = tl.program_id(0) + pid_row = tl.program_id(1) + pid_col = tl.program_id(2) + + rows = tl.load(rows_ptr + pid_group_index) + nbrows = (rows + BLOCK_SIZE - 1) // BLOCK_SIZE + if pid_row >= nbrows: + return + + row_base = tl.load(rowwise_data_ptrs + pid_group_index).to( + tl.pointer_type(tl.uint8) + ) + rsi_base = tl.load(rowwise_scale_inv_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + col_base = tl.load(columnwise_data_ptrs + pid_group_index).to( + tl.pointer_type(tl.uint8) + ) + csi_base = tl.load(columnwise_scale_inv_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + + r_start = pid_row * BLOCK_SIZE + c_start = pid_col * BLOCK_SIZE + r_offsets = r_start + tl.arange(0, BLOCK_SIZE) + c_offsets = c_start + tl.arange(0, BLOCK_SIZE) + valid_r = r_offsets < rows + valid_c = c_offsets < cols + data = tl.load( + row_base + (r_offsets[:, None] * stride_rowwise_data_r + c_offsets[None, :]), + mask=valid_r[:, None] & valid_c[None, :], + other=0, + ) + + rsi_c_offsets = pid_col + tl.arange(0, 1) + valid_rsi_c = rsi_c_offsets < rsi_cols + si = tl.load( + rsi_base + r_offsets[:, None] * stride_rsi_r + rsi_c_offsets[None, :], + mask=valid_r[:, None] & valid_rsi_c[None, :], + other=0.0, + ) + + # Write rowwise_scale_inv.T + rst_base = tl.load(rowwise_scale_inv_t_ptrs + pid_group_index).to( + tl.pointer_type(tl.float32) + ) + tl.store( + rst_base + (rsi_c_offsets[:, None] * rows + r_offsets[None, :]), + si.T, + mask=valid_rsi_c[:, None] & valid_r[None, :], + ) + + # For the current block-row (128 rows), take the per-channel max of rowwise_scale_inv + # This max value becomes the columnwise scaling factor for this block + target_si = tl.max(si, axis=0) + tl.store(csi_base + (pid_row * cols + c_offsets), target_si, mask=valid_c) + + # FP8 decode/encode + sign = (data >> 7) & 1 + exp = (data >> 3) & 0xF + mant = data & 0x7 + # log2_t = tl.log2(target_si) + # log2_si = tl.log2(si + 1e-30) + # kf = log2_t - log2_si + # k = tl.cast(tl.floor(kf + 0.5), tl.int32) + bits_target = tl.cast(target_si, tl.uint32, bitcast=True) + bits_si = tl.cast(si, tl.uint32, bitcast=True) + exp_t = ((bits_target & 0x7F800000) >> 23) - 127 + exp_s = ((bits_si & 0x7F800000) >> 23) - 127 + k_approx = exp_t[None, :] - exp_s + k = tl.cast(k_approx, tl.int32) + exp_new = exp - k + exp_new = tl.where(exp_new < 1, 0, exp_new) + new_data = (sign << 7) | (exp_new << 3) | mant + new_data = tl.where(exp == 0, 0, new_data) + + # write columnwise_data (uint8) to [K,M] (c, r) + tl.store( + col_base + (c_offsets[:, None] * rows + r_offsets[None, :]), + new_data.T, + mask=valid_c[:, None] & valid_r[None, :], + ) + + +def blockwise_scaling_aware_fp8_transpose( + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + m_splits: list[int], + block_size: int = 128, +): + """ + Scaling-aware FP8 transpose that converts row-wise quantized FP8 tensors to a + column-wise layout in the FP8 domain. + + The input is split along the M dimension according to ``m_splits``. For each split, + the kernel transposes FP8 data from shape ``[m_i, cols]`` to ``[cols, m_i]`` while + producing column-wise scaling factors at block-row granularity. The operation is + performed without dequantizing to higher precision types. + + Parameters + ---------- + rowwise_data : torch.Tensor + Row-wise FP8-encoded data stored as ``uint8`` with shape + ``[sum(m_splits), cols]``. + + rowwise_scale_inv : torch.Tensor + Row-wise scaling factors associated with ``rowwise_data`` with shape + ``[sum(m_splits), rsi_cols]``. + + m_splits : list[int] + Sizes of splits along the M dimension. Each entry ``m_i`` defines the number of + rows in one group. + + block_size : int, optional + Tile size for the blockwise transpose and scaling-aware conversion. + + Returns + ------- + rowwise_data_list : list[torch.Tensor] + List of input views split by ``m_splits``, each with shape ``[m_i, cols]`` and + dtype matching ``rowwise_data``. + + rowwise_scale_inv_t_list : list[torch.Tensor] + List of transposed row-wise inverse scaling tensors, each with shape + ``[nbcols, m_i]``, where ``nbcols = ceil(cols / block_size)`` and dtype matching + ``rowwise_scale_inv``. + + columnwise_data_list : list[torch.Tensor] + List of column-wise FP8-encoded output tensors, each with shape ``[cols, m_i]`` + and dtype matching ``rowwise_data`` (raw FP8 bits in ``uint8``). + + columnwise_scale_inv_list : list[torch.Tensor] + List of column-wise inverse scaling tensors at block-row granularity, each with + shape ``[nbrows_i, cols]``, where ``nbrows_i = ceil(m_i / block_size)`` and dtype + matching ``rowwise_scale_inv``. + + """ + assert len(m_splits) > 0, "m_splits can not be zero" + device = rowwise_data.device + data_dtype = rowwise_data.dtype + scale_dtype = rowwise_scale_inv.dtype + + cols = rowwise_data.shape[1] + rsi_cols = rowwise_scale_inv.shape[1] + # Number of block-rows (along the M dimension) for each tensor, + # since each Mi differs, we must take the maximum among them + nbrows_list = [(m + block_size - 1) // block_size for m in m_splits] + nbcols = (cols + block_size - 1) // block_size + + rowwise_data_list = list(torch.split(rowwise_data, m_splits, dim=0)) + rowwise_scale_inv_list = list(torch.split(rowwise_scale_inv, m_splits, dim=0)) + rowwise_scale_inv_t_list = [ + torch.empty((nbcols, m), dtype=scale_dtype, device=device) for m in m_splits + ] + columnwise_data_list = [ + torch.empty((cols, m), dtype=data_dtype, device=device) for m in m_splits + ] + columnwise_scale_inv_list = [ + torch.empty((nb, cols), dtype=scale_dtype, device=device) for nb in nbrows_list + ] + + rowwise_data_ptrs = torch.as_tensor([t.data_ptr() for t in rowwise_data_list]).to( + device=device, non_blocking=True + ) + rowwise_scale_inv_ptrs = torch.as_tensor( + [t.data_ptr() for t in rowwise_scale_inv_list] + ).to(device=device, non_blocking=True) + rowwise_scale_inv_t_ptrs = torch.as_tensor( + [t.data_ptr() for t in rowwise_scale_inv_t_list] + ).to(device=device, non_blocking=True) + columnwise_data_ptrs = torch.as_tensor( + [t.data_ptr() for t in columnwise_data_list] + ).to(device=device, non_blocking=True) + columnwise_scale_inv_ptrs = torch.as_tensor( + [t.data_ptr() for t in columnwise_scale_inv_list] + ).to(device=device, non_blocking=True) + + rows_t = torch.as_tensor(m_splits, dtype=torch.int32).to( + device=device, non_blocking=True + ) + + grid = (len(m_splits), max(nbrows_list), nbcols) + _scaling_aware_fp8_transpose_kernel[grid]( + rowwise_data_ptrs, + rowwise_scale_inv_ptrs, + columnwise_data_ptrs, + columnwise_scale_inv_ptrs, + rowwise_scale_inv_t_ptrs, + rows_t, + cols, + rsi_cols, + rowwise_data.stride(0), + rowwise_scale_inv.stride(0), + BLOCK_SIZE=block_size, + ) + + return ( + rowwise_data_list, + rowwise_scale_inv_t_list, + columnwise_data_list, + columnwise_scale_inv_list, + ) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 27662e1b283..985d11c644b 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -165,7 +165,7 @@ def permute_with_mask_map( alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None ) permuted_scale = ( - torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") + alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda") if scale is not None else None )