From 3afce1f133112d162cf66f680b83a7cd8d360ab0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 2 Feb 2026 16:45:50 -0800 Subject: [PATCH 01/43] Add NVTE_KEEP_BACKWARD_UNQUANTIZED Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/module/grouped_linear.py | 36 +++-- .../pytorch/module/layernorm_linear.py | 80 +++++++--- .../pytorch/module/layernorm_mlp.py | 147 +++++++++++------- transformer_engine/pytorch/module/linear.py | 65 +++++--- .../pytorch/ops/basic/basic_linear.py | 48 ++++-- .../pytorch/ops/basic/quantize.py | 6 +- .../ops/fused/backward_activation_bias.py | 7 +- .../fused/forward_linear_bias_activation.py | 18 ++- .../ops/fused/forward_linear_bias_add.py | 18 ++- .../ops/fused/forward_linear_scale_add.py | 18 ++- .../ops/fused/userbuffers_forward_linear.py | 49 +++++- transformer_engine/pytorch/ops/fuser.py | 16 +- transformer_engine/pytorch/quantization.py | 5 + 14 files changed, 375 insertions(+), 142 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..4a2140718d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,9 +1135,11 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..874eadeb36 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,6 +96,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -286,6 +289,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -294,7 +298,11 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, weights[0], biases[0]) + ): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() @@ -318,6 +326,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -333,7 +343,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -384,7 +394,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -395,13 +405,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + weights_for_dgrad = weights if use_fp8_bwd else origin_weights + if use_fp8_bwd: + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -415,7 +427,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -442,7 +454,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -516,7 +528,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not ctx.fp8 + and not use_fp8_bwd ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..28842fc315 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,6 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -200,7 +201,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -213,6 +217,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -236,6 +241,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -409,13 +415,14 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -427,7 +434,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -439,7 +446,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -466,7 +473,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -493,6 +500,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -515,7 +523,11 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, ln_weight, ln_bias, weight, bias) + ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -592,6 +604,15 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -601,23 +622,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -628,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -665,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -703,18 +724,22 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -730,12 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight if use_quantized_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -782,7 +808,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -794,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -820,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -836,7 +866,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -862,7 +892,9 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -870,7 +902,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..2b3a72b803 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,6 +232,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -350,8 +351,10 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = fc1_weight.requires_grad and ( - (is_grad_enabled and not checkpoint) or is_recomputation + backwards_needs_fc1_input = ( + fc1_weight.requires_grad + and ((is_grad_enabled and not checkpoint) or is_recomputation) + and not keep_backward_unquantized ) device = inp.device @@ -394,6 +397,7 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom ) @@ -415,6 +419,7 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -611,6 +616,10 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) + act_out_hp = act_out + if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: + act_out_hp = activation_func(fc1_out, None, **act_params) + # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -686,22 +695,30 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out) - ln_out = None + clear_tensor_data(ln_out_to_save) + ln_out_to_save = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out) - act_out = None + clear_tensor_data(act_out_to_save) + act_out_to_save = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + inputmat, + mu, + rsigma, + ln_out_to_save, + fc1_out, + fc1_out_without_bias, + act_out_to_save, ) # Scatter intermediate/activation tensors saved for the backward pass @@ -714,9 +731,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out, + ln_out_to_save, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out, + act_out_to_save, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -744,13 +761,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out_to_save, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out, + act_out_to_save, fc2_weight_final, fc2_weight, fc2_bias, @@ -798,6 +815,7 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -826,8 +844,12 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -996,6 +1018,16 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1015,7 +1047,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1029,7 +1061,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1042,7 +1074,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1057,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1066,7 +1098,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1103,7 +1135,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 + not use_fp8_bwd and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1112,20 +1144,23 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc2_weight_quantizer is not None + and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM + fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight_for_dgrad, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if fc2_dgrad_gemm_gelu_fusion or ctx.debug + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1157,7 +1192,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1170,7 +1209,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1193,14 +1232,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1209,7 +1248,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1219,7 +1258,9 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision + "quantization_params": ( + ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1256,8 +1297,8 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - ctx.fp8 - and ctx.fp8_recipe.float8_block_scaling() + use_fp8_bwd + and fp8_recipe_bwd.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1277,12 +1318,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None: + if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not ctx.fp8 + assert not use_fp8_bwd fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1292,13 +1333,10 @@ def fc2_wgrad_gemm( fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( - _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and ctx.fp8 + _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[2] + dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1308,18 +1346,16 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[1] + activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if ctx.fp8: + if use_fp8_bwd: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or ctx.fp8_recipe.custom() + or fp8_recipe_bwd.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1347,16 +1383,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1364,8 +1400,10 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc1_weight_quantizer is not None + and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1380,12 +1418,13 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM + fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight_for_dgrad, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer, + quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1434,7 +1473,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1444,7 +1483,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1466,7 +1505,9 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc1_grad_weight_quantizer, + "quantization_params": ( + ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..b4bad849c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,6 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -443,6 +446,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -479,7 +483,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -536,6 +540,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -545,23 +558,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -575,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -594,6 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None + and use_quantized_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -623,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -649,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -690,20 +704,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -720,12 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -774,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -784,7 +801,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -796,7 +817,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -816,7 +837,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -825,7 +846,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -851,7 +872,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -859,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index e640f3ffb1..a9a6895112 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,14 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +422,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +462,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +515,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +550,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +622,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +988,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1007,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1017,16 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = self.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index d126b554b5..cc26022d0e 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,7 +57,11 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward + quantize_backward = ( + fp8_enabled + and self._quantize_backward + and not FP8GlobalStateManager.keep_backward_unquantized() + ) # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..59e9af14f4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe +from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,7 +105,10 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None: + if recipe is None or ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..0a28d00706 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..41ae096e54 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..b06f5ad36a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 6ef9bf083b..8c04fca17c 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,6 +94,7 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -126,6 +127,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -200,7 +203,10 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -216,7 +222,10 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -227,7 +236,10 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Construct output tensor if needed @@ -257,14 +269,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -311,6 +332,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -340,6 +364,7 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -352,10 +377,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 7fe6ea37ed..035233fb55 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,6 +109,10 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -120,7 +124,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None: + if prev_op is not None and not keep_backward_unquantized: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -286,7 +290,15 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not _is_graph_capturing(): + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) + if ( + func_ctx.is_first_module + and not keep_backward_unquantized + and not _is_graph_capturing() + ): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..9806871ef6 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -430,6 +430,11 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL + @classmethod + def keep_backward_unquantized(cls) -> bool: + """Should backward skip FP8 quantization and use high precision""" + return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" From 72149be265539dc732cf8656e4ed2d21ecde374c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:49:22 +0000 Subject: [PATCH 02/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/ops/fuser.py | 6 +----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2b3a72b803..8e8749b237 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1332,9 +1332,7 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif ( - _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd - ): + elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 035233fb55..a692bc9487 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -294,11 +294,7 @@ def backward( FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.keep_backward_unquantized() ) - if ( - func_ctx.is_first_module - and not keep_backward_unquantized - and not _is_graph_capturing() - ): + if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From 927d482136a3f297813f7bdb3b36d678e44faf6c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:36:13 -0800 Subject: [PATCH 03/43] Disable ub and clean up Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 9 ++-- .../pytorch/module/layernorm_mlp.py | 13 ++--- transformer_engine/pytorch/module/linear.py | 17 +++---- .../ops/fused/userbuffers_forward_linear.py | 49 +++---------------- 4 files changed, 25 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 28842fc315..66e67522f6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -608,6 +608,7 @@ def backward( use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -622,23 +623,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8e8749b237..5d72508d0d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1023,6 +1023,7 @@ def backward( use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -1074,7 +1075,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1098,7 +1099,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1192,11 +1193,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1209,7 +1206,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b4bad849c1..a03e9ac4d5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -544,6 +544,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -558,23 +559,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -801,11 +802,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -817,7 +814,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 8c04fca17c..6ef9bf083b 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,7 +94,6 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -127,8 +126,6 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = `False` Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -203,10 +200,7 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -222,10 +216,7 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -236,10 +227,7 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage( - rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, - ) + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) # Construct output tensor if needed @@ -269,23 +257,14 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if ( - w is not weight - and with_quantized_compute - and is_quantized_tensor(w) - and not keep_backward_unquantized - ): + if w is not weight and with_quantized_compute and is_quantized_tensor(w): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if ( - with_quantized_compute - and is_quantized_tensor(x_local) - and not keep_backward_unquantized - ): + if with_quantized_compute and is_quantized_tensor(x_local): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -332,9 +311,6 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() - ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -364,7 +340,6 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -377,18 +352,10 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized - ) + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer From cc85b606cf31717ccb7684b21125e858505413d0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:37:57 -0800 Subject: [PATCH 04/43] Drop fuser changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/fuser.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a692bc9487..7fe6ea37ed 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,10 +109,6 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -124,7 +120,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None and not keep_backward_unquantized: + if prev_op is not None: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -290,11 +286,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) - if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From fe24f95c16d8c5a46b363f612afbcbc7fd676b6d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:56:43 -0800 Subject: [PATCH 05/43] Replace use_quantized_bwd with use_fp8_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 19 +++++++------ .../pytorch/module/layernorm_mlp.py | 27 +++++++++---------- transformer_engine/pytorch/module/linear.py | 23 ++++++++-------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 66e67522f6..b759c152ec 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -606,7 +606,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -650,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -687,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_quantized_bwd: + if ctx.input_quantizer is not None and use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -726,7 +725,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -740,7 +739,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -756,13 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_quantized_bwd else origin_weight + weight_for_dgrad = weight if use_fp8_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -851,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -894,7 +893,7 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5d72508d0d..1414bb4afa 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1020,7 +1020,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True @@ -1062,7 +1061,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1090,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1146,7 +1145,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): @@ -1161,7 +1160,7 @@ def backward( grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1229,14 +1228,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1256,7 +1255,7 @@ def backward( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad @@ -1315,7 +1314,7 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu @@ -1396,7 +1395,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc1_weight_quantizer is not None and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): @@ -1419,7 +1418,7 @@ def fc2_wgrad_gemm( dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1468,7 +1467,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1478,7 +1477,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1501,7 +1500,7 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a03e9ac4d5..6ecc647626 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -542,7 +542,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -589,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -608,7 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_quantized_bwd + and use_fp8_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -638,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -706,7 +705,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -720,7 +719,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -737,13 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight + weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -792,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +833,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad From 5ca361584796e6010768f8c91ee9b265a379f8bd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:57:32 +0000 Subject: [PATCH 06/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b759c152ec..bdfeff056b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -892,9 +892,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1414bb4afa..c5f7051fa1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1499,9 +1499,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6ecc647626..1ce4fac445 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -868,9 +868,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From 5ba76747ab50fc5cd8cccd3e5bfa9fcf53fe58bb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:30:04 -0800 Subject: [PATCH 07/43] Ignore keep_backward_unquantized if delayed scaling Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + transformer_engine/pytorch/quantization.py | 3 +++ 3 files changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 874eadeb36..0ccacd9b17 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,6 +98,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1ce4fac445..49b78382d2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,6 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9806871ef6..e8f6dafdb5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -433,6 +433,9 @@ def with_high_precision_init_val(cls) -> bool: @classmethod def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" + recipe = cls.get_fp8_recipe() + if recipe.delayed(): + return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) @classmethod From 02b7b2ae23f01942968e59eda24a47d74ee832a3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:39:02 -0800 Subject: [PATCH 08/43] Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 0ccacd9b17..9e2eb60ea5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,7 +98,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 49b78382d2..0bf560c7b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,7 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e8f6dafdb5..aab7ed2d1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -434,7 +434,8 @@ def with_high_precision_init_val(cls) -> bool: def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" recipe = cls.get_fp8_recipe() - if recipe.delayed(): + if recipe is not None and recipe.delayed(): + # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) From 01a7de026f92e7bb9e8f1e8b8e6f51b7da1c668a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:13:57 -0800 Subject: [PATCH 09/43] Add back missing ctx.debug Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++----- transformer_engine/pytorch/module/linear.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bdfeff056b..fd458a34b4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -850,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c5f7051fa1..a98ecfb903 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1089,7 +1089,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1228,14 +1228,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1467,7 +1467,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1477,7 +1477,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0bf560c7b7..930fbe061d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -638,7 +638,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +664,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -792,7 +792,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +834,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: From bf904aab91dad9d2a515dc249400b9282e65ce09 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:43:45 -0800 Subject: [PATCH 10/43] Refactor changes under fused Signed-off-by: Ziang Li --- .../ops/fused/backward_activation_bias.py | 7 ++----- .../ops/fused/forward_linear_bias_activation.py | 17 +++++++++++------ .../ops/fused/forward_linear_bias_add.py | 17 +++++++++++------ .../ops/fused/forward_linear_scale_add.py | 17 +++++++++++------ 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 59e9af14f4..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,10 +105,7 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None or ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ): + if recipe is None: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 0a28d00706..6e7c85988f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,12 +122,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 41ae096e54..f3b4533848 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,12 +119,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index b06f5ad36a..53e7327873 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,12 +100,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From b449fc4516f5e3146d13f99d2377158788de385c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:44:30 -0800 Subject: [PATCH 11/43] Clean up Signed-off-by: Ziang Li --- .../pytorch/ops/fused/forward_linear_bias_activation.py | 6 ------ .../pytorch/ops/fused/forward_linear_bias_add.py | 6 ------ .../pytorch/ops/fused/forward_linear_scale_add.py | 6 ------ 3 files changed, 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 6e7c85988f..2458d4d072 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -127,12 +127,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index f3b4533848..efa543e555 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -124,12 +124,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 53e7327873..2804534968 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -105,12 +105,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From de3acaf7e11c79cc072face5d3fc8431be84fec6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:11:07 -0800 Subject: [PATCH 12/43] Refactor high-precision overwrite if keep_backward_unquantized Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 17 ++++++++++------- .../pytorch/module/layernorm_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_mlp.py | 14 +++++++++++--- transformer_engine/pytorch/module/linear.py | 5 ++++- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9e2eb60ea5..859e648579 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -406,13 +406,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - weights_for_dgrad = weights if use_fp8_bwd else origin_weights - if use_fp8_bwd: - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights_for_dgrad: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + # weights_for_dgrad = weights if use_fp8_bwd else origin_weights + # if use_fp8_bwd: + weights_for_dgrad = weights + if keep_backward_unquantized: + weights_for_dgrad = origin_weights + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fd458a34b4..70d8936ce3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -415,7 +415,10 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -755,7 +758,10 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_fp8_bwd else origin_weight + # weight_for_dgrad = weight if use_fp8_bwd else origin_weight + weight_for_dgrad = weight + if keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a98ecfb903..a8e0bda73d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -695,8 +695,13 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - act_out_to_save = act_out_hp if keep_backward_unquantized else act_out + ln_out_to_save = ln_out + act_out_to_save = act_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + act_out_to_save = act_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1152,7 +1157,10 @@ def backward( ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight + fc2_weight_for_dgrad = fc2_weight + if keep_backward_unquantized: + fc2_weight_for_dgrad = origin_fc2_weight + # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 930fbe061d..496bfd45b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -737,7 +737,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight + weight_for_dgrad = weight_fp8 + if keep_backward_unquantized: + weight_for_dgrad = weight + # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From fe65d34213cfa6061459e5a04ab2ce4610865535 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:14:22 -0800 Subject: [PATCH 13/43] Clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_mlp.py | 3 --- transformer_engine/pytorch/module/linear.py | 1 - 4 files changed, 8 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 859e648579..e782f20cc6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -406,8 +406,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # weights_for_dgrad = weights if use_fp8_bwd else origin_weights - # if use_fp8_bwd: weights_for_dgrad = weights if keep_backward_unquantized: weights_for_dgrad = origin_weights diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 70d8936ce3..e3aab9b304 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -418,7 +418,6 @@ def forward( ln_out_to_save = ln_out if keep_backward_unquantized: ln_out_to_save = ln_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -758,7 +757,6 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - # weight_for_dgrad = weight if use_fp8_bwd else origin_weight weight_for_dgrad = weight if keep_backward_unquantized: weight_for_dgrad = origin_weight diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a8e0bda73d..6107c7d377 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -700,8 +700,6 @@ def _forward( if keep_backward_unquantized: ln_out_to_save = ln_out_hp act_out_to_save = act_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1160,7 +1158,6 @@ def backward( fc2_weight_for_dgrad = fc2_weight if keep_backward_unquantized: fc2_weight_for_dgrad = origin_fc2_weight - # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 496bfd45b7..10ea095c16 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -740,7 +740,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight_fp8 if keep_backward_unquantized: weight_for_dgrad = weight - # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From 59aaf6b7875202f19f4180e5057a07df418668cd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 10:56:41 -0800 Subject: [PATCH 14/43] Drop redundant fp8_recipe_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6107c7d377..9406c0c7ef 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1023,7 +1023,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -1249,7 +1248,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): + if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1299,7 +1298,7 @@ def fc2_wgrad_gemm( if fc2_bias_grad is None: if ( use_fp8_bwd - and fp8_recipe_bwd.float8_block_scaling() + and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1333,9 +1332,14 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and use_fp8_bwd + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1345,7 +1349,9 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision @@ -1354,7 +1360,7 @@ def fc2_wgrad_gemm( # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or fp8_recipe_bwd.custom() + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) From 44da62593ef2476d80691f79f652ec907333870f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:57:29 +0000 Subject: [PATCH 15/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9406c0c7ef..863a70e5e8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1335,7 +1335,7 @@ def fc2_wgrad_gemm( elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and use_fp8_bwd - ): + ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None From 0f5879380fcdb9a9c90d0fa73d6de3edfb646df0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:02:24 -0800 Subject: [PATCH 16/43] Drop redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 863a70e5e8..add32c0ba9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1388,16 +1388,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- From 192fbad0501fb967bb02c5e545343726a2dbaff1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:07:16 -0800 Subject: [PATCH 17/43] Drop more redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e3aab9b304..60c4e1d8b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -812,11 +812,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -828,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) From 0dd12689957868370d0f17890cbb743361bf134a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:25:01 -0800 Subject: [PATCH 18/43] Drop redundant delayed scaling changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +----- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e782f20cc6..7e6773043d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -299,11 +299,7 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, weights[0], biases[0]) - ): + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index add32c0ba9..5f8de6159e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -847,12 +847,8 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad( + if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias - ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 10ea095c16..535d2e75e5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -484,7 +484,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): + if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 216621d01a3021a63e1c6f102817113ec46edd0a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:25:49 +0000 Subject: [PATCH 19/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5f8de6159e..6a88848236 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -848,7 +848,7 @@ def _forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() From ab8749bb120ce73f6009d285c2c2c84c7890590b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 12:01:36 -0800 Subject: [PATCH 20/43] Drop unneeded backwards_needs_fc1_input Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6a88848236..44028aebcc 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -351,10 +351,8 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = ( - fc1_weight.requires_grad - and ((is_grad_enabled and not checkpoint) or is_recomputation) - and not keep_backward_unquantized + backwards_needs_fc1_input = fc1_weight.requires_grad and ( + (is_grad_enabled and not checkpoint) or is_recomputation ) device = inp.device From 58810837b3d3794c4c66b2994c767418ec2b9e8d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 14:01:43 -0800 Subject: [PATCH 21/43] Drop and disallow LayerNormMLP implementation Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 44028aebcc..8d78ceab86 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -233,6 +233,7 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -395,7 +396,6 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized and not custom ) @@ -417,7 +417,6 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -614,10 +613,6 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) - act_out_hp = act_out - if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: - act_out_hp = activation_func(fc1_out, None, **act_params) - # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -693,33 +688,22 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out - act_out_to_save = act_out - if keep_backward_unquantized: - ln_out_to_save = ln_out_hp - act_out_to_save = act_out_hp ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out_to_save) - ln_out_to_save = None + clear_tensor_data(ln_out) + ln_out = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out_to_save) - act_out_to_save = None + clear_tensor_data(act_out) + act_out = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, - mu, - rsigma, - ln_out_to_save, - fc1_out, - fc1_out_without_bias, - act_out_to_save, + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out ) # Scatter intermediate/activation tensors saved for the backward pass @@ -732,9 +716,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out_to_save, + ln_out, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out_to_save, + act_out, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -762,13 +746,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out_to_save, + ln_out, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out_to_save, + act_out, fc2_weight_final, fc2_weight, fc2_bias, @@ -816,7 +800,6 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -1015,15 +998,6 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1043,7 +1017,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1057,7 +1031,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc2_grad_output_quantizer is not None: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1085,7 +1059,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1131,7 +1105,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not use_fp8_bwd + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1140,25 +1114,20 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ( - use_fp8_bwd - and ctx.fc2_weight_quantizer is not None - and isinstance(ctx.fc2_weight, QuantizedTensorStorage) + if ctx.fc2_weight_quantizer is not None and isinstance( + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight - if keep_backward_unquantized: - fc2_weight_for_dgrad = origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight_for_dgrad, + fc2_weight, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd + if fc2_dgrad_gemm_gelu_fusion or ctx.debug else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1226,14 +1195,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1242,7 +1211,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1252,9 +1221,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None - ), # wgrad in high precision + "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1291,7 +1258,7 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - use_fp8_bwd + ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): @@ -1312,12 +1279,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc1_grad_output_quantizer is not None: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not use_fp8_bwd + assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1328,7 +1295,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and use_fp8_bwd + and ctx.fp8 ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( @@ -1350,7 +1317,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if use_fp8_bwd: + if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) @@ -1399,10 +1366,8 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ( - use_fp8_bwd - and ctx.fc1_weight_quantizer is not None - and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1417,13 +1382,12 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM - fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight_for_dgrad, + fc1_weight, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1472,7 +1436,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1482,7 +1446,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1504,7 +1468,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) From 431f0c8fd3c643380fefa3f4ca923d59ada5bcea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:02:31 +0000 Subject: [PATCH 22/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8d78ceab86..da236e7be0 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -233,7 +233,9 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() - assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + assert ( + not keep_backward_unquantized + ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: From 937e34b10585058383293c649cf2a5841813e7a9 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:10:10 -0800 Subject: [PATCH 23/43] Move interface changes to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 67 +++++++++++++++++-- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 6 +- .../pytorch/ops/basic/quantize.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/quantization.py | 39 +++++++---- 11 files changed, 99 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..a36b743f3b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,11 @@ from pydantic.dataclasses import dataclass +def _default_quantize_backward() -> bool: + """Default backward quantization setting.""" + return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -181,6 +186,11 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. Delayed scaling + always quantizes backward; setting this to False is not supported. Notes ----- @@ -204,6 +214,8 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -216,7 +228,9 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -230,6 +244,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -242,6 +260,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -257,7 +279,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -284,12 +308,18 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -298,7 +328,9 @@ def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -327,6 +359,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -379,7 +415,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -428,6 +466,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ # Configuration envvars @@ -443,6 +485,8 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -474,6 +518,8 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -505,12 +551,23 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 7e6773043d..a7d7bc8948 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,7 +96,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 60c4e1d8b2..4173c76216 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index da236e7be0..82e7d868b4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,7 +232,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 535d2e75e5..76ff5dd1d4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index a9a6895112..a362485a7e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,7 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) @@ -989,7 +991,7 @@ def op_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index cc26022d0e..e6c28b9fdc 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -60,7 +60,7 @@ def op_forward( quantize_backward = ( fp8_enabled and self._quantize_backward - and not FP8GlobalStateManager.keep_backward_unquantized() + and FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Quantize if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2458d4d072..80cb5647d7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -93,7 +93,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index efa543e555..cf29140a20 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -87,7 +87,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 2804534968..0caae13af9 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -66,7 +66,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get extra input tensor for add operation diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index aab7ed2d1c..fb0553056a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,6 +87,21 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +def _validate_recipe_quantization_flags(recipe: Recipe) -> None: + """Validate forward/backward quantization flags on a recipe.""" + quantize_forward = getattr(recipe, "quantize_forward", True) + quantize_backward = getattr(recipe, "quantize_backward", True) + if not quantize_forward and quantize_backward: + raise ValueError( + "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + ) + if recipe.delayed() and not quantize_backward: + raise ValueError( + "Invalid recipe configuration: delayed scaling does not support " + "quantize_backward=False." + ) + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -430,15 +445,6 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL - @classmethod - def keep_backward_unquantized(cls) -> bool: - """Should backward skip FP8 quantization and use high precision""" - recipe = cls.get_fp8_recipe() - if recipe is not None and recipe.delayed(): - # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used - return False - return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -851,16 +857,21 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) + fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + if enabled or calibrating: + _validate_recipe_quantization_flags(fp8_recipe) + quantize_forward = getattr(fp8_recipe, "quantize_forward", True) + effective_enabled = enabled and quantize_forward + if effective_enabled: + check_recipe_support(fp8_recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=enabled, + enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=recipe, + fp8_recipe=fp8_recipe, fp8_group=amax_reduction_group, _graph=_graph, ) @@ -868,7 +879,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From 0d26127d2d90370bfedfe834fb3d7e10ac4e07ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:11:01 +0000 Subject: [PATCH 24/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- transformer_engine/pytorch/ops/basic/basic_linear.py | 8 ++++---- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 ++-- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 ++-- .../pytorch/ops/fused/forward_linear_scale_add.py | 4 ++-- 8 files changed, 22 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a7d7bc8948..9aad36a868 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -96,7 +96,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4173c76216..3016d41c5f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,9 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 82e7d868b4..8e6a189843 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -232,7 +232,9 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 76ff5dd1d4..c8feddf5af 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index a362485a7e..ba7de55f69 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,8 +332,8 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) @@ -990,8 +990,8 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 80cb5647d7..2bccabb306 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,8 +92,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index cf29140a20..03e3bff6f3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,8 +86,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 0caae13af9..8cebcec53a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,8 +65,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get extra input tensor for add operation From 0135366a68fc8add8988c540b60ec96cbf25b723 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:43:08 -0800 Subject: [PATCH 25/43] Move ub overrides to fwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 14 ++++++++------ transformer_engine/pytorch/module/linear.py | 13 +++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3016d41c5f..f39fb45608 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -539,6 +539,14 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # keep_backward_unquantized overrides + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -610,12 +618,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c8feddf5af..3ed78e85da 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -493,6 +493,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -545,12 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None From 1de3c64a524e1d5127ab94e0eaf54037461cc7bd Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:44:22 -0800 Subject: [PATCH 26/43] Remove duplication Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index a36b743f3b..85b232c26b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -262,8 +262,6 @@ class Float8CurrentScaling(Recipe): fp8_mha: bool = False quantize_forward: bool = True quantize_backward: bool = field(default_factory=_default_quantize_backward) - quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." From 04d35430cdd4537056f2dd18d4e62275a133e245 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:59:39 -0800 Subject: [PATCH 27/43] Simplify use_fp8_bwd logic in bwd Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 19 +++++++++---- .../pytorch/module/layernorm_linear.py | 25 ++++++++--------- transformer_engine/pytorch/module/linear.py | 28 +++++++++---------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 9aad36a868..38e3ceef9a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -310,6 +310,14 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -326,7 +334,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -342,7 +349,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -393,7 +400,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -427,7 +434,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -454,7 +461,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -528,7 +535,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not use_fp8_bwd + and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f39fb45608..1ef8536e4f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -541,7 +541,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -617,7 +617,6 @@ def backward( origin_weight.main_grad = main_grad keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -655,7 +654,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -692,7 +691,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_fp8_bwd: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -731,7 +730,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -739,13 +738,13 @@ def backward( # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -769,7 +768,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -854,14 +853,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -896,7 +895,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -904,7 +903,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3ed78e85da..a97ba398e0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -495,6 +495,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -551,7 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_pop(f"{nvtx_label}.fsdp_gather") keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -592,7 +592,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +611,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_fp8_bwd + and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -641,7 +641,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -667,7 +667,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -709,7 +709,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -717,13 +717,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +748,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -797,7 +797,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -839,7 +839,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -848,7 +848,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -874,7 +874,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -882,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, From 454976eaeb1520ad075d0f3dbc2de736108ea0cd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:00:24 +0000 Subject: [PATCH 28/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 38e3ceef9a..54caabdb7e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -310,7 +310,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + # keep_backward_unquantized overrides if keep_backward_unquantized: ctx.fp8 = ctx.fp8 and not keep_backward_unquantized From f7794c94eb301e466db5c1c0b311bf054977caa6 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 14:28:06 -0800 Subject: [PATCH 29/43] Set grad quantizers to none if keep bwd unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 3 +++ .../pytorch/module/layernorm_linear.py | 11 +++++++---- transformer_engine/pytorch/module/linear.py | 13 ++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 54caabdb7e..73dc81ad41 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -318,6 +318,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1ef8536e4f..4de6afa38b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -546,6 +546,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -654,7 +657,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -744,7 +747,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -768,7 +771,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -895,7 +898,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a97ba398e0..1fd2fcba8d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -500,6 +500,10 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -592,7 +596,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +615,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -723,7 +726,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +751,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -874,7 +877,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From 58db8ea72fd2a52a8c4fabe324c487352919fa35 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:28:55 +0000 Subject: [PATCH 30/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fd2fcba8d..3e8c4c146f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -503,7 +503,6 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None - # ------------------------------------------------------ # Cached state for backward pass is ready... From 9d0b6547427e4e6f2c969d84431e665c762577b1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 17:28:04 -0800 Subject: [PATCH 31/43] Drop delayed scaling change Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4de6afa38b..26b14c2d8a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -527,11 +527,7 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, ln_weight, ln_bias, weight, bias) - ): + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 004cb455f8a39abe5c75f13c8b3a0fb5b179d664 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:29:24 -0800 Subject: [PATCH 32/43] Simplify env var logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 85b232c26b..55010499ec 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,11 +11,6 @@ from pydantic.dataclasses import dataclass -def _default_quantize_backward() -> bool: - """Default backward quantization setting.""" - return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - - class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -215,7 +210,7 @@ def scaling_factor_compute(amax: Tensor, fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -261,7 +256,7 @@ class Float8CurrentScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -317,7 +312,7 @@ class MXFP8BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -484,7 +479,7 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -560,7 +555,7 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __repr__(self) -> str: return ( From 9baccfd65dde556099fe8ded69160de09342c3a4 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:41:01 -0800 Subject: [PATCH 33/43] Move validation check to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 18 ++++++++++++++++++ transformer_engine/pytorch/quantization.py | 17 ----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 55010499ec..673df45f4c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -214,6 +214,12 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + assert ( + not self.quantize_backward + ), "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( @@ -260,6 +266,9 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -316,6 +325,9 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -393,6 +405,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -484,6 +499,9 @@ class NVFP4BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fb0553056a..bbffe51eec 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,21 +87,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) -def _validate_recipe_quantization_flags(recipe: Recipe) -> None: - """Validate forward/backward quantization flags on a recipe.""" - quantize_forward = getattr(recipe, "quantize_forward", True) - quantize_backward = getattr(recipe, "quantize_backward", True) - if not quantize_forward and quantize_backward: - raise ValueError( - "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - ) - if recipe.delayed() and not quantize_backward: - raise ValueError( - "Invalid recipe configuration: delayed scaling does not support " - "quantize_backward=False." - ) - - def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -858,8 +843,6 @@ def autocast( """ fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - if enabled or calibrating: - _validate_recipe_quantization_flags(fp8_recipe) quantize_forward = getattr(fp8_recipe, "quantize_forward", True) effective_enabled = enabled and quantize_forward if effective_enabled: From 207eb5a7d2319d4e12a016faa37a0438c4c8ce27 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:55:28 -0800 Subject: [PATCH 34/43] Simplify effective_enabled Signed-off-by: Ziang Li --- transformer_engine/pytorch/quantization.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index bbffe51eec..00196c584f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -842,11 +842,9 @@ def autocast( are reduced at the end of each training step. """ - fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - quantize_forward = getattr(fp8_recipe, "quantize_forward", True) - effective_enabled = enabled and quantize_forward + effective_enabled = enabled and getattr(recipe, "quantize_forward", True) if effective_enabled: - check_recipe_support(fp8_recipe) + check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() @@ -854,7 +852,7 @@ def autocast( FP8GlobalStateManager.autocast_enter( enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe, + fp8_recipe=recipe, fp8_group=amax_reduction_group, _graph=_graph, ) From 15117b1d545660aa3a9ceae82fb0bd4c4191ed44 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:56:33 -0800 Subject: [PATCH 35/43] Fix inverted assertion logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 673df45f4c..f03e9b24d6 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -217,9 +217,7 @@ def __post_init__(self) -> None: assert not ( not self.quantize_forward and self.quantize_backward ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert ( - not self.quantize_backward - ), "Delayed scaling does not support quantize_backward=False." + assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( From 3fc5270e82689136aa58d12deb98b1012bdc11a0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:33:38 -0800 Subject: [PATCH 36/43] Simplify changes under ops Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/basic_linear.py | 4 ---- transformer_engine/pytorch/ops/basic/quantize.py | 11 ++++++----- .../ops/fused/forward_linear_bias_activation.py | 7 ++----- .../pytorch/ops/fused/forward_linear_bias_add.py | 7 ++----- .../pytorch/ops/fused/forward_linear_scale_add.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index ba7de55f69..16b7bcb7c5 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1020,11 +1020,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None saved_weight = self.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index e6c28b9fdc..6e90e33846 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,11 +57,12 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = ( - fp8_enabled - and self._quantize_backward - and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + quantize_backward = fp8_enabled and self._quantize_backward + + # Recipe quantize overrides + if FP8GlobalStateManager.get_fp8_recipe() is not None: + quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2bccabb306..860407904c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,11 +122,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 03e3bff6f3..0729291d55 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,11 +119,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 8cebcec53a..dfdd11a231 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,11 +100,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 9201d1926d44099f556745613b555869b70b08a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:34:39 +0000 Subject: [PATCH 37/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 6e90e33846..33062d5b88 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -61,8 +61,12 @@ def op_forward( # Recipe quantize overrides if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + quantize_forward = ( + quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + ) + quantize_backward = ( + quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Quantize if needed out = input_ From 1e0f1d2deb435facb7c28bbc4374036db930c91b Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:52:01 -0800 Subject: [PATCH 38/43] Simplify ctx.keep_backward_unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4a2140718d..a878f2ace2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,8 +1135,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 73dc81ad41..abe6df6875 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -336,7 +336,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -415,7 +414,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 26b14c2d8a..187fd70f92 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -615,8 +615,6 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -760,7 +758,7 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3e8c4c146f..7d960102ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -554,8 +554,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -743,7 +741,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, From 253873a4560b2c2a2c909918cc3ee26500e5b43d Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 15:07:48 -0800 Subject: [PATCH 39/43] Fix missing attribute Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 3 insertions(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index f03e9b24d6..d534ad883b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -382,6 +382,8 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8e6a189843..ac10534012 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -784,6 +784,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer From fd947612cf65d9b15f0e55235f42dd27913b3f63 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:02:10 -0800 Subject: [PATCH 40/43] Add unit tests Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 1 + .../pytorch/test_keep_backward_unquantized.py | 701 ++++++++++++++++++ 2 files changed, 702 insertions(+) create mode 100644 tests/pytorch/test_keep_backward_unquantized.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a13dfada79..5ee843987c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -40,6 +40,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py new file mode 100644 index 0000000000..a5ef00e34c --- /dev/null +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -0,0 +1,701 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import os +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, +) + +from utils import quantization_tols, reset_rng_states + + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# This file is intended to run in dedicated keep-backward-unquantized mode. +pytestmark = pytest.mark.skipif( + os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", + reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", +) + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: + kwargs = {} + if quantize_backward is not None: + kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} + + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: + fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) + assert fp8_recipe.quantize_forward + assert not fp8_recipe.quantize_backward + return fp8_recipe + + +def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: + return _make_recipe(recipe_name, quantize_backward=True) + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + bias: bool = False, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + ) + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + bias_grads: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + bias_grads.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + x2: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + weight_grad = model[0].weight.grad.detach().clone() + bias_grad = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + bias_grad = model[0].bias.grad.detach().clone() + x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): + _ = _build_keep_backward_unquantized_recipe(recipe_name) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "module_type", + ("linear", "layernorm_linear", "ops_linear"), +) +@pytest.mark.parametrize( + "input_shape,out_features", + _shape_test_cases, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, +): + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + dtype = torch.bfloat16 + in_features = input_shape[-1] + + module_quantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_keep_bwd_hp = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_unquantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + + # Start all runs from identical parameters. + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( + module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + # Forward pass should still match quantized reference when only backward is unquantized. + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # Backward pass should match unquantized reference for dgrad and wgrad. + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if use_bias: + bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) + bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) + assert bgrad_keep is not None + assert bgrad_unquantized is not None + torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32]), + ids=("uniform_splits", "with_empty_split"), +) +def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + use_bias: bool, + m_splits: list[int], +): + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_keep_bwd_hp = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, x, m_splits, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( + module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( + module_unquantized_ref, x, m_splits, dy, None + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): + torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) + if use_bias: + for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): + assert test_db is not None + assert ref_db is not None + torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +def test_keep_backward_unquantized_fused_linear_paths( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + m = 32 + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( + _run_fused_single_step( + fused_pattern, + model_keep_bwd_hp, + x1, + dy, + keep_bwd_hp_recipe, + x2=x2, + ) + ) + _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( + _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: + torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) + if db_keep_bwd_hp is not None and db_unquantized_ref is not None: + torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = input_shape[-1] + out_features = 64 + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", in_features, out_features, dtype, bias=True + ) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + out_shape = x1.shape[:-1] + (out_features,) + dy = torch.randn(*out_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( + "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # keep-bwd mode should disable backward-activation+bias fusion, while quantized + # reference should still use it. + keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops + assert not any( + isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops + ) + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any( + isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # In keep-backward-unquantized mode, backward should behave as high-precision linear backward + # given the ReLU mask induced by quantized forward activations. + dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) + _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + db_expected = _extract_bias_grad(linear_unquantized_ref) + assert db_keep_bwd_hp is not None + assert db_expected is not None + + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) + torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) + + +def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + + module_quantization_disabled = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) + module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) + + x = torch.randn(32, in_features, dtype=dtype, device="cuda") + dy = torch.randn(32, out_features, dtype=dtype, device="cuda") + + recipe_no_fwd_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + + y_test, dx_test, dw_test = _run_single_step( + module_quantization_disabled, x, dy, recipe_no_fwd_quant + ) + y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) + bgrad_test = _extract_bias_grad(module_quantization_disabled) + bgrad_ref = _extract_bias_grad(module_unquantized_ref) + assert bgrad_test is not None + assert bgrad_ref is not None + torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): + reset_rng_states() + dtype = torch.bfloat16 + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + + recipe_no_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) + + torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling does not support quantize_backward=False", + ): + _ = recipe.DelayedScaling() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): + reset_rng_states() + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=torch.bfloat16, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") + + with pytest.raises( + AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + ): + with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): + _ = layer(x) From 0b2dbf962f02aefc4b2c306a226dccc82bb00d82 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:03:02 -0800 Subject: [PATCH 41/43] Fix bias errors in unit test Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/bias.py | 8 +++++++- .../pytorch/ops/fused/backward_activation_bias.py | 5 +++-- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 +++- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 +++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8b60251088..d0ff6d5e15 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex +from ...quantization import FP8GlobalStateManager from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -123,7 +124,12 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - ctx.grad_input_quantizer = prev_op_grad_output_quantizer + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else prev_op_grad_output_quantizer + ) return x + b diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..395a9dbd67 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,8 +104,9 @@ def fuse_backward_ops( """ - # Check if recipe supports bias activation fusion - if recipe is None: + # Check if recipe supports bias activation fusion. + # keep-backward-unquantized mode should use unfused backward ops. + if recipe is None or not recipe.quantize_backward: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 860407904c..42f459a41e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -138,7 +138,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 0729291d55..75d58fd5cc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -135,7 +135,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] From 364332022caadee3bc10796b57414b18673b4130 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:03:50 +0000 Subject: [PATCH 42/43] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/test_keep_backward_unquantized.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index a5ef00e34c..fe11bfcd3a 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -214,7 +214,9 @@ def _run_grouped_linear_single_step( y = module(x_run, m_splits) y.backward(dy) assert x_run.grad is not None - weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + weight_grads = [ + getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) + ] bias_grads: list[Optional[torch.Tensor]] = [] for i in range(module.num_gemms): if module.use_bias: @@ -257,7 +259,9 @@ def _run_fused_single_step( dy: torch.Tensor, fp8_recipe: Optional[recipe.Recipe], x2: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] +]: model.zero_grad(set_to_none=True) x1_run = x1.detach().clone().requires_grad_(True) x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None @@ -276,7 +280,9 @@ def _run_fused_single_step( bias_grad = None if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: bias_grad = model[0].bias.grad.detach().clone() - x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + x2_grad = ( + x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + ) return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad @@ -355,7 +361,9 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe ) - _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( + module_unquantized_ref, x, dy, None + ) # Forward pass should still match quantized reference when only backward is unquantized. torch.testing.assert_close( @@ -458,6 +466,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un assert ref_db is not None torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + @pytest.mark.parametrize( "recipe_name", _quantized_numerics_recipe_list, @@ -589,13 +598,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # keep-bwd mode should disable backward-activation+bias fusion, while quantized # reference should still use it. keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops - assert not any( - isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops - ) + assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops - assert any( - isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops - ) + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) torch.testing.assert_close( y_keep_bwd_hp, @@ -606,7 +611,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # In keep-backward-unquantized mode, backward should behave as high-precision linear backward # given the ReLU mask induced by quantized forward activations. dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) - _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + _, dx1_expected, dw_expected = _run_single_step( + linear_unquantized_ref, x1, dy_after_activation, None + ) db_expected = _extract_bias_grad(linear_unquantized_ref) assert db_keep_bwd_hp is not None assert db_expected is not None @@ -625,7 +632,9 @@ def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): module_quantization_disabled = _make_linear_like_module( "linear", in_features, out_features, dtype, bias=True ) - module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + module_unquantized_ref = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) x = torch.randn(32, in_features, dtype=dtype, device="cuda") From 74c787dafa267ad3bcfe9059504facbdc4f245d0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 15:22:40 -0800 Subject: [PATCH 43/43] Add more shapes to unit test Signed-off-by: Ziang Li --- .../pytorch/test_keep_backward_unquantized.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index fe11bfcd3a..f5c3339a71 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -5,6 +5,7 @@ from __future__ import annotations from contextlib import nullcontext +import math import os from typing import Optional @@ -64,7 +65,9 @@ ] _shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), ] @@ -166,6 +169,46 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + def _run_single_step( module: torch.nn.Module, x: torch.Tensor, @@ -333,6 +376,7 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads ): reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) dtype = torch.bfloat16 in_features = input_shape[-1] @@ -390,8 +434,8 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize( "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32]), - ids=("uniform_splits", "with_empty_split"), + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), ) def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( recipe_name: str, @@ -400,6 +444,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ): if recipe_name == "nvfp4": pytest.skip("NVFP4 not supported for grouped linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) reset_rng_states() dtype = torch.bfloat16 @@ -478,10 +523,12 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ("scale_add", ForwardLinearScaleAdd), ), ) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) def test_keep_backward_unquantized_fused_linear_paths( recipe_name: str, fused_pattern: str, expected_fused_op: type, + m: int, ): # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -490,8 +537,7 @@ def test_keep_backward_unquantized_fused_linear_paths( dtype = torch.bfloat16 in_features = 64 out_features = 64 - m = 32 - + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype)