Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1f9f63b
Ray scheduler implementation
hlyli Dec 17, 2025
40d2428
Merge branch 'main' into ray_scheduler
HwVanICI Dec 17, 2025
3701cef
Update areal/scheduler/ray.py
HwVanICI Dec 17, 2025
a9ebec2
Update areal/scheduler/ray.py
HwVanICI Dec 17, 2025
aa4fec5
Stylistic changes and remove asyncio.to_thread from ray calls
hlyli Dec 18, 2025
18d7737
RayRTensor, RTensor refactor, and tests for RayScheduler
hlyli Dec 19, 2025
15c28a8
Merge branch 'main' into ray_scheduler
HwVanICI Dec 19, 2025
9942924
Fix typos
hlyli Dec 19, 2025
691b4d3
Merge branch 'ray_scheduler' of https://github.com/HwVanICI/AReaL int…
hlyli Dec 19, 2025
fd2b0b3
Update areal/scheduler/rpc/rtensor.py
HwVanICI Dec 19, 2025
75c5e53
Update areal/scheduler/ray.py
HwVanICI Dec 19, 2025
6e34fa2
Add gemini suggestions
hlyli Dec 19, 2025
bc3ee12
Merge branch 'ray_scheduler' of https://github.com/HwVanICI/AReaL int…
hlyli Dec 19, 2025
8681a27
Fix rtensor test regex assertion error
hlyli Dec 22, 2025
f18f77d
Merge branch 'main' into ray_scheduler
HwVanICI Dec 22, 2025
30abe0e
Tests for ray scheduler create and call engine
hlyli Dec 22, 2025
27ce76a
Merge branch 'ray_scheduler' of https://github.com/HwVanICI/AReaL int…
hlyli Dec 22, 2025
60ecdd8
Refactor ray implementation of rtensor to use dependency injection in…
hlyli Dec 23, 2025
dd221ca
Merge branch 'main' into ray_scheduler
HwVanICI Dec 23, 2025
9190ddd
Fix torch import
hlyli Dec 23, 2025
fb40ac2
Support PPOTrainer change for RayScheduler
hlyli Dec 23, 2025
5f7967b
Remove gsm8k_grpo_ray.py as it is handled in a unified script now.
hlyli Dec 23, 2025
265fa1c
Merge branch 'main' into ray_scheduler
HwVanICI Dec 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions areal/controller/train_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from pathlib import Path
from typing import Any

import aiohttp
import torch.distributed as dist
from torchdata.stateful_dataloader import StatefulDataLoader

Expand Down Expand Up @@ -579,25 +578,14 @@ def update_weights(self, meta: WeightUpdateMeta):
raise ValueError(f"Unknown weight update type {meta.type}")

async def _async_clear_batches(self, *targets: dict[str, RTensor]):
"""Extract shard IDs and call /data/clear on each worker."""
"""Extract shard IDs and clear tensors on each worker."""
shards_by_node = RTensor.collect_shards(targets)

if not shards_by_node:
return

async def clear_node(node_addr, shard_ids):
async with aiohttp.ClientSession() as session:
async with session.delete(
f"http://{node_addr}/data/clear", json={"shard_ids": shard_ids}
) as resp:
if resp.status == 200:
result = await resp.json()
logger.info(
f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}"
)

await asyncio.gather(
*[clear_node(addr, sids) for addr, sids in shards_by_node.items()],
*[RTensor.clear_node(addr, sids) for addr, sids in shards_by_node.items()],
return_exceptions=True,
)

Expand Down
8 changes: 8 additions & 0 deletions areal/core/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Protocol

import aiohttp
import ray
import requests
import torch.distributed as dist
import uvloop
Expand Down Expand Up @@ -949,6 +950,13 @@ def launch_server(self, server_args: dict[str, Any]) -> LocalInfServerInfo:
try:
self._wait_for_server(address)
self.local_server_processes.append(server_info)
if ray.is_initialized():
# do not return with process for ray as it is not picklable
return LocalInfServerInfo(
host=server_args["host"],
port=server_args["port"],
process=None,
)
return server_info
except TimeoutError:
logger.warning(
Expand Down
4 changes: 3 additions & 1 deletion areal/experimental/trainer/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from areal.engine.sglang_remote import RemoteSGLangEngine
from areal.engine.vllm_remote import RemotevLLMEngine
from areal.platforms import current_platform
from areal.scheduler import LocalScheduler
from areal.scheduler import LocalScheduler, RayScheduler
from areal.utils import logging, perf_tracer, seeding, stats_tracker
from areal.utils.dataloader import create_dataloader
from areal.utils.environ import is_single_controller
Expand Down Expand Up @@ -395,6 +395,8 @@ def _init_scheduler(self) -> Scheduler:
cfg = self.config.scheduler
if cfg.type == "local":
return LocalScheduler(exp_config=self.config)
elif cfg.type == "ray":
return RayScheduler(exp_config=self.config)
raise NotImplementedError(f"Unknown scheduler type: {cfg.type}")

def _create_dataloader(
Expand Down
5 changes: 2 additions & 3 deletions areal/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .local import LocalScheduler
from .ray import RayScheduler

__all__ = [
"LocalScheduler",
]
__all__ = ["LocalScheduler", "RayScheduler"]
Loading