Skip to content

Conversation

@fishcrap
Copy link
Collaborator

Description

This PR adds comprehensive FP8 (8-bit floating point) training support to AReaL, enabling memory-efficient training with low precision while maintaining training stability. The implementation includes:

  • FP8 quantization/dequantization utilities: New fp8_utils.py and fp8_kernels.py modules providing blockwise and per-tensor quantization support
  • CLI configuration: Extended TrainEngineConfig with FP8-related options (fp8 mode, recipe, parameter quantization, etc.)
  • Model loading/saving: Updated HuggingFace model loading and saving to handle FP8 weights with proper conversion between PyTorch FP8 and Transformer Engine FP8 formats
  • Megatron engine integration: Enhanced MegatronEngine to support FP8 training with proper configuration propagation
  • Comprehensive test suite: Added extensive tests for FP8 conversion, BF16 comparison, and gradient correctness

The implementation supports the blockwise scheme, with integration into Transformer Engine's FP8 infrastructure for efficient GEMM operations.

Related Issue

Fixes #(issue)

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

N/A - This is a new feature that adds optional FP8 support without breaking existing functionality.

Additional Context

TODO:

  • Memory profiling
  • Training time reduction
  • Fix MLA

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @fishcrap, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly upgrades AReaL by integrating comprehensive FP8 training support. The primary goal is to enable memory-efficient training with reduced precision without compromising model stability. This is achieved through the introduction of new FP8 quantization and dequantization utilities, extensive configuration options via the CLI, and updates to model loading and saving processes to handle FP8 weights. The core MegatronEngine has been adapted to leverage these FP8 capabilities, and new tests ensure the reliability of these low-precision operations.

Highlights

  • Comprehensive FP8 Training Support: This PR introduces full 8-bit floating point (FP8) training capabilities to AReaL, enabling more memory-efficient training while striving to maintain training stability.
  • FP8 Quantization Utilities: New modules fp8_utils.py and fp8_kernels.py have been added, providing blockwise and per-tensor quantization and dequantization functionalities, including Triton-based kernels for efficient operations.
  • Extended CLI Configuration: The TrainEngineConfig and MegatronEngineConfig have been significantly extended with numerous FP8-related options, allowing users to configure FP8 mode, scaling recipes, parameter quantization, and other precision-related settings via the command-line interface.
  • Enhanced Model Loading and Saving: HuggingFace model loading and saving mechanisms have been updated to correctly handle FP8 weights, including proper conversion between PyTorch FP8 and Transformer Engine FP8 formats, and dequantization when necessary.
  • MegatronEngine Integration: The MegatronEngine has been enhanced to seamlessly support FP8 training, ensuring that FP8 configurations are correctly propagated and applied throughout the training process.
  • New Test Suite: A comprehensive test suite (test_fp8_conversion.py) has been added to verify the correctness of FP8 conversion, compare results with BF16 baselines, and ensure gradient accuracy.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive FP8 training support, including new utilities for quantization/dequantization, CLI configurations, and updates to model loading/saving to handle FP8 weights. The changes are extensive and well-structured. I've identified a few areas with TODO or FIXME comments in the new code, particularly in tests and utility functions, that should be addressed to ensure correctness and clarity. The overall implementation seems robust, with good integration into the existing MegatronEngine and the addition of a comprehensive test suite.

@fishcrap fishcrap changed the title Sxj/fp8 train [Feat] Add FP8 training support Dec 24, 2025
Copy link
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The critical issue is that we should enforce HF fp8 base model if fp8 training is enabled.

bucket_size: int | None = None
average_in_collective: bool = False
fp8_param_gather: bool = False
data_parallel_sharding_strategy: str = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it for FSDP or DDP? Does no_shard means no sharding for optimizer states or parameters?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this field

recompute_modules: list[str] | None = None

# MoE
moe_router_dtype: str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to float32?


def get_device_stats(self) -> DeviceRuntimeInfo:
return DeviceRuntimeInfo.get_current()
def _check_and_apply_fp8_config(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also check transformer_engine installation here. If transformer_engine is not installed, e.g., in a uv pip install environment, a runtime error should be raised

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also revert the above change

Comment on lines +476 to +484
# FP8 Training Configuration
fp8: str | None = field(
default=None,
metadata={
"help": "Enable FP8 precision training. Options: "
"'e4m3' (uniform e4m3), "
"'hybrid' (e4m3 for activations/weights, e5m2 for output activation gradients)."
},
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we provide an example yaml config for fp8 qwen3 training? We'd better provide a learning curve with the config (fp8 vs bf16 training curve).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants