Skip to content

Conversation

@HwVanICI
Copy link
Collaborator

@HwVanICI HwVanICI commented Dec 17, 2025

Description

This PR is a first implementation of a single controller implementation using Ray (RayScheduler) to allow for multi-node deployments, as a distributed alternative to the pre-existing LocalScheduler.

This change introduces RayScheduler, implementing the Scheduler interface, and RayRPCServer, mimicking the RPCServer class.

RayScheduler
RayScheduler creates one ray actor per training rank, and one ray actor per rollout instance. Ray handles the device assignments to each actor. Current design generates 1 placement group for the training, with one placement group per rollout actor.

RayRPCServer
RayRPCServer is the ray actor itself and resembles the RPCServer. Instead of communicating with the Scheduler using HTTP, all communication is done with ray remote calls.

Caveats
The current implementation is tested on vLLM using the "mp" distributed executor backend. A follow-up design supporting the "ray" backend is ideal, but there are currently some limitations preventing such implementation due to how the EngineCore hooks are run.

Related Issue

Fix #661

Type of Change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not
    work as expected)
  • Documentation update
  • Code refactoring (no functional changes)
  • Performance improvement
  • Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • I have run formatting tools (pre-commit or manual)
  • I have run relevant unit tests and they pass (test_rtensor.py)
  • I have added tests for new functionality
  • I have updated documentation if needed
  • My branch is up to date with main
  • This PR introduces breaking changes (if yes, fill out details below)
  • If this PR changes documentation, I have built and previewed it locally with
    jb build docs
  • No critical issues raised by AI reviewers (/gemini review)

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @HwVanICI, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the system's scalability by integrating Ray for distributed scheduling. It replaces the local scheduling mechanism with a Ray-based approach, allowing for seamless multi-node deployments of training and inference workloads. The core change involves the introduction of RayScheduler to manage worker actors and RayRPCServer for efficient, Ray-native inter-actor communication, moving away from HTTP-based RPC. This lays the groundwork for more robust and scalable distributed machine learning experiments.

Highlights

  • Ray Scheduler Introduction: A new RayScheduler is implemented, serving as a distributed alternative to the LocalScheduler for multi-node deployments.
  • Ray Actor-based Workers: The RayScheduler creates Ray actors for each training rank and rollout instance, with Ray managing device assignments.
  • RayRPCServer for Communication: A RayRPCServer class is introduced, mirroring the existing RPCServer but utilizing Ray remote calls for inter-process communication instead of HTTP.
  • Current Backend Support: The initial implementation is tested with vLLM using the "mp" distributed executor backend.
  • Future Backend Plans: Future work aims to support the "ray" backend, pending resolution of current limitations with EngineCore hooks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a RayScheduler as a distributed alternative to the LocalScheduler, enabling multi-node deployments. The implementation includes the RayScheduler class for managing Ray actors and placement groups, and the RayRPCServer actor which wraps the training/inference engines. The code is well-structured, with robust error handling and retry mechanisms for remote calls. I've identified a minor bug in a log message and a typo, and have also suggested a small design improvement for handling worker ports to enhance maintainability. Overall, this is a solid first implementation of Ray-based scheduling.

).remote()

# 0 needed to pad the list as the trainer takes index 1 for ports
worker_ports = ["0", str(master_port)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of a padded list ["0", str(master_port)] for worker_ports seems a bit brittle, as it relies on an implicit contract with the consumer (the trainer) about which index to use. For future improvements, consider using a more descriptive data structure like a dictionary ({"master_port": master_port}), or making the consumer more robust to handle different port list formats. This would make the code easier to understand and maintain.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is according to TrainController taking port index 1 instead of 0.

HwVanICI and others added 3 commits December 17, 2025 11:45
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Copy link
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for some minor style issues.


def tensor_container_to(
d: dict[str, Any] | torch.Tensor | list[torch.Tensor], *args, **kwargs
d: dict[str, Any] | torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tuple annotation is not correct. It should be tuple[torch.Tensor, ...].

if self._engine is None:
raise RuntimeError("Engine not initialized. Call create_engine() first")

should_bcast = kwargs.pop("_should_bcast", True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This key has been changed to should_broadcast.

Comment on lines 101 to 114
try:
fn = getattr(self._engine, method)
result = fn(*args, **kwargs)
if isinstance(result, Future):
result = result.result()
# put back to cpu to mimic RPCServer encode/decode
result = tensor_container_to(result, "cpu")
return result
except Exception as e:
self.logger.error(
f"RayRPCServer Engine method '{method}' failed: {e}\n"
f"{traceback.format_exc()}"
)
raise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to add some debug log with logger.debug around the engine creation and method calls.

Comment on lines 48 to 59
def __init__(
self,
gpu_devices: list[int] | None = None,
log_dir: str | None = None,
startup_timeout: float = 30.0,
health_check_interval: float = 1.0,
*,
fileroot: str | None = None,
experiment_name: str | None = None,
trial_name: str | None = None,
exp_config: BaseExperimentConfig | None = None,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to maintain the same init APIs for different schedulers. If some parameters are not used, they can be removed.

Comment on lines 46 to 55
def ray_resource_type():
if torch.cuda.is_available():
return "GPU"

from areal.platforms import is_npu_available

if is_npu_available:
return "NPU"

return "CPU"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to move it to areal.platform or make it inline with the ray scheduler.

options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem)

env = get_env_vars(
"", ",".join([f"{k}={v}" for k, v in spec.env_vars.items()])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should input a cluster name, which should be an init argument of ray scheduler

Comment on lines 596 to 597
ref = wi.actor.call.remote(method, *args, **kwargs)
result = await asyncio.to_thread(ray.get, ref, timeout=http_timeout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use native Ray async APIs instead of threading, which may have some risk of getting stuck.

ref: https://docs.ray.io/en/latest/ray-core/actors/async_api.html

@garrett4wade
Copy link
Collaborator

One more thing: it would be better if we write a unit test file to test the basic functionality of RayScheduler, e.g., in areal/tests/test_ray_scheduler.py.

@HwVanICI HwVanICI changed the title Ray Scheduler Implementation for Single Controller [WIP] Ray Scheduler Implementation for Single Controller Dec 18, 2025
@HwVanICI
Copy link
Collaborator Author

I am putting this back to WIP to accommodate the RTensor changes.

@HwVanICI
Copy link
Collaborator Author

HwVanICI commented Dec 18, 2025

To accommodate the RTensor changes, I will need to implement a RayRTensor class using the Ray Object Store instead of HTTP. I plan to refactor the RTensor class such that there is a BaseRTensor class that implements RTensor's shared staticmethods as classmethods. Then I would have RTensor and RayRTensor both extend from BaseRTensor. Functions such as _fetch and from_batched would of course be specific to the subclasses.
Let me know if you have any other preferences.

@HwVanICI
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a RayScheduler as a distributed alternative to the existing LocalScheduler, enabling multi-node deployments. The implementation includes a RayRPCServer actor to manage engine lifecycle and communication, and a RayRTensor for distributed tensor representation. The changes are extensive and well-structured, introducing a significant new capability. My review focuses on ensuring resource management is robust, identifying potential bugs in the new logic, and suggesting improvements for performance and maintainability. Key areas of feedback include improving the reliability of resource cleanup, fixing a bug in placement group handling, and enhancing type safety and test coverage.

Comment on lines 627 to 631
def __del__(self):
try:
self.delete_workers()
except Exception:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using __del__ for resource cleanup is unreliable in Python, as its execution is not guaranteed, especially in the presence of reference cycles. This can lead to leaked Ray actors and placement groups. An explicit shutdown() method should be provided and called by the user to ensure proper resource release. The broad except Exception: pass also dangerously hides any errors that might occur during cleanup.

    def shutdown(self):
        """Shuts down the scheduler and cleans up all associated workers and resources."""
        try:
            self.delete_workers()
        except Exception as e:
            logger.error(f"Error during RayScheduler shutdown: {e}", exc_info=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicit shutdown is done by calling delete_workers.

Comment on lines 1 to 111
import ray
from ray.util.state import summarize_actors

from areal.api.cli_args import BaseExperimentConfig
from areal.api.scheduler_api import (
Job,
SchedulingSpec,
)
from areal.scheduler.ray import RayScheduler, ray_resource_type


class TestRaySchedulerInitialization:
def test_init(self):
scheduler = RayScheduler(
startup_timeout=60.0, exp_config=BaseExperimentConfig()
)
assert scheduler.startup_timeout == 60.0


class TestWorkerCreationAndDeletion:
def test_create_delete_workers(self):
ray.init()

config = BaseExperimentConfig()

scheduler = RayScheduler(startup_timeout=60.0, exp_config=config)

job = Job(
replicas=2,
role="train",
tasks=[
SchedulingSpec(
cpu=1,
mem=1024,
gpu=1,
),
SchedulingSpec(
cpu=1,
mem=1024,
gpu=1,
),
],
)

# create workers
worker_ids = scheduler.create_workers(job)
assert len(worker_ids) == 2
assert len(scheduler._workers["train"]) == 2

actor_summary = summarize_actors()

assert (
actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["ALIVE"]
== 2
)
assert len(scheduler.get_workers("train")) == 2

scheduler._ping_workers("train")

# delete workers
scheduler.delete_workers()
assert len(scheduler._workers["train"]) == 0

actor_summary = summarize_actors()
assert (
actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["DEAD"]
== 2
)


class TestUtilityFunctions:
def test_utilities(self):
_num_gpu_per_node = 16
config = BaseExperimentConfig()

config.cluster.n_gpus_per_node = _num_gpu_per_node

scheduler = RayScheduler(startup_timeout=60.0, exp_config=config)

schedulings = [
SchedulingSpec(
cpu=1,
mem=1024,
gpu=1,
),
SchedulingSpec(
cpu=1,
mem=1024,
gpu=1,
),
]

new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings)
assert len(new_schedulings) == 2
for spec in new_schedulings:
assert spec.cpu == 1
assert spec.mem == 1024
assert spec.gpu == 1

# case where only 1 spec is passed but multiple workers
new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings[0:])
assert len(new_schedulings) == 2
for spec in new_schedulings:
assert spec.cpu == 1
assert spec.mem == 1024
assert spec.gpu == 1

bundle_list = scheduler._create_bundle_list_gpu(1, 24, 1024)
assert len(bundle_list) == 2
for bundle in bundle_list:
assert bundle[ray_resource_type()] <= _num_gpu_per_node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The added tests for RayScheduler cover basic initialization and worker creation/deletion, which is a good start. However, there is no test coverage for more complex and critical functionalities, such as call_engine, async_call_engine, and the RayRTensor logic. These are core components of the new Ray-based scheduling and should be tested to ensure correctness and prevent regressions.

HwVanICI and others added 6 commits December 18, 2025 16:49
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@HwVanICI HwVanICI changed the title [WIP] Ray Scheduler Implementation for Single Controller Ray Scheduler Implementation for Single Controller Dec 22, 2025
@HwVanICI
Copy link
Collaborator Author

I've performed the refactor and implemented the RayRTensor class. Should be ready for review again.

Copy link
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HwVanICI , thanks for the addition of RayRTensor, but currently there are too much redundant code in both RTensor implementations. At the current stage we should consider improving the code quality before merging:

There's no need to create a base abstract class. Ray only differs in how the data is fetched and what's the in-memory form of the tensor metadata. That's the only part that we should extend. I suggest only subclassing the ShardInfo class and let it provide the fetch functionality in a dependency injection manner, and we nearly don't modify the top-level RTensor implementation.

@HwVanICI
Copy link
Collaborator Author

Thanks for the suggestion. I have done as requested and updated the RayScheduler code to be compatible with the PPOTrainer changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature] RayScheduler support in single-controller mode

3 participants