Skip to content

Conversation

@HwVanICI
Copy link
Collaborator

Description

This PR implements update_weights_from_distributed in single-controller mode while preserving backward compatibility with the existing SPMD training architecture.

In the single-controller design, rollout and training responsibilities are split across controllers and workers. As a result, training workers no longer own rollout inference engines, which breaks an implicit assumption in the existing FSDP-based implementation: that every training engine has a locally constructed rollout_engine.

The goal of this PR is to enable XCCL-based distributed weight updates in single-controller mode.


Related Issue

Fixes #718


Type of Change

  • New feature (non-breaking change that adds functionality)

Motivation

In the original SPMD architecture:

  • Each training process constructs its own rollout inference engine
  • connect_engine(rollout_engine) is called locally
  • update_weights_from_distributed directly invokes rollout update APIs from the training engine

In single-controller mode:

  • Rollout engines are owned and managed by the RolloutController
  • Training workers must not construct or own rollout inference engines
  • Training workers still execute FSDP collectives and broadcasts, but rollout coordination must occur at the controller layer

At the same time, the existing FSDP implementation assumes that connect_engine() has been called and that self.rollout_engine exists.


Why we cannot pass MockInferenceEngine via RPC

During implementation, we initially attempted to pass a MockInferenceEngine instance directly through the RPC call to connect_engine.

This fails because:

  • RPC payloads are serialized to JSON
  • MockInferenceEngine is a runtime Python object with internal state
  • Such objects are not serializable and cannot cross process boundaries

Attempting this results in runtime serialization errors (e.g. “object is not JSON serializable”) and prevents the RPC call from being dispatched.


High-Level Solution

To enable update_weights_from_distributed in single-controller mode without modifying existing FSDP logic, this PR introduces a mock inference engine connection pattern:

  • Training workers are connected to a mock inference engine that implements the
    InferenceEngine interface but performs no operation
  • The mock engine is not passed as an object
  • Instead, the controller passes a string import path (e.g. "module.path.MockInferenceEngine")
  • Each worker:
    • Resolves the import path locally
    • Instantiates the mock engine in-process

This allows:

  • Existing update_weights_from_distributed logic to remain unchanged
  • Inference-related operations to be safely bypassed on training workers
  • The TrainController to fully orchestrate rollout-side weight updates

This PR also introduces _get_bucket_param_specs on FSDPEngine, which returns bucket-level
ParamSpec metadata to the controller, as discussed in Issue #718.

In addition, this PR adds an optional rank_ids argument to
InferenceEngine.init_weights_update_group, enabling correct per-worker rank assignment.
Without this, all workers will incorrectly initialize with the same rank in the communication group.


Compatibility

  • Fully backward compatible
  • No breaking changes
  • Single-controller and SPMD modes both supported and tested

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools
  • I have run relevant tests and they pass
  • My branch is up to date with main
  • I have added tests for new functionality (follow-up PR if needed)

Additional Context

This implementation follows the design direction discussed in #718 and subsequent comments.
Instead of restructuring FSDP internals, training workers are connected to a mock inference
engine so that inference-related assumptions remain satisfied while rollout orchestration
is handled entirely at the controller layer.

To test:

python examples/single-controller/gsm8k_grpo_xccl_weight_update.py \
  --config examples/single-controller/gsm8k_grpo.yaml

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HwVanICI, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a robust solution for performing distributed weight updates in a single-controller training architecture, specifically for XCCL. It resolves the architectural shift where training workers no longer manage rollout inference engines by implementing a clever mock inference engine pattern. This design ensures that the core FSDP weight update logic remains untouched while enabling proper orchestration of weight updates by the TrainController, thereby maintaining compatibility and enhancing the system's flexibility.

Highlights

  • Single-Controller XCCL Weight Update: This PR enables XCCL-based distributed weight updates in single-controller mode, addressing the challenge where training workers no longer own rollout inference engines, which was an implicit assumption in the existing FSDP-based implementation.
  • Mock Inference Engine Pattern: A mock inference engine connection pattern is introduced. Instead of serializing and passing a MockInferenceEngine object, a string import path is passed to workers. Each worker then resolves this path and instantiates the mock engine locally, allowing existing update_weights_from_distributed logic to remain unchanged and inference operations to be safely bypassed on training workers.
  • Parameter Bucket Metadata: The FSDPEngine now includes a _get_bucket_param_specs method, which returns bucket-level ParamSpec metadata to the controller. This is crucial for coordinating distributed weight updates.
  • Explicit Rank Assignment: The init_weights_update_group method in InferenceEngine and related remote engines now accepts an optional rank_ids argument, enabling correct per-worker rank assignment during the initialization of communication groups.
  • Backward Compatibility: The changes are fully backward compatible, introduce no breaking changes, and support both single-controller and SPMD modes.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature to support distributed weight updates in a single-controller architecture, where training and rollout workers are decoupled. The core solution, which involves connecting training workers to a mock inference engine via a string import path, is a clever way to maintain backward compatibility with the existing FSDP logic. The changes are comprehensive, affecting APIs, controllers, engines, and the RPC server, and include a new example for testing. The code is generally well-structured and documented. My review focuses on improving maintainability by removing a hardcoded string and redundant code, fixing a minor typo, and highlighting a fragile implementation detail and a design assumption for future consideration.

ParamSpec(
name=n,
shape=tuple(t.shape),
dtype=str(t.dtype).split("torch.")[1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Parsing the dtype string with str(t.dtype).split('torch.')[1] is a bit fragile. If the string representation of torch dtypes changes in a future version (e.g., torch.bfloat16 becomes something else), this code will break. A more robust approach would be to use a reverse mapping from torch.dtype objects to their string names, or use a utility function that handles this conversion safely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For creating the ParamSpec we followed same approach as _update_bucket_weights_from_distributed in the fsdp_engine

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] Single-Controller update_weights_from_distributed ::Design and Implementation

1 participant