-
Notifications
You must be signed in to change notification settings - Fork 257
[Feature] Implement Single-Controller XCCL Weight Update #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Feature] Implement Single-Controller XCCL Weight Update #754
Conversation
Summary of ChangesHello @HwVanICI, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust solution for performing distributed weight updates in a single-controller training architecture, specifically for XCCL. It resolves the architectural shift where training workers no longer manage rollout inference engines by implementing a clever mock inference engine pattern. This design ensures that the core FSDP weight update logic remains untouched while enabling proper orchestration of weight updates by the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant feature to support distributed weight updates in a single-controller architecture, where training and rollout workers are decoupled. The core solution, which involves connecting training workers to a mock inference engine via a string import path, is a clever way to maintain backward compatibility with the existing FSDP logic. The changes are comprehensive, affecting APIs, controllers, engines, and the RPC server, and include a new example for testing. The code is generally well-structured and documented. My review focuses on improving maintainability by removing a hardcoded string and redundant code, fixing a minor typo, and highlighting a fragile implementation detail and a design assumption for future consideration.
| ParamSpec( | ||
| name=n, | ||
| shape=tuple(t.shape), | ||
| dtype=str(t.dtype).split("torch.")[1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parsing the dtype string with str(t.dtype).split('torch.')[1] is a bit fragile. If the string representation of torch dtypes changes in a future version (e.g., torch.bfloat16 becomes something else), this code will break. A more robust approach would be to use a reverse mapping from torch.dtype objects to their string names, or use a utility function that handles this conversion safely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For creating the ParamSpec we followed same approach as _update_bucket_weights_from_distributed in the fsdp_engine
Description
This PR implements
update_weights_from_distributedin single-controller mode while preserving backward compatibility with the existing SPMD training architecture.In the single-controller design, rollout and training responsibilities are split across controllers and workers. As a result, training workers no longer own rollout inference engines, which breaks an implicit assumption in the existing FSDP-based implementation: that every training engine has a locally constructed
rollout_engine.The goal of this PR is to enable XCCL-based distributed weight updates in single-controller mode.
Related Issue
Fixes #718
Type of Change
Motivation
In the original SPMD architecture:
connect_engine(rollout_engine)is called locallyupdate_weights_from_distributeddirectly invokes rollout update APIs from the training engineIn single-controller mode:
At the same time, the existing FSDP implementation assumes that
connect_engine()has been called and thatself.rollout_engineexists.Why we cannot pass
MockInferenceEnginevia RPCDuring implementation, we initially attempted to pass a
MockInferenceEngineinstance directly through the RPC call toconnect_engine.This fails because:
MockInferenceEngineis a runtime Python object with internal stateAttempting this results in runtime serialization errors (e.g. “object is not JSON serializable”) and prevents the RPC call from being dispatched.
High-Level Solution
To enable
update_weights_from_distributedin single-controller mode without modifying existing FSDP logic, this PR introduces a mock inference engine connection pattern:InferenceEngineinterface but performs no operation"module.path.MockInferenceEngine")This allows:
update_weights_from_distributedlogic to remain unchangedThis PR also introduces
_get_bucket_param_specsonFSDPEngine, which returns bucket-levelParamSpecmetadata to the controller, as discussed in Issue #718.In addition, this PR adds an optional
rank_idsargument toInferenceEngine.init_weights_update_group, enabling correct per-worker rank assignment.Without this, all workers will incorrectly initialize with the same rank in the communication group.
Compatibility
Checklist
Additional Context
This implementation follows the design direction discussed in #718 and subsequent comments.
Instead of restructuring FSDP internals, training workers are connected to a mock inference
engine so that inference-related assumptions remain satisfied while rollout orchestration
is handled entirely at the controller layer.
To test: