Skip to content

Conversation

@zzc0430
Copy link
Contributor

@zzc0430 zzc0430 commented Jan 20, 2026

PR Type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR Information

Problem

When running RLHF (e.g., KTO, DPO) with Megatron-LM + LoRA + TransformerEngine, a RuntimeError occurs during the reference model's forward pass when calculating KL divergence.

Root Cause:
The null_ref_context reuses the training model instance (or its unwrapped version) as the reference model without switching it to eval() mode. This leads to two critical issues:

  1. Crash with TransformerEngine: The model remains in train() mode, causing Gradient Checkpointing (Recompute) to stay enabled. The combination of no_grad, Gradient Checkpointing, LoRA, and TransformerEngine triggers an internal state error in TE (Input x is not allocated).
  2. Incorrect KL Calculation: Since the model is in training mode, Dropout remains active. This introduces randomness into the reference log-probabilities, resulting in unstable and mathematically incorrect KL divergence.

Solution

Forcibly switch the reference models to eval() mode within the null_ref_context manager and restore their original training state upon exit.

  • eval() disables Gradient Checkpointing, resolving the TE crash.
  • eval() disables Dropout, ensuring deterministic reference outputs.

Error Log

[rank10]: Traceback (most recent call last):
[rank10]:   File "/workspace/ms-swift/swift/cli/_megatron/rlhf.py", line 7, in <module>
[rank10]:     megatron_rlhf_main()
...
[rank10]:   File "/workspace/ms-swift/swift/megatron/trainers/base.py", line 565, in train_step
[rank10]:     return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
...
[rank10]:   File "/workspace/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step
[rank10]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank10]:   # Error happens during KL calculation entry point
[rank10]:   File "/workspace/ms-swift/swift/megatron/trainers/kto_trainer.py", line 136, in forward_step
[rank10]:     ref_KL_output_tensor = ref_model(**kl_data)
...
[rank10]:   File "/workspace/Megatron-LM/megatron/core/transformer/transformer_block.py", line 669, in forward
[rank10]:     hidden_states = self._checkpointed_forward(  # <--- Checkpointing is wrongly enabled
...
[rank10]:   File "/workspace/Megatron-LM/megatron/core/extensions/transformer_engine.py", line 615, in forward
[rank10]:     out = super().forward(x, is_first_microbatch=_is_first_microbatch)
...
[rank10]:   File "/root/anaconda3/envs/swift/lib/python3.10/site-packages/transformer_engine/pytorch/module/_common.py", line 51, in apply_normalization
[rank10]:     return normalization_func(
[rank10]: RuntimeError: /TransformerEngine/transformer_engine/common/transformer_engine.cpp:164 in function CheckInputTensor: Assertion failed: t.has_data() || t.has_columnwise_data(). Input x is not allocated!

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @zzc0430, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request fixes a critical bug affecting RLHF training, particularly when using Megatron-LM, LoRA, and TransformerEngine. The issue stemmed from reference models incorrectly remaining in training mode during KL divergence calculation, leading to RuntimeError due to gradient checkpointing conflicts and inaccurate KL values due to active dropout. The solution ensures these reference models are temporarily switched to evaluation mode, resolving both stability and correctness issues.

Highlights

  • Resolved RuntimeError in RLHF with TransformerEngine: The null_ref_context now explicitly switches reference models to eval() mode, which disables gradient checkpointing. This prevents a RuntimeError that occurred when using Megatron-LM, LoRA, and TransformerEngine for RLHF tasks, specifically during KL divergence calculation.
  • Ensured accurate KL divergence calculation: By forcing reference models into eval() mode, dropout is disabled, guaranteeing deterministic log-probabilities. This corrects the mathematical instability and inaccuracy previously observed in KL divergence calculations due to active dropout in training mode.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively resolves a critical bug that caused runtime errors and incorrect KL divergence calculations during RLHF training with Megatron-LM, LoRA, and TransformerEngine. The issue stemmed from the reference model not being switched to evaluation mode. The proposed solution, which involves using a try...finally block to set the model to eval() mode and reliably restore its state, is well-implemented and correct. My review includes a minor suggestion to further improve the robustness of this state management.

Comment on lines +55 to +66
training_states = [m.training for m in ref_models]
for m in ref_models:
m.eval()
try:
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of try...finally is a good practice for ensuring resource cleanup. To enhance robustness, I recommend moving the state modification logic (setting models to eval() mode) inside the try block. This change ensures that if an exception occurs during the m.eval() call on one of the models, the finally block will still be executed, guaranteeing that the training states of all models are correctly restored. This prevents potential state corruption that could affect subsequent training steps.

Suggested change
training_states = [m.training for m in ref_models]
for m in ref_models:
m.eval()
try:
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')
training_states = [m.training for m in ref_models]
try:
for m in ref_models:
m.eval()
yield ref_models
finally:
for m, was_training in zip(ref_models, training_states):
if was_training:
m.train()
if has_ref_adapter:
for m in self.peft_models:
m.set_adapter('default')

@Jintao-Huang
Copy link
Collaborator

hello! Could you provide a training script that can reproduce your issue?

@zzc0430
Copy link
Contributor Author

zzc0430 commented Jan 22, 2026

hello! Could you provide a training script that can reproduce your issue?

2*8 A800

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NNODES=$WORLD_SIZE \
NODE_RANK=$RANK \
megatron rlhf \
    --rlhf_type kto \
    --model Qwen/Qwen3-235B-A22B-Instruct-2507/ \
    --adapters ${ADAPTER} \
    --ref_adapters ${ADAPTER} \
    --dataset ${DATASET_PATHS} \
    --train_type lora \
    --use_rslora true \
    --lora_rank 64 \
    --lora_alpha 6 \
    --load_safetensors true \
    --save_safetensors true \
    --merge_lora false \
    --split_dataset_ratio 0 \
    --pipeline_model_parallel_size 2 \
    --tensor_model_parallel_size 4 \
    --expert_model_parallel_size 8 \
    --moe_permute_fusion true \
    --moe_grouped_gemm true \
    --moe_shared_expert_overlap true \
    --moe_aux_loss_coeff 1e-4 \
    --micro_batch_size 1 \
    --global_batch_size 16 \
    --packing false \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --max_epochs 3 \
    --finetune true \
    --cross_entropy_loss_fusion true \
    --lr 1e-4 \
    --lr_warmup_fraction 0.05 \
    --save_strategy epoch \
    --loss_scale last_round \
    --num_workers 8 \
    --dataset_num_proc 8 \
    --no_save_optim true \
    --no_save_rng true \
    --sequence_parallel true \
    --attention_backend flash \
    --log_interval 1 

@Jintao-Huang
Copy link
Collaborator

Jintao-Huang commented Jan 23, 2026

The error message above is a bug that will be fixed in this PR.

#7882

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants