From 1f9f63bbe54522a1ca6f87776d1f00736fb75ab9 Mon Sep 17 00:00:00 2001 From: Haley Date: Wed, 17 Dec 2025 11:11:52 -0800 Subject: [PATCH 01/15] Ray scheduler implementation --- areal/core/remote_inf_engine.py | 8 + areal/scheduler/ray.py | 633 +++++++++++++++++++ areal/scheduler/rpc/ray_rpc_server.py | 124 ++++ areal/utils/data.py | 6 +- areal/utils/device.py | 13 + examples/single-controller/gsm8k_grpo_ray.py | 215 +++++++ 6 files changed, 997 insertions(+), 2 deletions(-) create mode 100644 areal/scheduler/ray.py create mode 100644 areal/scheduler/rpc/ray_rpc_server.py create mode 100644 examples/single-controller/gsm8k_grpo_ray.py diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index b1a67ddbb..1fc84ca0b 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -14,6 +14,7 @@ from typing import Any, Protocol import aiohttp +import ray import requests import torch.distributed as dist import uvloop @@ -945,6 +946,13 @@ def launch_server(self, server_args: dict[str, Any]) -> LocalInfServerInfo: try: self._wait_for_server(address) self.local_server_processes.append(server_info) + if ray.is_initialized(): + # do not return with process for ray as it is not picklable + return LocalInfServerInfo( + host=server_args["host"], + port=server_args["port"], + process=None, + ) return server_info except TimeoutError: logger.warning( diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py new file mode 100644 index 000000000..a2c9fdb19 --- /dev/null +++ b/areal/scheduler/ray.py @@ -0,0 +1,633 @@ +import asyncio +import math +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + +import ray +import ray.exceptions +from ray.runtime_env import RuntimeEnv +from ray.util.placement_group import ( + PlacementGroup, + placement_group, + remove_placement_group, +) +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker +from areal.scheduler.exceptions import ( + EngineCallError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.scheduler.rpc.ray_rpc_server import RayRPCServer +from areal.utils import logging +from areal.utils.device import ray_resource_type +from areal.utils.launcher import get_env_vars +from areal.utils.ray import get_placement_group_master_ip_and_port + +logger = logging.getLogger("RayScheduler") + + +@dataclass +class RayWorkerInfo: + worker: Worker + actor: ray.actor.ActorHandle + role: str + placement_group: PlacementGroup + bundle_index: int + created_at: float + env_vars: dict[str, str] = field(default_factory=dict) + + +class RayScheduler(Scheduler): + def __init__( + self, + gpu_devices: list[int] | None = None, + log_dir: str | None = None, + startup_timeout: float = 30.0, + health_check_interval: float = 1.0, + *, + fileroot: str | None = None, + experiment_name: str | None = None, + trial_name: str | None = None, + exp_config: BaseExperimentConfig | None = None, + ): + # we do not set up logging dir as it is done by Ray + if log_dir is not None: + logger.warning( + f"log_dir {log_dir} will not be used for Ray. Check /tmp/ray/session_*/logs for Ray logs" + ) + self.exp_config = exp_config + self.gpu_devices = gpu_devices + + self.startup_timeout = startup_timeout + self.health_check_interval = health_check_interval + + self._workers: dict[str, list[RayWorkerInfo]] = defaultdict(list) + self._placement_groups: list[PlacementGroup] = [] + + def _prepare_worker_specs( + self, role: str, num_workers: int, schedulings: list[SchedulingSpec] | None + ) -> list[SchedulingSpec]: + if not schedulings: + raise WorkerCreationError( + role, "Invalid configuration", "Tasks SchedulingSpec must be provided" + ) + if len(schedulings) == 1: + return [schedulings[0]] * num_workers + + if len(schedulings) == num_workers: + return schedulings + + raise WorkerCreationError( + role, + "Invalid Configuration", + f"schedulings length ({len(schedulings)}) must be 1 or equal to replicas ({num_workers})", + ) + + def _bundle_spec(self, cpu: int, gpu: int, mem: int) -> dict: + """ + define a bundle dict for a given cpu, gpu, mem requirement + """ + device = ray_resource_type() + if device == "CPU" and gpu > 0: + raise ValueError( + f"Current detected device is CPU but specified number of GPUs is {gpu}" + ) + return { + "CPU": cpu, + device: float(gpu), + "memory": mem * 1024 * 1024, # convert mb to bytes + } + + def _create_bundle_list_gpu(self, cpu: int, gpu: int, mem: int) -> list[dict]: + """ + for dividing out resources so that 1 bundle can be contained on 1 node and creates a list of bundles + """ + bundle_list = [] + + n_gpus_per_node = self.exp_config.cluster.n_gpus_per_node + + if n_gpus_per_node == 0 and gpu > 0: + raise ValueError( + f"Requested {gpu} GPUs but number of GPUs per node is {n_gpus_per_node}" + ) + + if gpu < n_gpus_per_node: + return [self._bundle_spec(cpu, gpu, mem)] + + gpu_remaining_to_be_assigned = gpu + + while gpu_remaining_to_be_assigned > 0: + # do not want to take all gpus in node if we do not need that many + gpu_in_bundle = min(gpu_remaining_to_be_assigned, n_gpus_per_node) + + # for scaling the amount of cpu and memory relative to gpu in bundle + resource_per_node_multiplier = min(gpu_in_bundle / gpu, 1) + cpu_in_bundle = math.ceil(cpu * resource_per_node_multiplier) + mem_in_bundle = math.ceil(mem * resource_per_node_multiplier) + + bundle_list.append( + self._bundle_spec(cpu_in_bundle, gpu_in_bundle, mem_in_bundle) + ) + gpu_remaining_to_be_assigned -= gpu_in_bundle + + return bundle_list + + def _actor_resource_spec(self, cpu: int, gpu: int, mem: int) -> dict: + """ + create a dictionary for passing into ray actor options specifying resource requirements + """ + + device = ray_resource_type() + if device == "CPU" and gpu > 0: + raise ValueError( + f"Current detected device is CPU but specified number of GPUs is {gpu}" + ) + + return { + "num_cpus": cpu, + "resources": {device: float(gpu)}, + "memory": mem * 1024 * 1024, + } + + def _sum_resource_spec( + self, schedulings: list[SchedulingSpec] + ) -> tuple[int, int, int]: + num_cpu = sum(spec.cpu for spec in schedulings) + num_gpu = sum(spec.gpu for spec in schedulings) + num_mem = sum(spec.mem for spec in schedulings) + + return (num_cpu, num_gpu, num_mem) + + def _ping_workers(self, role: str, timeout: float | None = None): + worker_info_list = self._workers[role] + timeout = timeout if timeout is not None else self.startup_timeout + refs = [wi.actor.ping.remote() for wi in worker_info_list] + + ref_to_worker = {ref: wi for wi, ref in zip(worker_info_list, refs)} + + pending = refs + while pending: + ready, pending = ray.wait(pending, num_returns=1, timeout=timeout) + # ray.wait timed out + if len(ready) == 0: + raise WorkerTimeoutError(role, timeout) + + ref = ready[0] + + try: + # get to determine if this is a failed actor + ray.get(ref) + except ray.exceptions.GetTimeoutError: + failed_worker = ref_to_worker[ref] + raise WorkerTimeoutError(failed_worker.worker.id, timeout) + except ray.exceptions.RayActorError: + failed_worker = ref_to_worker[ref] + raise WorkerFailedError(failed_worker.worker.id, -1) + + def _create_rollout_workers( + self, role: str, schedulings: list[SchedulingSpec] + ) -> tuple[list[RayWorkerInfo], list[str]]: + """ + Crate rollout workers, assuming 1 worker per rollout instance. + + Parameters + --------- + role: str + schedulings: list[SchedulingSpec] + + Returns + -------- + Tuple[list[RayWorkerInfo], list[str]] + List of RayWorkerInfo of created workers + List of worker IDs created + """ + + worker_info_list: list[RayWorkerInfo] = [] + worker_ids: list[str] = [] + + # create placement_groups + for idx, spec in enumerate(schedulings): + worker_id = f"{role}/{idx}" + + bundles = [self._bundle_spec(spec.cpu, spec.gpu, spec.mem)] + pg = placement_group(bundles, strategy="PACK") + + try: + ray.get(pg.ready(), timeout=self.startup_timeout) + except ray.exceptions.GetTimeoutError: + logger.error( + f"Ray placement group timeout for train role {role}\n" + f"ray.nodes(): {ray.nodes()}" + f"bundles: {bundles}" + ) + raise + self._placement_groups.append(pg) + + master_ip, master_port = get_placement_group_master_ip_and_port( + pg, placement_group_bundle_index=0 + ) + + # define resources to actor + options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) + + env = get_env_vars( + "", ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]) + ) + + if spec.env_vars: + env.update(spec.env_vars) + + actor = RayRPCServer.options( + **options, + name=worker_id, + runtime_env=RuntimeEnv(env_vars=env), + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + placement_group_capture_child_tasks=True, + ), + ).remote() + + # 0 needed to pad the list as the trainer takes index 1 for ports + worker_ports = ["0", str(master_port)] + worker = Worker( + id=worker_id, ip=master_ip, worker_ports=worker_ports, engine_ports=[] + ) + + wi = RayWorkerInfo( + worker=worker, + actor=actor, + role=role, + placement_group=pg, + bundle_index=0, + created_at=time.time(), + env_vars=env, + ) + + worker_info_list.append(wi) + worker_ids.append(worker_id) + + return worker_info_list, worker_ids + + def _create_train_workers( + self, role: str, schedulings: list[SchedulingSpec] + ) -> tuple[list[RayWorkerInfo], list[str]]: + """ + Create workers for training roles. One PG per role with multiple bundles. + Assume 1 ray worker per train rank. + + Parameters + --------- + role: str + schedulings: list[SchedulingSpec] + + Returns + -------- + Tuple[list[RayWorkerInfo], list[str]] + List of RayWorkerInfo of created workers + List of worker IDs created + """ + # build bundles + sum_cpu, sum_gpu, sum_mem = self._sum_resource_spec(schedulings) + bundles: list[dict[str, float]] = self._create_bundle_list_gpu( + sum_cpu, sum_gpu, sum_mem + ) + + pg = placement_group(bundles=bundles, strategy="PACK") + + try: + ray.get(pg.ready(), timeout=self.startup_timeout) + except ray.exceptions.GetTimeoutError: + logger.error( + f"Ray placement group timeout for train role {role}\n" + f"ray.nodes(): {ray.nodes()}" + f"bundles: {bundles}" + ) + raise + + self._placement_groups.append(pg) + + master_ip, master_port = get_placement_group_master_ip_and_port( + pg, placement_group_bundle_index=0 + ) + + worker_info_list: list[RayWorkerInfo] = [] + worker_ids: list[str] = [] + + for idx, spec in enumerate(schedulings): + worker_id = f"{role}/{idx}" + + options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) + + env = get_env_vars( + "", ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]) + ) + + if spec.env_vars: + env.update(spec.env_vars) + + actor = RayRPCServer.options( + **options, + name=worker_id, + runtime_env=RuntimeEnv(env_vars=env), + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_capture_child_tasks=True + ), + ).remote() + + worker_ports = ["0", str(master_port)] + worker = Worker( + id=worker_id, ip=master_ip, worker_ports=worker_ports, engine_ports=[] + ) + + wi = RayWorkerInfo( + worker=worker, + actor=actor, + role=role, + placement_group=pg, + bundle_index=None, # decided by ray + created_at=time.time(), + env_vars=env, + ) + worker_info_list.append(wi) + worker_ids.append(worker_id) + + return worker_info_list, worker_ids + + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: + """ + Create worker actors. + + Parameters + -------- + job: Job + Job configuration with role, replicas, tasks, scheduling strategy + *args + Additional arguments (UNUSED) + **kwargs + Additional keyword arguments (UNUSED) + + Returns + -------- + list[str] + List of worker IDs created (e.g., ["rollout/0", "rollout/1]) + + Raises + -------- + WorkerCreationError + If worker creation fails + """ + role = job.role + if role in self._workers: + raise WorkerCreationError( + role, + "Worker group already exists", + f"Use delete_workers('{role}') first to remove existing workers.", + ) + + num_workers = job.replicas + if num_workers == 0: + raise WorkerCreationError( + role, "Invalud configuration", "replicas must be greater than 0" + ) + + schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) + + strategy = job.scheduling_strategy + if strategy is None: + strategy_type = "separation" + else: + strategy_type = strategy.type or "separation" + if strategy_type == "colocation": + raise WorkerCreationError( + role, + "Unavailable strategy type", + "RayScheduler only supports separation strategy", + ) + + if role == "rollout": + worker_info_list, worker_ids = self._create_rollout_workers( + role, schedulings + ) + else: + worker_info_list, worker_ids = self._create_train_workers(role, schedulings) + + self._workers[role].extend(worker_info_list) + + self._ping_workers(role, self.startup_timeout) + + if self.exp_config is not None: + for rank, wi in enumerate(worker_info_list): + try: + wi.actor.configure.remote(self.exp_config, wi.role, rank) + except Exception as e: + logger.error( + f"Confgiure failed on worker {wi.worker.id}: {e}", exc_info=True + ) + self._cleanup_workers(worker_info_list) + raise WorkerCreationError( + role, "Worker configuration failed", str(e) + ) + + return worker_ids + + def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: + if role not in self._workers: + raise WorkerNotFoundError(role) + + worker_info_list = self._workers[role] + + self._ping_workers(role, timeout) + + return [wi.worker for wi in worker_info_list] + + def delete_workers(self, role: str | None = None): + """ + Delete workers and clean up resources + + Parameters + -------- + role: str, optional + Specific worker role to delete, or None to delete all + """ + if role is None: + roles = list(self._workers.keys()) + for r in roles: + self.delete_workers(r) + return + + if role not in self._workers: + logger.warning(f"Worker role '{role}' not found, skipping deletion") + return + + workers = self._workers[role] + logger.info(f"Deleting {len(workers)} workers for role '{role}'") + + self._cleanup_workers(workers) + + del self._workers[role] + + logger.info(f"Successfully deleted workers for role '{role}'") + + def _cleanup_workers(self, workers: list[RayWorkerInfo]): + for wi in workers: + actor = wi.actor + try: + actor.destroy.remote() + except Exception: + logger.warning( + f"Could not destroy remote actor {actor}, force killing actor" + ) + ray.kill(actor, no_restart=True) + + try: + remove_placement_group(wi.placement_group) + except Exception: + logger.warning(f"Could not remove placement group {placement_group}") + if wi.placement_group in self._placement_groups: + self._placement_groups.remove(wi.placement_group) + + def _get_worker_info_by_id(self, worker_id: str) -> RayWorkerInfo | None: + for worker_info_list in self._workers.values(): + for wi in worker_info_list: + if wi.worker.id == worker_id: + return wi + return None + + async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + if not env: + return + + ref = wi.actor.set_env.remote(env) + await asyncio.to_thread(ray.get, ref) + wi.env_vars.update(env) + + async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + if not isinstance(engine, str): + raise WorkerCreationError( + worker_id, f"Engine must be a string import path, got {type(engine)}" + ) + ref = wi.actor.create_engine.remote(engine, *args, **kwargs) + await asyncio.to_thread(ray.get, ref) + + def call_engine( + self, + worker_id: str, + method: str, + *args, + http_timeout: float = 7200.0, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + last_error: str | None = None + + for attempt in range(1, max_retries + 1): + try: + ref = wi.actor.call.remote(method, *args, **kwargs) + result = ray.get(ref, timeout=http_timeout) + if attempt > 1: + logger.info( + f"Method '{method}' on '{worker_id}' " + f"succeeded after {attempt} attempts" + ) + return result + except ray.exceptions.GetTimeoutError as e: + last_error = f"Timeout: {e}" + except ray.exceptions.RayActorError as e: + raise WorkerFailedError(worker_id, -1, str(e)) from e + except ray.exceptions.RayTaskError as e: + raise EngineCallError(worker_id, method, str(e), attempt) from e + except EngineCallError: + raise + except Exception as e: + last_error = f"Ray call failed: {e}" + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + + raise EngineCallError( + worker_id, method, last_error or "Max retries exceeded", attempt=max_retries + ) + + async def async_call_engine( + self, + worker_id: str, + method: str, + *args, + http_timeout: float = 7200.0, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + last_error: str | None = None + + for attempt in range(1, max_retries + 1): + try: + ref = wi.actor.call.remote(method, *args, **kwargs) + result = await asyncio.to_thread(ray.get, ref, timeout=http_timeout) + if attempt > 1: + logger.info( + f"Method '{method}' on '{worker_id}' " + f"succeeded after {attempt} attempts" + ) + return result + except ray.exceptions.GetTimeoutError as e: + last_error = f"Timeout: {e}" + except ray.exceptions.RayActorError as e: + raise WorkerFailedError(worker_id, -1, str(e)) from e + except ray.exceptions.RayTaskError as e: + raise EngineCallError(worker_id, method, str(e), attempt) from e + except EngineCallError: + raise + except Exception as e: + last_error = f"Ray async call failed: {e}" + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + + raise EngineCallError( + worker_id, method, last_error or "Max retries exceeded", attempt=max_retries + ) + + def __del__(self): + try: + self.delete_workers() + except Exception: + pass diff --git a/areal/scheduler/rpc/ray_rpc_server.py b/areal/scheduler/rpc/ray_rpc_server.py new file mode 100644 index 000000000..e118a12f6 --- /dev/null +++ b/areal/scheduler/rpc/ray_rpc_server.py @@ -0,0 +1,124 @@ +import os +import traceback +from concurrent.futures import Future +from typing import Any + +import ray + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.utils import logging, name_resolve, seeding +from areal.utils.data import ( + broadcast_tensor_container, + tensor_container_to, +) +from areal.utils.dynamic_import import import_from_string + + +@ray.remote +class RayRPCServer: + """ + Ray engine container. Represents either: + - one training world rank, or + - one rollout instance + + Placement group scheduling is controlled by the scheduler. + The actor is only responsible for the engine lifecycle and method calls + within this process. + """ + + def __init__(self): + self._engine: TrainEngine | InferenceEngine | None = None + self.logger = logging.getLogger("RayRPCServer") + + def _get_device(self): + # lazy resolve the device inside worker process + from areal.platforms import current_platform + + return current_platform.current_device() + + def ping(self) -> str: + return "ok" + + def configure(self, config: BaseExperimentConfig, role: str, rank: int) -> None: + name_resolve.reconfigure(config.cluster.name_resolve) + if isinstance(self._engine, TrainEngine): + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + self.logger.info(f"RayRPCServer configured for role role={role}, rank={rank}") + + def set_env(self, env: dict[str, str]) -> None: + for k, v in env.items(): + os.environ[str(k)] = str(v) + + def create_engine(self, engine_path: str, *init_args, **init_kwargs) -> None: + try: + engine_class = import_from_string(engine_path) + if not issubclass(engine_class, (TrainEngine, InferenceEngine)): + raise TypeError( + f"Engine class must be a TrainEngine or InferenceEngine, but got {engine_class}" + ) + self._engine = engine_class(*init_args, **init_kwargs) + self.logger.info( + f"RayRPCServer Engine '{engine_path}' instantiated successfully!" + ) + except Exception as e: + self.logger.error( + f"RayRPCServer failed to create engine '{engine_path}' : {e}\n" + f"{traceback.format_exc()}" + ) + raise + + def call(self, method: str, *args, **kwargs) -> Any: + if self._engine is None: + raise RuntimeError("Engine not initialized. Call create_engine() first") + + should_bcast = kwargs.pop("_should_bcast", True) + + # keep broadcast behavior the same as RPCServer + try: + if should_bcast and isinstance(self._engine, TrainEngine): + device = self._get_device() + + args = tensor_container_to(args, device) + args = broadcast_tensor_container( + args, + src_rank=self._engine.current_data_parallel_head(), + group=self._engine.context_and_model_parallel_group, + ) + kwargs = tensor_container_to(kwargs, device) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=self._engine.current_data_parallel_head(), + group=self._engine.context_and_model_parallel_group, + ) + except Exception as e: + self.logger.error( + f"RayRPCServer broadcast failed for '{method}': {e}\n" + f"{traceback.format_exc()}" + ) + raise + + try: + fn = getattr(self._engine, method) + result = fn(*args, **kwargs) + if isinstance(result, Future): + result = result.result() + # put back to cpu to mimic RPCServer encode/decode + result = tensor_container_to(result, "cpu") + return result + except Exception as e: + self.logger.error( + f"RayRPCServer Engine method '{method}' failed: {e}\n" + f"{traceback.format_exc()}" + ) + raise + + def destroy(self) -> None: + if self._engine is not None: + try: + self._engine.destroy() + self.logger.info("RayRPCServer Engine destroyed successfully") + except Exception as e: + self.logger.error(f"RayRPCServer error destroying engine: {e}") + self._engine = None + ray.actor.exit_actor() diff --git a/areal/utils/data.py b/areal/utils/data.py index b23f48cd4..6011ed4e2 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -336,7 +336,9 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: list[torch.Tensor]): def tensor_container_to( - d: dict[str, Any] | torch.Tensor | list[torch.Tensor], *args, **kwargs + d: dict[str, Any] | torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor], + *args, + **kwargs, ): """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. Support nested dictionaries. @@ -344,7 +346,7 @@ def tensor_container_to( if torch.is_tensor(d): return d.to(*args, **kwargs) - if isinstance(d, list): + if isinstance(d, list) or isinstance(d, tuple): return [tensor_container_to(v, *args, **kwargs) for v in d] if isinstance(d, dict): diff --git a/areal/utils/device.py b/areal/utils/device.py index 550270713..38a5a4a54 100644 --- a/areal/utils/device.py +++ b/areal/utils/device.py @@ -1,5 +1,6 @@ import gc +import torch import torch.distributed as dist from areal.platforms import current_platform @@ -40,3 +41,15 @@ def clear_memory(): current_platform.synchronize() gc.collect() current_platform.empty_cache() + + +def ray_resource_type(): + if torch.cuda.is_available(): + return "GPU" + + from areal.platforms import is_npu_available + + if is_npu_available: + return "NPU" + + return "CPU" diff --git a/examples/single-controller/gsm8k_grpo_ray.py b/examples/single-controller/gsm8k_grpo_ray.py new file mode 100644 index 000000000..e8db9ec17 --- /dev/null +++ b/examples/single-controller/gsm8k_grpo_ray.py @@ -0,0 +1,215 @@ +import os +import sys + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config, vLLMConfig +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.engine.vllm_remote import RemotevLLMEngine +from areal.scheduler.ray import RayScheduler +from areal.utils import stats_tracker +from areal.utils.dataloader import create_dataloader +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders + train_dataset = get_custom_dataset( + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + ) + + train_dataloader = create_dataloader( + train_dataset, + rank=0, + world_size=1, + dataset_config=config.train_dataset, + ) + + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize scheduler + scheduler = RayScheduler(exp_config=config) + + # Initialize train controller + allocation_mode = AllocationMode.from_str(config.allocation_mode) + actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) + actor.initialize( + role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Initialize inference engine + + if allocation_mode.gen_backend == "sglang": + engine_class = RemoteSGLangEngine + server_args = SGLangConfig.build_args( + sglang_config=config.sglang, + tp_size=allocation_mode.gen.tp_size, + base_gpu_id=0, + ) + elif allocation_mode.gen_backend == "vllm": + engine_class = RemotevLLMEngine + server_args = vLLMConfig.build_args( + vllm_config=config.vllm, + tp_size=allocation_mode.gen.tp_size, + pp_size=allocation_mode.gen.pp_size, + ) + else: + raise ValueError(f"Unsupported gen_backend: '{allocation_mode.gen_backend}'") + + rollout = RolloutController( + engine_class, config=config.rollout, scheduler=scheduler + ) + rollout.initialize( + role="rollout", + alloc_mode=allocation_mode, + server_args=server_args, + ) + + weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + file_root=config.cluster.fileroot, + ) + actor.connect_engine(rollout, weight_update_meta) + + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) + ref.initialize( + role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + + try: + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + workflow_kwargs = dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ) + rollout_batch = actor.prepare_batch( + train_dataloader, + workflow="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=workflow_kwargs, + ) + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + prox_logp = actor.compute_logp(rollout_batch) + rollout_batch["prox_logp"] = prox_logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + ref_logp = ref.compute_logp(rollout_batch) + rollout_batch["ref_logp"] = ref_logp + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + adv_batch = actor.compute_advantages(rollout_batch) + log_gpu_stats("compute advantages") + + with stats_tracker.record_timing("train_step"): + actor.ppo_update(adv_batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + actor.update_weights(weight_update_meta) + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + with stats_tracker.record_timing("clear_batches"): + actor.clear_batches(rollout_batch, adv_batch) + + # Upload statistics to the logger (e.g., wandb) + stats_logger.commit(epoch, step, global_step, actor.export_stats()) + + # Resume rollout + rollout.resume() + + finally: + stats_logger.close() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) From 3701cef44184e36bf67a456d95e29a2a07633ea9 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Wed, 17 Dec 2025 11:45:53 -0800 Subject: [PATCH 02/15] Update areal/scheduler/ray.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/scheduler/ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index a2c9fdb19..d3fe170e7 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -490,7 +490,7 @@ def _cleanup_workers(self, workers: list[RayWorkerInfo]): try: remove_placement_group(wi.placement_group) except Exception: - logger.warning(f"Could not remove placement group {placement_group}") + logger.warning(f"Could not remove placement group {wi.placement_group}") if wi.placement_group in self._placement_groups: self._placement_groups.remove(wi.placement_group) From a9ebec2a3da356469da70273846da89f0dfd0720 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Wed, 17 Dec 2025 11:46:04 -0800 Subject: [PATCH 03/15] Update areal/scheduler/ray.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/scheduler/ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index d3fe170e7..2ea69cd50 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -429,7 +429,7 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: wi.actor.configure.remote(self.exp_config, wi.role, rank) except Exception as e: logger.error( - f"Confgiure failed on worker {wi.worker.id}: {e}", exc_info=True + f"Configure failed on worker {wi.worker.id}: {e}", exc_info=True ) self._cleanup_workers(worker_info_list) raise WorkerCreationError( From aa4fec58253ed3430358946c3d82b45bf995d42a Mon Sep 17 00:00:00 2001 From: Haley Date: Thu, 18 Dec 2025 15:35:42 -0800 Subject: [PATCH 04/15] Stylistic changes and remove asyncio.to_thread from ray calls --- areal/scheduler/ray.py | 42 +++++++++++++-------------- areal/scheduler/rpc/ray_rpc_server.py | 7 +++-- areal/utils/data.py | 2 +- areal/utils/device.py | 13 --------- 4 files changed, 26 insertions(+), 38 deletions(-) diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index 2ea69cd50..38d87ddb7 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -7,6 +7,7 @@ import ray import ray.exceptions +import torch from ray.runtime_env import RuntimeEnv from ray.util.placement_group import ( PlacementGroup, @@ -26,13 +27,24 @@ ) from areal.scheduler.rpc.ray_rpc_server import RayRPCServer from areal.utils import logging -from areal.utils.device import ray_resource_type from areal.utils.launcher import get_env_vars from areal.utils.ray import get_placement_group_master_ip_and_port logger = logging.getLogger("RayScheduler") +def ray_resource_type(): + if torch.cuda.is_available(): + return "GPU" + + from areal.platforms import is_npu_available + + if is_npu_available: + return "NPU" + + return "CPU" + + @dataclass class RayWorkerInfo: worker: Worker @@ -47,26 +59,12 @@ class RayWorkerInfo: class RayScheduler(Scheduler): def __init__( self, - gpu_devices: list[int] | None = None, - log_dir: str | None = None, startup_timeout: float = 30.0, - health_check_interval: float = 1.0, *, - fileroot: str | None = None, - experiment_name: str | None = None, - trial_name: str | None = None, exp_config: BaseExperimentConfig | None = None, ): - # we do not set up logging dir as it is done by Ray - if log_dir is not None: - logger.warning( - f"log_dir {log_dir} will not be used for Ray. Check /tmp/ray/session_*/logs for Ray logs" - ) self.exp_config = exp_config - self.gpu_devices = gpu_devices - self.startup_timeout = startup_timeout - self.health_check_interval = health_check_interval self._workers: dict[str, list[RayWorkerInfo]] = defaultdict(list) self._placement_groups: list[PlacementGroup] = [] @@ -238,7 +236,8 @@ def _create_rollout_workers( options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) env = get_env_vars( - "", ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]) + self.exp_config, + ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]), ) if spec.env_vars: @@ -327,7 +326,8 @@ def _create_train_workers( options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) env = get_env_vars( - "", ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]) + self.exp_config, + ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]), ) if spec.env_vars: @@ -508,8 +508,7 @@ async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: if not env: return - ref = wi.actor.set_env.remote(env) - await asyncio.to_thread(ray.get, ref) + await wi.actor.set_env.remote(env) wi.env_vars.update(env) async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any: @@ -521,8 +520,7 @@ async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> A raise WorkerCreationError( worker_id, f"Engine must be a string import path, got {type(engine)}" ) - ref = wi.actor.create_engine.remote(engine, *args, **kwargs) - await asyncio.to_thread(ray.get, ref) + await wi.actor.create_engine.remote(engine, *args, **kwargs) def call_engine( self, @@ -594,7 +592,7 @@ async def async_call_engine( for attempt in range(1, max_retries + 1): try: ref = wi.actor.call.remote(method, *args, **kwargs) - result = await asyncio.to_thread(ray.get, ref, timeout=http_timeout) + result = await ref if attempt > 1: logger.info( f"Method '{method}' on '{worker_id}' " diff --git a/areal/scheduler/rpc/ray_rpc_server.py b/areal/scheduler/rpc/ray_rpc_server.py index e118a12f6..9b2312b34 100644 --- a/areal/scheduler/rpc/ray_rpc_server.py +++ b/areal/scheduler/rpc/ray_rpc_server.py @@ -53,6 +53,7 @@ def set_env(self, env: dict[str, str]) -> None: def create_engine(self, engine_path: str, *init_args, **init_kwargs) -> None: try: engine_class = import_from_string(engine_path) + self.logger.debug(f"Initializing engine {engine_class}") if not issubclass(engine_class, (TrainEngine, InferenceEngine)): raise TypeError( f"Engine class must be a TrainEngine or InferenceEngine, but got {engine_class}" @@ -69,14 +70,15 @@ def create_engine(self, engine_path: str, *init_args, **init_kwargs) -> None: raise def call(self, method: str, *args, **kwargs) -> Any: + self.logger.debug(f"Calling {method} with arguments {args=} {kwargs=}") if self._engine is None: raise RuntimeError("Engine not initialized. Call create_engine() first") - should_bcast = kwargs.pop("_should_bcast", True) + should_broadcast = kwargs.pop("should_broadcast", True) # keep broadcast behavior the same as RPCServer try: - if should_bcast and isinstance(self._engine, TrainEngine): + if should_broadcast and isinstance(self._engine, TrainEngine): device = self._get_device() args = tensor_container_to(args, device) @@ -105,6 +107,7 @@ def call(self, method: str, *args, **kwargs) -> Any: result = result.result() # put back to cpu to mimic RPCServer encode/decode result = tensor_container_to(result, "cpu") + self.logger.debug(f"Successfully completed RayRPCServer call {result}") return result except Exception as e: self.logger.error( diff --git a/areal/utils/data.py b/areal/utils/data.py index 6011ed4e2..c1f7b6d1d 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -336,7 +336,7 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: list[torch.Tensor]): def tensor_container_to( - d: dict[str, Any] | torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor], + d: dict[str, Any] | torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...], *args, **kwargs, ): diff --git a/areal/utils/device.py b/areal/utils/device.py index 38a5a4a54..550270713 100644 --- a/areal/utils/device.py +++ b/areal/utils/device.py @@ -1,6 +1,5 @@ import gc -import torch import torch.distributed as dist from areal.platforms import current_platform @@ -41,15 +40,3 @@ def clear_memory(): current_platform.synchronize() gc.collect() current_platform.empty_cache() - - -def ray_resource_type(): - if torch.cuda.is_available(): - return "GPU" - - from areal.platforms import is_npu_available - - if is_npu_available: - return "NPU" - - return "CPU" From 18d7737d3cb131521709bc50b32e3de844582981 Mon Sep 17 00:00:00 2001 From: Haley Date: Thu, 18 Dec 2025 16:29:48 -0800 Subject: [PATCH 05/15] RayRTensor, RTensor refactor, and tests for RayScheduler --- areal/scheduler/rpc/ray_rpc_server.py | 13 + areal/scheduler/rpc/ray_rtensor.py | 120 ++++++++ areal/scheduler/rpc/rtensor.py | 381 ++++++++++++++------------ areal/tests/test_ray_scheduler.py | 111 ++++++++ 4 files changed, 451 insertions(+), 174 deletions(-) create mode 100644 areal/scheduler/rpc/ray_rtensor.py create mode 100644 areal/tests/test_ray_scheduler.py diff --git a/areal/scheduler/rpc/ray_rpc_server.py b/areal/scheduler/rpc/ray_rpc_server.py index 9b2312b34..da012bdd7 100644 --- a/areal/scheduler/rpc/ray_rpc_server.py +++ b/areal/scheduler/rpc/ray_rpc_server.py @@ -7,6 +7,7 @@ from areal.api.cli_args import BaseExperimentConfig from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.scheduler.rpc.ray_rtensor import RayRTensor from areal.utils import logging, name_resolve, seeding from areal.utils.data import ( broadcast_tensor_container, @@ -74,6 +75,12 @@ def call(self, method: str, *args, **kwargs) -> Any: if self._engine is None: raise RuntimeError("Engine not initialized. Call create_engine() first") + raw_args = list(args) + raw_kwargs = kwargs.copy() + # fetch remote tensors if any + args = RayRTensor.localize(raw_args) + kwargs = RayRTensor.localize(raw_kwargs) + should_broadcast = kwargs.pop("should_broadcast", True) # keep broadcast behavior the same as RPCServer @@ -105,6 +112,12 @@ def call(self, method: str, *args, **kwargs) -> Any: result = fn(*args, **kwargs) if isinstance(result, Future): result = result.result() + # Convert all tensors to RTensors and store the tensor locally + layout = RayRTensor.extract_layout( + result, layouts=dict(args=raw_args, kwargs=raw_kwargs) + ) + if layout is not None: + result = RayRTensor.remotize(result, layout) # put back to cpu to mimic RPCServer encode/decode result = tensor_container_to(result, "cpu") self.logger.debug(f"Successfully completed RayRPCServer call {result}") diff --git a/areal/scheduler/rpc/ray_rtensor.py b/areal/scheduler/rpc/ray_rtensor.py new file mode 100644 index 000000000..5d65acbaa --- /dev/null +++ b/areal/scheduler/rpc/ray_rtensor.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import ray +import torch + +from areal.scheduler.rpc.rtensor import ( + BaseRTensor, + BaseTensorShardInfo, + _find_in_structure, + _pad_cat_dim0, +) + + +@dataclass +class RayTensorShardInfo(BaseTensorShardInfo): + ref: ray.ObjectRef + + +@dataclass +class RayRTensor(BaseRTensor): + def to_local(self) -> torch.Tensor: + if not self.data.is_meta: + return self.data + # Fetch all shards first + tensors = self._fetch() + self.data = _pad_cat_dim0(tensors) + return self.data + + def _fetch(self) -> list[torch.Tensor]: + return ray.get([s.ref for s in self.shards]) + + def split(self) -> list[RayRTensor]: + tensors = RayRTensor.split_tensor(self.data, self) + return [RayRTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] + + @classmethod + def from_batched(cls, batch_tensor: torch.Tensor, layout: RayRTensor): + if not batch_tensor.is_cpu and not batch_tensor.is_meta: + raise ValueError("RTensor shards must be on CPU or meta device") + + tensors = cls.split_tensor(batch_tensor, layout) + + shards = [] + for tensor, shard_info in zip(tensors, layout.shards): + ref = ray.put(tensor) + info = RayTensorShardInfo( + ref=ref, + size=shard_info.size, + seqlens=shard_info.seqlens.copy(), + ) + shards.append(info) + + # Truncate at the maximum sequence length + # to prevent over-padding + if tensor.ndim > 1: + tensor = tensor[:, : max(shard_info.seqlens)] + + return cls(shards=shards, data=batch_tensor.to("meta")) + + @staticmethod + def extract_layout(obj: Any, layouts: Any): + layout_rtensor = _find_in_structure(layouts, RayRTensor) + result_tensor = _find_in_structure(obj, torch.Tensor) + + if layout_rtensor is None and result_tensor is not None: + if not isinstance(obj, dict): + raise RuntimeError( + "When input does not contain RayRTensor, " + "we expect to extract layouts from a dict batch " + f"returned by InferenceEngine. Get obj: {obj}, " + f"input layouts: {layouts}." + ) + attn_mask = obj.get("attention_mask", None) + if attn_mask is None: + raise RuntimeError("`attention_mask` is not found") + + layout_rtensor = RayRTensor( + shards=[ + RayTensorShardInfo( + ref=None, # placeholder to be filled later + size=attn_mask.shape[0], + seqlens=[int(am.sum()) for am in attn_mask], + ) + ], + data=torch.empty_like(attn_mask, device="meta"), + ) + return layout_rtensor + + @staticmethod + def remotize(obj: Any, layout: RayRTensor) -> Any: + if isinstance(obj, torch.Tensor): + return RayRTensor.from_batched(obj.detach().cpu(), layout=layout) + + if isinstance(obj, dict): + return { + k: RayRTensor.remotize(obj=v, layout=layout) for k, v in obj.items() + } + + if isinstance(obj, list): + return [RayRTensor.remotize(obj=item, layout=layout) for item in obj] + + if isinstance(obj, tuple): + return tuple(RayRTensor.remotize(obj=item, layout=layout) for item in obj) + + return obj + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.cat: + return RayRTensor.cat(*args, **kwargs) + + raise NotImplementedError( + f"RayRTensor does not implement torch function {func}" + ) diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index e784e98fb..4b6b93e5b 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -2,6 +2,7 @@ import asyncio import uuid +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from threading import Lock @@ -16,62 +17,26 @@ @dataclass -class TensorShardInfo: - """Metadata for a single shard of an RTensor.""" - - shard_id: str - node_addr: str +class BaseTensorShardInfo(ABC): size: int # Batch size (shape[0]) of this shard seqlens: list[int] # Sequence lengths of this shard (from attention_mask) @dataclass -class RTensor: +class BaseRTensor(ABC): """Single tensor distributed as CPU shards across nodes.""" shards: list[TensorShardInfo] data: torch.Tensor + @abstractmethod def to_local(self) -> torch.Tensor: - """Fetch all shards via HTTP, concatenate along dim 0.""" - if not self.data.is_meta: - return self.data - # Fetch all shards first - tensors = self._fetch() - self.data = _pad_cat_dim0(tensors) - return self.data - - def _fetch(self): - """Fetch all shards synchronously.""" - - async def _fetch_all(): - async with aiohttp.ClientSession() as session: - return await asyncio.gather( - *[ - RTensor._fetch_tensor(session, s.shard_id, s.node_addr) - for s in self.shards - ] - ) - - return asyncio.run(_fetch_all()) + pass @staticmethod - async def _fetch_tensor( - session: aiohttp.ClientSession, shard_id: str, node_addr: str - ) -> torch.Tensor: - # Avoid circular import - from areal.scheduler.rpc.serialization import deserialize_value - - url = f"http://{node_addr}/data/{shard_id}" - async with session.get(url) as resp: - if resp.status != 200: - raise RuntimeError(f"Failed to fetch shard from {url}: {resp.status}") - data_bytes = await resp.read() - serialized_data = orjson.loads(data_bytes) - return deserialize_value(serialized_data) - - @staticmethod - def split_tensor(batch_tensor: torch.Tensor, layout: RTensor) -> list[torch.Tensor]: + def split_tensor( + batch_tensor: torch.Tensor, layout: BaseRTensor + ) -> list[torch.Tensor]: offsets = np.cumsum([0] + [shard.size for shard in layout.shards]) if offsets[-1] != batch_tensor.shape[0]: raise ValueError( @@ -80,54 +45,32 @@ def split_tensor(batch_tensor: torch.Tensor, layout: RTensor) -> list[torch.Tens # no clone here because they are read-only slices return [batch_tensor[a:b] for a, b in zip(offsets[:-1], offsets[1:])] - def split(self) -> list[RTensor]: - tensors = RTensor.split_tensor(self.data, self) - return [RTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] + @abstractmethod + def split(self) -> list[BaseRTensor]: + pass @classmethod - def from_batched(cls, batch_tensor: torch.Tensor, layout: RTensor, node_addr: str): - if not batch_tensor.is_cpu and not batch_tensor.is_meta: - raise ValueError("RTensor shards must be on CPU or meta device") + @abstractmethod + def from_batched(cls, batch_tensor: torch.Tensor, layout: BaseRTensor): + pass - tensors = cls.split_tensor(batch_tensor, layout) - - shards = [] - for tensor, shard_info in zip(tensors, layout.shards): - sid = str(uuid.uuid4()) - info = TensorShardInfo( - shard_id=sid, - node_addr=node_addr, - size=shard_info.size, - seqlens=shard_info.seqlens.copy(), - ) - shards.append(info) - - # Truncate at the maximum sequence length - # to prevent over-padding - if tensor.ndim > 1: - tensor = tensor[:, : max(shard_info.seqlens)] - # Store locally - store(sid, tensor) - - return cls(shards=shards, data=batch_tensor.to("meta")) - - @staticmethod - def cat(rtensors: list[RTensor | torch.Tensor], dim=0) -> RTensor: + @classmethod + def cat(cls, rtensors: list[BaseRTensor | torch.Tensor], dim=0) -> BaseRTensor: """Concatenate RTensors along existing dimension.""" n_tensors = len(rtensors) if n_tensors == 0: - return RTensor(shards=[], data=torch.tensor([]).to("meta")) - n_rtensors = len([x for x in rtensors if isinstance(x, RTensor)]) + return cls(shards=[], data=torch.tensor([]).to("meta")) + n_rtensors = len([x for x in rtensors if isinstance(x, cls)]) # All RTensors if n_tensors == n_rtensors: if dim != 0: raise ValueError( - "RTensor.cat for multiple RTensors only supports dim=0" + "BaseRTensor.cat for multiple RTensors only supports dim=0" ) if any(t.data is None for t in rtensors): raise RuntimeError("Cannot concat rtensors with None data") - return RTensor( + return cls( shards=[shard for r in rtensors for shard in r.shards], data=_pad_cat_dim0([r.data for r in rtensors]), ) @@ -137,97 +80,50 @@ def cat(rtensors: list[RTensor | torch.Tensor], dim=0) -> RTensor: raise ValueError( "RTensor.cat only support concatenating a single RTensor with other torch.Tensor" ) - rt = [x for x in rtensors if isinstance(x, RTensor)][0] - return RTensor( + rt = [x for x in rtensors if isinstance(x, cls)][0] + return cls( shards=rt.shards, data=torch.cat( - [r.data if isinstance(r, RTensor) else r for r in rtensors], dim=dim + [r.data if isinstance(r, cls) else r for r in rtensors], dim=dim ), ) @staticmethod - def extract_layout(obj: Any, layouts: Any, node_addr: str | None): - # Determine if batched from input layouts - layout_rtensor = _find_in_structure(layouts, RTensor) - result_tensor = _find_in_structure(obj, torch.Tensor) - if layout_rtensor is None and result_tensor is not None: - if not isinstance(obj, dict): - raise RuntimeError( - "When input does not contain RTensor, " - "we expect to extract layouts from a dict batch " - f"returned by InferenceEngine. Get obj: {obj}, " - f"input layouts: {layouts}." - ) - attn_mask = obj.get("attention_mask", None) - if attn_mask is None: - raise RuntimeError("`attention_mask` is not found") - assert node_addr is not None - layout_rtensor = RTensor( - shards=[ - TensorShardInfo( - shard_id="", - node_addr=node_addr, - size=attn_mask.shape[0], - seqlens=[int(am.sum()) for am in attn_mask], - ) - ], - data=torch.empty_like(attn_mask, device="meta"), - ) - return layout_rtensor + @abstractmethod + def extract_layout(obj: Any, layouts: Any): + pass @staticmethod - def remotize(obj: Any, layout: RTensor, node_addr: str) -> Any: - if isinstance(obj, torch.Tensor): - return RTensor.from_batched( - obj.detach().cpu(), layout=layout, node_addr=node_addr - ) - - if isinstance(obj, dict): - return { - k: RTensor.remotize(obj=v, layout=layout, node_addr=node_addr) - for k, v in obj.items() - } - - if isinstance(obj, list): - return [ - RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) - for item in obj - ] - - if isinstance(obj, tuple): - return tuple( - RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) - for item in obj - ) + @abstractmethod + def remotize(obj: Any, layout: BaseRTensor) -> Any: + pass - return obj - - @staticmethod - def localize(obj: Any) -> Any: - """Convert RTensors to local tensors in nested structures. + @classmethod + def localize(cls, obj: Any) -> Any: + """Convert BaseRTensors to local tensors in nested structures. Inverse of remotize() - fetches remote data and converts to local tensors. """ - if isinstance(obj, RTensor): + if isinstance(obj, cls): return obj.to_local() if isinstance(obj, dict): - return {k: RTensor.localize(v) for k, v in obj.items()} + return {k: cls.localize(v) for k, v in obj.items()} if isinstance(obj, list): - return [RTensor.localize(item) for item in obj] + return [cls.localize(item) for item in obj] if isinstance(obj, tuple): - return tuple(RTensor.localize(item) for item in obj) + return tuple(cls.localize(item) for item in obj) return obj - @staticmethod + @classmethod def data_parallel_dispatch( - obj: Any, dp_size: int, group_indices: list[list[int]] | None = None + cls, obj: Any, dp_size: int, group_indices: list[list[int]] | None = None ) -> tuple[list[Any], list[list[int]] | None]: if group_indices is None: - layout_rtensor = _find_in_structure(obj, RTensor) + layout_rtensor = _find_in_structure(obj, cls) if layout_rtensor is not None: # FIXME: the next line prevents splitting a single trajectory into finer granularity seqlens = [sum(s.seqlens) for s in layout_rtensor.shards] @@ -237,21 +133,21 @@ def data_parallel_dispatch( ) # else: no RTensors found, will replicate scalars without group_indices - if isinstance(obj, RTensor): - tensors = RTensor.split_tensor(obj.data, obj) + if isinstance(obj, cls): + tensors = cls.split_tensor(obj.data, obj) # Split shards according to group assignments split_rtensors = [] for group_idxs in group_indices: # Collect shards for this group group_shards = [obj.shards[i] for i in group_idxs] group_data = _pad_cat_dim0([tensors[i] for i in group_idxs]) - split_rtensors.append(RTensor(shards=group_shards, data=group_data)) + split_rtensors.append(cls(shards=group_shards, data=group_data)) return split_rtensors, group_indices if isinstance(obj, dict): # Split each value, return list of dicts split_values = { - k: RTensor.data_parallel_dispatch(v, dp_size, group_indices)[0] + k: cls.data_parallel_dispatch(v, dp_size, group_indices)[0] for k, v in obj.items() } return [ @@ -261,7 +157,7 @@ def data_parallel_dispatch( if isinstance(obj, list): # Split each element split_elements = [ - RTensor.data_parallel_dispatch(elem, dp_size, group_indices)[0] + cls.data_parallel_dispatch(elem, dp_size, group_indices)[0] for elem in obj ] return [ @@ -271,7 +167,7 @@ def data_parallel_dispatch( if isinstance(obj, tuple): # Split each element split_elements = [ - RTensor.data_parallel_dispatch(elem, dp_size, group_indices)[0] + cls.data_parallel_dispatch(elem, dp_size, group_indices)[0] for elem in obj ] return [ @@ -282,9 +178,9 @@ def data_parallel_dispatch( # Non-RTensor objects: replicate to all groups return [obj] * dp_size, group_indices - @staticmethod + @classmethod def data_parallel_merge( - results: list[Any], group_indices: list[list[int]] | None + cls, results: list[Any], group_indices: list[list[int]] | None ) -> Any: if not results: return None @@ -294,24 +190,24 @@ def data_parallel_merge( # Check for raw tensors - not allowed if isinstance(first, torch.Tensor): raise TypeError( - "Regular tensors not allowed in merge - only RTensors. " - "Engine outputs should be automatically converted to RTensors." + "Regular tensors not allowed in merge - only BaseRTensors. " + "Engine outputs should be automatically converted to BaseRTensors." ) - if isinstance(first, RTensor): + if isinstance(first, cls): assert group_indices is not None rtensors = flat2d([r.split() for r in results]) indices = flat2d(group_indices) assert len(rtensors) == len(indices), (len(rtensors), len(indices)) inv_indices = np.zeros(len(indices), dtype=np.int64) inv_indices[indices] = np.arange(len(indices)) - return RTensor.cat([rtensors[i] for i in inv_indices]) + return cls.cat([rtensors[i] for i in inv_indices]) if isinstance(first, dict): merged = {} for key in first.keys(): values = [r[key] for r in results] - merged[key] = RTensor.data_parallel_merge( + merged[key] = cls.data_parallel_merge( values, group_indices=group_indices ) return merged @@ -321,7 +217,7 @@ def data_parallel_merge( for i in range(len(first)): elements = [r[i] for r in results] merged.append( - RTensor.data_parallel_merge(elements, group_indices=group_indices) + cls.data_parallel_merge(elements, group_indices=group_indices) ) return merged @@ -330,13 +226,166 @@ def data_parallel_merge( for i in range(len(first)): elements = [r[i] for r in results] merged.append( - RTensor.data_parallel_merge(elements, group_indices=group_indices) + cls.data_parallel_merge(elements, group_indices=group_indices) ) return tuple(merged) # Scalars: return first (assume synchronized) return first + @property + def shape(self): + return self.data.shape + + @property + def dtype(self): + return self.data.dtype + + @property + def device(self): + return self.data.device + + @property + def ndim(self): + return self.data.ndim + + +@dataclass +class TensorShardInfo: + """Metadata for a single shard of an RTensor.""" + + shard_id: str + node_addr: str + + +@dataclass +class RTensor(BaseRTensor): + def to_local(self) -> torch.Tensor: + """Fetch all shards via HTTP, concatenate along dim 0.""" + if not self.data.is_meta: + return self.data + # Fetch all shards first + tensors = self._fetch() + self.data = _pad_cat_dim0(tensors) + return self.data + + def _fetch(self): + """Fetch all shards synchronously.""" + + async def _fetch_all(): + async with aiohttp.ClientSession() as session: + return await asyncio.gather( + *[ + RTensor._fetch_tensor(session, s.shard_id, s.node_addr) + for s in self.shards + ] + ) + + return asyncio.run(_fetch_all()) + + @staticmethod + async def _fetch_tensor( + session: aiohttp.ClientSession, shard_id: str, node_addr: str + ) -> torch.Tensor: + # Avoid circular import + from areal.scheduler.rpc.serialization import deserialize_value + + url = f"http://{node_addr}/data/{shard_id}" + async with session.get(url) as resp: + if resp.status != 200: + raise RuntimeError(f"Failed to fetch shard from {url}: {resp.status}") + data_bytes = await resp.read() + serialized_data = orjson.loads(data_bytes) + return deserialize_value(serialized_data) + + def split(self) -> list[RTensor]: + tensors = RTensor.split_tensor(self.data, self) + return [RTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] + + @classmethod + def from_batched(cls, batch_tensor: torch.Tensor, layout: RTensor, node_addr: str): + if not batch_tensor.is_cpu and not batch_tensor.is_meta: + raise ValueError("RTensor shards must be on CPU or meta device") + + tensors = cls.split_tensor(batch_tensor, layout) + + shards = [] + for tensor, shard_info in zip(tensors, layout.shards): + sid = str(uuid.uuid4()) + info = TensorShardInfo( + shard_id=sid, + node_addr=node_addr, + size=shard_info.size, + seqlens=shard_info.seqlens.copy(), + ) + shards.append(info) + + # Truncate at the maximum sequence length + # to prevent over-padding + if tensor.ndim > 1: + tensor = tensor[:, : max(shard_info.seqlens)] + # Store locally + store(sid, tensor) + + return cls(shards=shards, data=batch_tensor.to("meta")) + + @staticmethod + def extract_layout(obj: Any, layouts: Any, node_addr: str | None): + # Determine if batched from input layouts + layout_rtensor = _find_in_structure(layouts, RTensor) + result_tensor = _find_in_structure(obj, torch.Tensor) + if layout_rtensor is None and result_tensor is not None: + if not isinstance(obj, dict): + raise RuntimeError( + "When input does not contain RTensor, " + "we expect to extract layouts from a dict batch " + f"returned by InferenceEngine. Get obj: {obj}, " + f"input layouts: {layouts}." + ) + attn_mask = obj.get("attention_mask", None) + if attn_mask is None: + raise RuntimeError("`attention_mask` is not found") + assert node_addr is not None + layout_rtensor = RTensor( + shards=[ + TensorShardInfo( + shard_id="", + node_addr=node_addr, + size=attn_mask.shape[0], + seqlens=[int(am.sum()) for am in attn_mask], + ) + ], + data=torch.empty_like(attn_mask, device="meta"), + ) + return layout_rtensor + + @staticmethod + def remotize(obj: Any, layout: RTensor, node_addr: str) -> Any: + if isinstance(obj, torch.Tensor): + return RTensor.from_batched( + obj.detach().cpu(), layout=layout, node_addr=node_addr + ) + + if isinstance(obj, dict): + return { + k: RTensor.remotize(obj=v, layout=layout, node_addr=node_addr) + for k, v in obj.items() + } + + if isinstance(obj, list): + return [ + RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) + for item in obj + ] + + if isinstance(obj, tuple): + return tuple( + RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) + for item in obj + ) + + return obj + @staticmethod def collect_shards(obj: Any) -> dict[str, list[str]]: """Collect shard IDs grouped by node address from nested structure. @@ -361,22 +410,6 @@ def _collect(o): _collect(obj) return shards_by_node - @property - def shape(self): - return self.data.shape - - @property - def dtype(self): - return self.data.dtype - - @property - def device(self): - return self.data.device - - @property - def ndim(self): - return self.data.ndim - @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/areal/tests/test_ray_scheduler.py b/areal/tests/test_ray_scheduler.py new file mode 100644 index 000000000..f9091c059 --- /dev/null +++ b/areal/tests/test_ray_scheduler.py @@ -0,0 +1,111 @@ +import ray +from ray.util.state import summarize_actors + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.scheduler_api import ( + Job, + SchedulingSpec, +) +from areal.scheduler.ray import RayScheduler, ray_resource_type + + +class TestRaySchedulerInitialization: + def test_init(self): + scheduler = RayScheduler( + startup_timeout=60.0, exp_config=BaseExperimentConfig() + ) + assert scheduler.startup_timeout == 60.0 + + +class TestWorkerCreationAndDeletion: + def test_create_delete_workers(self): + ray.init() + + config = BaseExperimentConfig() + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + + job = Job( + replicas=2, + role="train", + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + ], + ) + + # create workers + worker_ids = scheduler.create_workers(job) + assert len(worker_ids) == 2 + assert len(scheduler._workers["train"]) == 2 + + actor_summary = summarize_actors() + + assert ( + actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["ALIVE"] + == 2 + ) + assert len(scheduler.get_workers("train")) == 2 + + scheduler._ping_workers("train") + + # delete workers + scheduler.delete_workers() + assert len(scheduler._workers["train"]) == 0 + + actor_summary = summarize_actors() + assert ( + actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["DEAD"] + == 2 + ) + + +class TestUtilityFunctions: + def test_utilities(self): + _num_gpu_per_node = 16 + config = BaseExperimentConfig() + + config.cluster.n_gpus_per_node = _num_gpu_per_node + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + + schedulings = [ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + ] + + new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings) + assert len(new_schedulings) == 2 + for spec in new_schedulings: + assert spec.cpu == 1 + assert spec.mem == 1024 + assert spec.gpu == 1 + + # case where only 1 spec is passed but multiple workers + new_schedulings = scheduler._prepare_worker_specs("trian", 2, schedulings[0:]) + assert len(new_schedulings) == 2 + for spec in new_schedulings: + assert spec.cpu == 1 + assert spec.mem == 1024 + assert spec.gpu == 1 + + bundle_list = scheduler._create_bundle_list_gpu(1, 24, 1024) + assert len(bundle_list) == 2 + for bundle in bundle_list: + assert bundle[ray_resource_type()] <= _num_gpu_per_node From 99429247a8f3c44da4f575302d0c04a724468509 Mon Sep 17 00:00:00 2001 From: Haley Date: Thu, 18 Dec 2025 16:45:55 -0800 Subject: [PATCH 06/15] Fix typos --- areal/scheduler/rpc/rtensor.py | 7 ++++++- areal/tests/test_ray_scheduler.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index 4b6b93e5b..7988ac97a 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -249,9 +249,14 @@ def device(self): def ndim(self): return self.data.ndim + @classmethod + @abstractmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + pass + @dataclass -class TensorShardInfo: +class TensorShardInfo(BaseTensorShardInfo): """Metadata for a single shard of an RTensor.""" shard_id: str diff --git a/areal/tests/test_ray_scheduler.py b/areal/tests/test_ray_scheduler.py index f9091c059..8bf3799cb 100644 --- a/areal/tests/test_ray_scheduler.py +++ b/areal/tests/test_ray_scheduler.py @@ -98,7 +98,7 @@ def test_utilities(self): assert spec.gpu == 1 # case where only 1 spec is passed but multiple workers - new_schedulings = scheduler._prepare_worker_specs("trian", 2, schedulings[0:]) + new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings[0:]) assert len(new_schedulings) == 2 for spec in new_schedulings: assert spec.cpu == 1 From fd2b0b39c1330a656e4933962c150a572e9a0db0 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Thu, 18 Dec 2025 16:49:30 -0800 Subject: [PATCH 07/15] Update areal/scheduler/rpc/rtensor.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/scheduler/rpc/rtensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index 7988ac97a..e4cc0111a 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -26,7 +26,7 @@ class BaseTensorShardInfo(ABC): class BaseRTensor(ABC): """Single tensor distributed as CPU shards across nodes.""" - shards: list[TensorShardInfo] + shards: list[BaseTensorShardInfo] data: torch.Tensor @abstractmethod From 75c5e53a2bc58a8c5fb10d75e69972197cffca9e Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Thu, 18 Dec 2025 16:49:41 -0800 Subject: [PATCH 08/15] Update areal/scheduler/ray.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- areal/scheduler/ray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index 38d87ddb7..99da95f23 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -51,7 +51,7 @@ class RayWorkerInfo: actor: ray.actor.ActorHandle role: str placement_group: PlacementGroup - bundle_index: int + bundle_index: int | None created_at: float env_vars: dict[str, str] = field(default_factory=dict) From 6e34fa2869c3fb8dd1d6ccdfc523a944672f251e Mon Sep 17 00:00:00 2001 From: Haley Date: Thu, 18 Dec 2025 16:58:02 -0800 Subject: [PATCH 09/15] Add gemini suggestions --- areal/scheduler/ray.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index 38d87ddb7..3835b3d5c 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -67,6 +67,7 @@ def __init__( self.startup_timeout = startup_timeout self._workers: dict[str, list[RayWorkerInfo]] = defaultdict(list) + self._worker_info_by_id: dict[str, RayWorkerInfo] = {} self._placement_groups: list[PlacementGroup] = [] def _prepare_worker_specs( @@ -421,6 +422,9 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: self._workers[role].extend(worker_info_list) + for wi in worker_info_list: + self._worker_info_by_id[wi.worker.id] = wi + self._ping_workers(role, self.startup_timeout) if self.exp_config is not None: @@ -477,9 +481,11 @@ def delete_workers(self, role: str | None = None): logger.info(f"Successfully deleted workers for role '{role}'") def _cleanup_workers(self, workers: list[RayWorkerInfo]): + # Kill actors first for wi in workers: actor = wi.actor try: + # Asynchronously destroy actor actor.destroy.remote() except Exception: logger.warning( @@ -487,19 +493,18 @@ def _cleanup_workers(self, workers: list[RayWorkerInfo]): ) ray.kill(actor, no_restart=True) + # Collect unique placement groups and remove them + unique_pgs = {wi.placement_group for wi in workers} + for pg in unique_pgs: try: - remove_placement_group(wi.placement_group) + remove_placement_group(pg) except Exception: - logger.warning(f"Could not remove placement group {wi.placement_group}") - if wi.placement_group in self._placement_groups: - self._placement_groups.remove(wi.placement_group) + logger.warning(f"Could not remove placement group {pg}") + if pg in self._placement_groups: + self._placement_groups.remove(pg) def _get_worker_info_by_id(self, worker_id: str) -> RayWorkerInfo | None: - for worker_info_list in self._workers.values(): - for wi in worker_info_list: - if wi.worker.id == worker_id: - return wi - return None + return self._worker_info_by_id.get(worker_id, None) async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: wi = self._get_worker_info_by_id(worker_id) @@ -625,6 +630,8 @@ async def async_call_engine( ) def __del__(self): + # delete in case delete_workers is not called from controllers + # explicit shutdown is by directly calling delete_workers try: self.delete_workers() except Exception: From 8681a271b74a7eeb58d6da8a448d7d75e1952746 Mon Sep 17 00:00:00 2001 From: Haley Date: Mon, 22 Dec 2025 09:51:19 -0800 Subject: [PATCH 10/15] Fix rtensor test regex assertion error --- areal/scheduler/rpc/rtensor.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index e4cc0111a..4108f6616 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -189,10 +189,7 @@ def data_parallel_merge( # Check for raw tensors - not allowed if isinstance(first, torch.Tensor): - raise TypeError( - "Regular tensors not allowed in merge - only BaseRTensors. " - "Engine outputs should be automatically converted to BaseRTensors." - ) + raise TypeError("Regular tensors not allowed in merge - only RTensors") if isinstance(first, cls): assert group_indices is not None From 30abe0ef59492bb82cabf5429d52b2e3b7b18b6c Mon Sep 17 00:00:00 2001 From: Haley Date: Mon, 22 Dec 2025 12:14:48 -0800 Subject: [PATCH 11/15] Tests for ray scheduler create and call engine --- areal/tests/test_ray_scheduler.py | 54 ++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/areal/tests/test_ray_scheduler.py b/areal/tests/test_ray_scheduler.py index 8bf3799cb..caf4025a6 100644 --- a/areal/tests/test_ray_scheduler.py +++ b/areal/tests/test_ray_scheduler.py @@ -1,12 +1,12 @@ +import asyncio +from unittest.mock import Mock, patch + import ray from ray.util.state import summarize_actors from areal.api.cli_args import BaseExperimentConfig -from areal.api.scheduler_api import ( - Job, - SchedulingSpec, -) -from areal.scheduler.ray import RayScheduler, ray_resource_type +from areal.api.scheduler_api import Job, SchedulingSpec, Worker +from areal.scheduler.ray import RayScheduler, RayWorkerInfo, ray_resource_type class TestRaySchedulerInitialization: @@ -68,6 +68,50 @@ def test_create_delete_workers(self): ) +class TestWorkerCallEngine: + def test_create_call_engine(self): + # to simulate an awaitable None + async def async_none(*args, **kwargs): + return None + + config = BaseExperimentConfig() + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + ray_actor_handle = Mock() + ray_actor_handle.create_engine.remote = async_none + + worker = RayWorkerInfo( + worker=Worker(id="test/0", ip="0.0.0.0"), + actor=ray_actor_handle, + role="test", + placement_group=None, + bundle_index=0, + created_at=0, + ) + + scheduler._workers["test"] = [worker] + scheduler._worker_info_by_id[worker.worker.id] = worker + + # create engine + result = asyncio.run( + scheduler.create_engine( + worker.worker.id, "test_engines.DummyEngine", name="TestEngine" + ) + ) + assert result is None + + # sync + ray_actor_handle.call.remote = lambda x: None + with patch("areal.scheduler.ray.ray.get", return_value=None): + result = scheduler.call_engine(worker.worker.id, "test_fn") + assert result is None + + # async + ray_actor_handle.call.remote = async_none + result = asyncio.run(scheduler.async_call_engine(worker.worker.id, "test_fn")) + assert result is None + + class TestUtilityFunctions: def test_utilities(self): _num_gpu_per_node = 16 From 60ecdd89bf7d84058b0b0845e08f5f7a56516394 Mon Sep 17 00:00:00 2001 From: Haley Date: Tue, 23 Dec 2025 12:30:54 -0800 Subject: [PATCH 12/15] Refactor ray implementation of rtensor to use dependency injection instead of subclassing RTensor --- areal/controller/train_controller.py | 16 +- areal/scheduler/rpc/ray_rpc_server.py | 12 +- areal/scheduler/rpc/ray_rtensor.py | 120 ------- areal/scheduler/rpc/rtensor.py | 481 +++++++++++++++----------- 4 files changed, 279 insertions(+), 350 deletions(-) delete mode 100644 areal/scheduler/rpc/ray_rtensor.py diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index 2accc3b08..9e66b2a8c 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any -import aiohttp import torch from torchdata.stateful_dataloader import StatefulDataLoader @@ -666,25 +665,14 @@ def update_weights(self, meta: WeightUpdateMeta): raise ValueError(f"Unknown weight update type {meta.type}") async def _async_clear_batches(self, *targets: dict[str, RTensor]): - """Extract shard IDs and call /data/clear on each worker.""" + """Extract shard IDs and clear tensors on each worker.""" shards_by_node = RTensor.collect_shards(targets) if not shards_by_node: return - async def clear_node(node_addr, shard_ids): - async with aiohttp.ClientSession() as session: - async with session.delete( - f"http://{node_addr}/data/clear", json={"shard_ids": shard_ids} - ) as resp: - if resp.status == 200: - result = await resp.json() - logger.info( - f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}" - ) - await asyncio.gather( - *[clear_node(addr, sids) for addr, sids in shards_by_node.items()], + *[RTensor.clear_node(addr, sids) for addr, sids in shards_by_node.items()], return_exceptions=True, ) diff --git a/areal/scheduler/rpc/ray_rpc_server.py b/areal/scheduler/rpc/ray_rpc_server.py index da012bdd7..cb5e75a25 100644 --- a/areal/scheduler/rpc/ray_rpc_server.py +++ b/areal/scheduler/rpc/ray_rpc_server.py @@ -7,7 +7,7 @@ from areal.api.cli_args import BaseExperimentConfig from areal.api.engine_api import InferenceEngine, TrainEngine -from areal.scheduler.rpc.ray_rtensor import RayRTensor +from areal.scheduler.rpc.rtensor import RTensor from areal.utils import logging, name_resolve, seeding from areal.utils.data import ( broadcast_tensor_container, @@ -78,8 +78,8 @@ def call(self, method: str, *args, **kwargs) -> Any: raw_args = list(args) raw_kwargs = kwargs.copy() # fetch remote tensors if any - args = RayRTensor.localize(raw_args) - kwargs = RayRTensor.localize(raw_kwargs) + args = RTensor.localize(raw_args) + kwargs = RTensor.localize(raw_kwargs) should_broadcast = kwargs.pop("should_broadcast", True) @@ -113,11 +113,11 @@ def call(self, method: str, *args, **kwargs) -> Any: if isinstance(result, Future): result = result.result() # Convert all tensors to RTensors and store the tensor locally - layout = RayRTensor.extract_layout( - result, layouts=dict(args=raw_args, kwargs=raw_kwargs) + layout = RTensor.extract_layout( + result, layouts=dict(args=raw_args, kwargs=raw_kwargs), node_addr="" ) if layout is not None: - result = RayRTensor.remotize(result, layout) + result = RTensor.remotize(result, layout, node_addr="") # put back to cpu to mimic RPCServer encode/decode result = tensor_container_to(result, "cpu") self.logger.debug(f"Successfully completed RayRPCServer call {result}") diff --git a/areal/scheduler/rpc/ray_rtensor.py b/areal/scheduler/rpc/ray_rtensor.py deleted file mode 100644 index 5d65acbaa..000000000 --- a/areal/scheduler/rpc/ray_rtensor.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any - -import ray -import torch - -from areal.scheduler.rpc.rtensor import ( - BaseRTensor, - BaseTensorShardInfo, - _find_in_structure, - _pad_cat_dim0, -) - - -@dataclass -class RayTensorShardInfo(BaseTensorShardInfo): - ref: ray.ObjectRef - - -@dataclass -class RayRTensor(BaseRTensor): - def to_local(self) -> torch.Tensor: - if not self.data.is_meta: - return self.data - # Fetch all shards first - tensors = self._fetch() - self.data = _pad_cat_dim0(tensors) - return self.data - - def _fetch(self) -> list[torch.Tensor]: - return ray.get([s.ref for s in self.shards]) - - def split(self) -> list[RayRTensor]: - tensors = RayRTensor.split_tensor(self.data, self) - return [RayRTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] - - @classmethod - def from_batched(cls, batch_tensor: torch.Tensor, layout: RayRTensor): - if not batch_tensor.is_cpu and not batch_tensor.is_meta: - raise ValueError("RTensor shards must be on CPU or meta device") - - tensors = cls.split_tensor(batch_tensor, layout) - - shards = [] - for tensor, shard_info in zip(tensors, layout.shards): - ref = ray.put(tensor) - info = RayTensorShardInfo( - ref=ref, - size=shard_info.size, - seqlens=shard_info.seqlens.copy(), - ) - shards.append(info) - - # Truncate at the maximum sequence length - # to prevent over-padding - if tensor.ndim > 1: - tensor = tensor[:, : max(shard_info.seqlens)] - - return cls(shards=shards, data=batch_tensor.to("meta")) - - @staticmethod - def extract_layout(obj: Any, layouts: Any): - layout_rtensor = _find_in_structure(layouts, RayRTensor) - result_tensor = _find_in_structure(obj, torch.Tensor) - - if layout_rtensor is None and result_tensor is not None: - if not isinstance(obj, dict): - raise RuntimeError( - "When input does not contain RayRTensor, " - "we expect to extract layouts from a dict batch " - f"returned by InferenceEngine. Get obj: {obj}, " - f"input layouts: {layouts}." - ) - attn_mask = obj.get("attention_mask", None) - if attn_mask is None: - raise RuntimeError("`attention_mask` is not found") - - layout_rtensor = RayRTensor( - shards=[ - RayTensorShardInfo( - ref=None, # placeholder to be filled later - size=attn_mask.shape[0], - seqlens=[int(am.sum()) for am in attn_mask], - ) - ], - data=torch.empty_like(attn_mask, device="meta"), - ) - return layout_rtensor - - @staticmethod - def remotize(obj: Any, layout: RayRTensor) -> Any: - if isinstance(obj, torch.Tensor): - return RayRTensor.from_batched(obj.detach().cpu(), layout=layout) - - if isinstance(obj, dict): - return { - k: RayRTensor.remotize(obj=v, layout=layout) for k, v in obj.items() - } - - if isinstance(obj, list): - return [RayRTensor.remotize(obj=item, layout=layout) for item in obj] - - if isinstance(obj, tuple): - return tuple(RayRTensor.remotize(obj=item, layout=layout) for item in obj) - - return obj - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - if func is torch.cat: - return RayRTensor.cat(*args, **kwargs) - - raise NotImplementedError( - f"RayRTensor does not implement torch function {func}" - ) diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index 4108f6616..f2e29cfe2 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -6,11 +6,12 @@ from collections import defaultdict from dataclasses import dataclass from threading import Lock -from typing import Any +from typing import Any, ClassVar import aiohttp import numpy as np import orjson +import ray import torch from areal.utils.datapack import ffd_allocate, flat2d @@ -20,23 +21,151 @@ class BaseTensorShardInfo(ABC): size: int # Batch size (shape[0]) of this shard seqlens: list[int] # Sequence lengths of this shard (from attention_mask) + shard_id: str + node_addr: str + + @classmethod + @abstractmethod + def fetch(cls: type[BaseTensorShardInfo], shards: BaseTensorShardInfo): + pass + + @classmethod + @abstractmethod + def store(cls: type[BaseTensorShardInfo], shard_id: str, tensor: torch.Tensor): + pass + + @classmethod + @abstractmethod + def create( + cls: type[BaseTensorShardInfo], + *, + size: int, + seqlens: list[int], + **kwargs, + ) -> BaseTensorShardInfo: + pass + + @classmethod + @abstractmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + pass + + +@dataclass +class TensorShardInfo(BaseTensorShardInfo): + """Metadata for a single shard of an RTensor.""" + + @classmethod + def fetch(cls, shards): + """Fetch all shards synchronously.""" + + async def _fetch_all(): + async with aiohttp.ClientSession() as session: + return await asyncio.gather( + *[ + cls._fetch_tensor(session, s.shard_id, s.node_addr) + for s in shards + ] + ) + + return asyncio.run(_fetch_all()) + + @classmethod + async def _fetch_tensor( + cls, session: aiohttp.ClientSession, shard_id: str, node_addr: str + ) -> torch.Tensor: + # Avoid circular import + from areal.scheduler.rpc.serialization import deserialize_value + + url = f"http://{node_addr}/data/{shard_id}" + async with session.get(url) as resp: + if resp.status != 200: + raise RuntimeError(f"Failed to fetch shard from {url}: {resp.status}") + data_bytes = await resp.read() + serialized_data = orjson.loads(data_bytes) + return deserialize_value(serialized_data) + + @classmethod + def store(cls, shard_id: str, tensor: torch.Tensor): + store(shard_id, tensor) + return shard_id + + @classmethod + def create(cls, *, size, seqlens, shard_id: str, node_addr: str, **_): + return cls( + shard_id=shard_id, + node_addr=node_addr, + size=size, + seqlens=seqlens, + ) + + @classmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + async with aiohttp.ClientSession() as session: + async with session.delete( + f"http://{node_addr}/data/clear", json={"shard_ids": shard_ids} + ) as resp: + if resp.status == 200: + await resp.json() + + +@dataclass +class RayTensorShardInfo(BaseTensorShardInfo): + shard_id: ray.ObjectRef + + @classmethod + def fetch(cls, shards: RayTensorShardInfo) -> list[torch.Tensor]: + return ray.get([s.shard_id for s in shards]) + + @classmethod + def store(cls, shard_id: str, tensor: torch.Tensor): + return ray.put(tensor) + + @classmethod + def create(cls, *, size, seqlens, shard_id=None, **_): + return cls( + shard_id=shard_id, + size=size, + seqlens=seqlens, + node_addr="", + ) + + @classmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + ray.internal.free(shard_ids) @dataclass -class BaseRTensor(ABC): +class RTensor: """Single tensor distributed as CPU shards across nodes.""" shards: list[BaseTensorShardInfo] data: torch.Tensor + _tensor_info_cls: ClassVar[type[BaseTensorShardInfo]] = None + + @classmethod + def tensor_info_cls(cls) -> type[BaseTensorShardInfo]: + if cls._tensor_info_cls is None: + if ray.is_initialized(): + cls._tensor_info_cls = RayTensorShardInfo + else: + cls._tensor_info_cls = TensorShardInfo + return cls._tensor_info_cls - @abstractmethod def to_local(self) -> torch.Tensor: - pass + """Fetch all shards via HTTP, concatenate along dim 0.""" + if not self.data.is_meta: + return self.data + # Fetch all shards first + tensors = self._fetch() + self.data = _pad_cat_dim0(tensors) + return self.data + + def _fetch(self): + return RTensor.tensor_info_cls().fetch(self.shards) @staticmethod - def split_tensor( - batch_tensor: torch.Tensor, layout: BaseRTensor - ) -> list[torch.Tensor]: + def split_tensor(batch_tensor: torch.Tensor, layout: RTensor) -> list[torch.Tensor]: offsets = np.cumsum([0] + [shard.size for shard in layout.shards]) if offsets[-1] != batch_tensor.shape[0]: raise ValueError( @@ -45,32 +174,53 @@ def split_tensor( # no clone here because they are read-only slices return [batch_tensor[a:b] for a, b in zip(offsets[:-1], offsets[1:])] - @abstractmethod - def split(self) -> list[BaseRTensor]: - pass + def split(self) -> list[RTensor]: + tensors = RTensor.split_tensor(self.data, self) + return [RTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] @classmethod - @abstractmethod - def from_batched(cls, batch_tensor: torch.Tensor, layout: BaseRTensor): - pass + def from_batched(cls, batch_tensor: torch.Tensor, layout: RTensor, node_addr: str): + if not batch_tensor.is_cpu and not batch_tensor.is_meta: + raise ValueError("RTensor shards must be on CPU or meta device") - @classmethod - def cat(cls, rtensors: list[BaseRTensor | torch.Tensor], dim=0) -> BaseRTensor: + tensors = cls.split_tensor(batch_tensor, layout) + + shards = [] + for tensor, shard_info in zip(tensors, layout.shards): + sid = str(uuid.uuid4()) + # Truncate at the maximum sequence length + # to prevent over-padding + if tensor.ndim > 1: + tensor = tensor[:, : max(shard_info.seqlens)] + # Store locally + shard_id = RTensor.tensor_info_cls().store(sid, tensor) + info = RTensor.tensor_info_cls().create( + size=shard_info.size, + seqlens=shard_info.seqlens.copy(), + shard_id=shard_id, + node_addr=node_addr, + ) + shards.append(info) + + return cls(shards=shards, data=batch_tensor.to("meta")) + + @staticmethod + def cat(rtensors: list[RTensor | torch.Tensor], dim=0) -> RTensor: """Concatenate RTensors along existing dimension.""" n_tensors = len(rtensors) if n_tensors == 0: - return cls(shards=[], data=torch.tensor([]).to("meta")) - n_rtensors = len([x for x in rtensors if isinstance(x, cls)]) + return RTensor(shards=[], data=torch.tensor([]).to("meta")) + n_rtensors = len([x for x in rtensors if isinstance(x, RTensor)]) # All RTensors if n_tensors == n_rtensors: if dim != 0: raise ValueError( - "BaseRTensor.cat for multiple RTensors only supports dim=0" + "RTensor.cat for multiple RTensors only supports dim=0" ) if any(t.data is None for t in rtensors): raise RuntimeError("Cannot concat rtensors with None data") - return cls( + return RTensor( shards=[shard for r in rtensors for shard in r.shards], data=_pad_cat_dim0([r.data for r in rtensors]), ) @@ -80,50 +230,97 @@ def cat(cls, rtensors: list[BaseRTensor | torch.Tensor], dim=0) -> BaseRTensor: raise ValueError( "RTensor.cat only support concatenating a single RTensor with other torch.Tensor" ) - rt = [x for x in rtensors if isinstance(x, cls)][0] - return cls( + rt = [x for x in rtensors if isinstance(x, RTensor)][0] + return RTensor( shards=rt.shards, data=torch.cat( - [r.data if isinstance(r, cls) else r for r in rtensors], dim=dim + [r.data if isinstance(r, RTensor) else r for r in rtensors], dim=dim ), ) @staticmethod - @abstractmethod - def extract_layout(obj: Any, layouts: Any): - pass + def extract_layout(obj: Any, layouts: Any, node_addr: str | None): + # Determine if batched from input layouts + layout_rtensor = _find_in_structure(layouts, RTensor) + result_tensor = _find_in_structure(obj, torch.Tensor) + if layout_rtensor is None and result_tensor is not None: + if not isinstance(obj, dict): + raise RuntimeError( + "When input does not contain RTensor, " + "we expect to extract layouts from a dict batch " + f"returned by InferenceEngine. Get obj: {obj}, " + f"input layouts: {layouts}." + ) + attn_mask = obj.get("attention_mask", None) + if attn_mask is None: + raise RuntimeError("`attention_mask` is not found") + assert node_addr is not None + shard = RTensor.tensor_info_cls().create( + size=attn_mask.shape[0], + seqlens=[int(am.sum()) for am in attn_mask], + shard_id="", + node_addr=node_addr, + ) + + layout_rtensor = RTensor( + shards=[shard], + data=torch.empty_like(attn_mask, device="meta"), + ) + return layout_rtensor @staticmethod - @abstractmethod - def remotize(obj: Any, layout: BaseRTensor) -> Any: - pass + def remotize(obj: Any, layout: RTensor, node_addr: str) -> Any: + if isinstance(obj, torch.Tensor): + return RTensor.from_batched( + obj.detach().cpu(), layout=layout, node_addr=node_addr + ) - @classmethod - def localize(cls, obj: Any) -> Any: - """Convert BaseRTensors to local tensors in nested structures. + if isinstance(obj, dict): + return { + k: RTensor.remotize(obj=v, layout=layout, node_addr=node_addr) + for k, v in obj.items() + } + + if isinstance(obj, list): + return [ + RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) + for item in obj + ] + + if isinstance(obj, tuple): + return tuple( + RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) + for item in obj + ) + + return obj + + @staticmethod + def localize(obj: Any) -> Any: + """Convert RTensors to local tensors in nested structures. Inverse of remotize() - fetches remote data and converts to local tensors. """ - if isinstance(obj, cls): + if isinstance(obj, RTensor): return obj.to_local() if isinstance(obj, dict): - return {k: cls.localize(v) for k, v in obj.items()} + return {k: RTensor.localize(v) for k, v in obj.items()} if isinstance(obj, list): - return [cls.localize(item) for item in obj] + return [RTensor.localize(item) for item in obj] if isinstance(obj, tuple): - return tuple(cls.localize(item) for item in obj) + return tuple(RTensor.localize(item) for item in obj) return obj - @classmethod + @staticmethod def data_parallel_dispatch( - cls, obj: Any, dp_size: int, group_indices: list[list[int]] | None = None + obj: Any, dp_size: int, group_indices: list[list[int]] | None = None ) -> tuple[list[Any], list[list[int]] | None]: if group_indices is None: - layout_rtensor = _find_in_structure(obj, cls) + layout_rtensor = _find_in_structure(obj, RTensor) if layout_rtensor is not None: # FIXME: the next line prevents splitting a single trajectory into finer granularity seqlens = [sum(s.seqlens) for s in layout_rtensor.shards] @@ -133,21 +330,21 @@ def data_parallel_dispatch( ) # else: no RTensors found, will replicate scalars without group_indices - if isinstance(obj, cls): - tensors = cls.split_tensor(obj.data, obj) + if isinstance(obj, RTensor): + tensors = RTensor.split_tensor(obj.data, obj) # Split shards according to group assignments split_rtensors = [] for group_idxs in group_indices: # Collect shards for this group group_shards = [obj.shards[i] for i in group_idxs] group_data = _pad_cat_dim0([tensors[i] for i in group_idxs]) - split_rtensors.append(cls(shards=group_shards, data=group_data)) + split_rtensors.append(RTensor(shards=group_shards, data=group_data)) return split_rtensors, group_indices if isinstance(obj, dict): # Split each value, return list of dicts split_values = { - k: cls.data_parallel_dispatch(v, dp_size, group_indices)[0] + k: RTensor.data_parallel_dispatch(v, dp_size, group_indices)[0] for k, v in obj.items() } return [ @@ -157,7 +354,7 @@ def data_parallel_dispatch( if isinstance(obj, list): # Split each element split_elements = [ - cls.data_parallel_dispatch(elem, dp_size, group_indices)[0] + RTensor.data_parallel_dispatch(elem, dp_size, group_indices)[0] for elem in obj ] return [ @@ -167,7 +364,7 @@ def data_parallel_dispatch( if isinstance(obj, tuple): # Split each element split_elements = [ - cls.data_parallel_dispatch(elem, dp_size, group_indices)[0] + RTensor.data_parallel_dispatch(elem, dp_size, group_indices)[0] for elem in obj ] return [ @@ -178,9 +375,9 @@ def data_parallel_dispatch( # Non-RTensor objects: replicate to all groups return [obj] * dp_size, group_indices - @classmethod + @staticmethod def data_parallel_merge( - cls, results: list[Any], group_indices: list[list[int]] | None + results: list[Any], group_indices: list[list[int]] | None ) -> Any: if not results: return None @@ -189,22 +386,25 @@ def data_parallel_merge( # Check for raw tensors - not allowed if isinstance(first, torch.Tensor): - raise TypeError("Regular tensors not allowed in merge - only RTensors") + raise TypeError( + "Regular tensors not allowed in merge - only RTensors. " + "Engine outputs should be automatically converted to RTensors." + ) - if isinstance(first, cls): + if isinstance(first, RTensor): assert group_indices is not None rtensors = flat2d([r.split() for r in results]) indices = flat2d(group_indices) assert len(rtensors) == len(indices), (len(rtensors), len(indices)) inv_indices = np.zeros(len(indices), dtype=np.int64) inv_indices[indices] = np.arange(len(indices)) - return cls.cat([rtensors[i] for i in inv_indices]) + return RTensor.cat([rtensors[i] for i in inv_indices]) if isinstance(first, dict): merged = {} for key in first.keys(): values = [r[key] for r in results] - merged[key] = cls.data_parallel_merge( + merged[key] = RTensor.data_parallel_merge( values, group_indices=group_indices ) return merged @@ -214,7 +414,7 @@ def data_parallel_merge( for i in range(len(first)): elements = [r[i] for r in results] merged.append( - cls.data_parallel_merge(elements, group_indices=group_indices) + RTensor.data_parallel_merge(elements, group_indices=group_indices) ) return merged @@ -223,171 +423,13 @@ def data_parallel_merge( for i in range(len(first)): elements = [r[i] for r in results] merged.append( - cls.data_parallel_merge(elements, group_indices=group_indices) + RTensor.data_parallel_merge(elements, group_indices=group_indices) ) return tuple(merged) # Scalars: return first (assume synchronized) return first - @property - def shape(self): - return self.data.shape - - @property - def dtype(self): - return self.data.dtype - - @property - def device(self): - return self.data.device - - @property - def ndim(self): - return self.data.ndim - - @classmethod - @abstractmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - pass - - -@dataclass -class TensorShardInfo(BaseTensorShardInfo): - """Metadata for a single shard of an RTensor.""" - - shard_id: str - node_addr: str - - -@dataclass -class RTensor(BaseRTensor): - def to_local(self) -> torch.Tensor: - """Fetch all shards via HTTP, concatenate along dim 0.""" - if not self.data.is_meta: - return self.data - # Fetch all shards first - tensors = self._fetch() - self.data = _pad_cat_dim0(tensors) - return self.data - - def _fetch(self): - """Fetch all shards synchronously.""" - - async def _fetch_all(): - async with aiohttp.ClientSession() as session: - return await asyncio.gather( - *[ - RTensor._fetch_tensor(session, s.shard_id, s.node_addr) - for s in self.shards - ] - ) - - return asyncio.run(_fetch_all()) - - @staticmethod - async def _fetch_tensor( - session: aiohttp.ClientSession, shard_id: str, node_addr: str - ) -> torch.Tensor: - # Avoid circular import - from areal.scheduler.rpc.serialization import deserialize_value - - url = f"http://{node_addr}/data/{shard_id}" - async with session.get(url) as resp: - if resp.status != 200: - raise RuntimeError(f"Failed to fetch shard from {url}: {resp.status}") - data_bytes = await resp.read() - serialized_data = orjson.loads(data_bytes) - return deserialize_value(serialized_data) - - def split(self) -> list[RTensor]: - tensors = RTensor.split_tensor(self.data, self) - return [RTensor(shards=[s], data=t) for s, t in zip(self.shards, tensors)] - - @classmethod - def from_batched(cls, batch_tensor: torch.Tensor, layout: RTensor, node_addr: str): - if not batch_tensor.is_cpu and not batch_tensor.is_meta: - raise ValueError("RTensor shards must be on CPU or meta device") - - tensors = cls.split_tensor(batch_tensor, layout) - - shards = [] - for tensor, shard_info in zip(tensors, layout.shards): - sid = str(uuid.uuid4()) - info = TensorShardInfo( - shard_id=sid, - node_addr=node_addr, - size=shard_info.size, - seqlens=shard_info.seqlens.copy(), - ) - shards.append(info) - - # Truncate at the maximum sequence length - # to prevent over-padding - if tensor.ndim > 1: - tensor = tensor[:, : max(shard_info.seqlens)] - # Store locally - store(sid, tensor) - - return cls(shards=shards, data=batch_tensor.to("meta")) - - @staticmethod - def extract_layout(obj: Any, layouts: Any, node_addr: str | None): - # Determine if batched from input layouts - layout_rtensor = _find_in_structure(layouts, RTensor) - result_tensor = _find_in_structure(obj, torch.Tensor) - if layout_rtensor is None and result_tensor is not None: - if not isinstance(obj, dict): - raise RuntimeError( - "When input does not contain RTensor, " - "we expect to extract layouts from a dict batch " - f"returned by InferenceEngine. Get obj: {obj}, " - f"input layouts: {layouts}." - ) - attn_mask = obj.get("attention_mask", None) - if attn_mask is None: - raise RuntimeError("`attention_mask` is not found") - assert node_addr is not None - layout_rtensor = RTensor( - shards=[ - TensorShardInfo( - shard_id="", - node_addr=node_addr, - size=attn_mask.shape[0], - seqlens=[int(am.sum()) for am in attn_mask], - ) - ], - data=torch.empty_like(attn_mask, device="meta"), - ) - return layout_rtensor - - @staticmethod - def remotize(obj: Any, layout: RTensor, node_addr: str) -> Any: - if isinstance(obj, torch.Tensor): - return RTensor.from_batched( - obj.detach().cpu(), layout=layout, node_addr=node_addr - ) - - if isinstance(obj, dict): - return { - k: RTensor.remotize(obj=v, layout=layout, node_addr=node_addr) - for k, v in obj.items() - } - - if isinstance(obj, list): - return [ - RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) - for item in obj - ] - - if isinstance(obj, tuple): - return tuple( - RTensor.remotize(obj=item, layout=layout, node_addr=node_addr) - for item in obj - ) - - return obj - @staticmethod def collect_shards(obj: Any) -> dict[str, list[str]]: """Collect shard IDs grouped by node address from nested structure. @@ -412,6 +454,25 @@ def _collect(o): _collect(obj) return shards_by_node + async def clear_node(node_addr, shard_ids): + await RTensor.tensor_info_cls().delete_by_shard_id(node_addr, shard_ids) + + @property + def shape(self): + return self.data.shape + + @property + def dtype(self): + return self.data.dtype + + @property + def device(self): + return self.data.device + + @property + def ndim(self): + return self.data.ndim + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: From 9190dddaa2ce1d138911962ae26767fd421d53b0 Mon Sep 17 00:00:00 2001 From: Haley Date: Tue, 23 Dec 2025 12:35:09 -0800 Subject: [PATCH 13/15] Fix torch import --- areal/controller/train_controller.py | 1 - 1 file changed, 1 deletion(-) diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index bad9664e8..902bca9d6 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any -import torch import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader From fb40ac2bc294e661f0dde96043ce2b8958ee4da7 Mon Sep 17 00:00:00 2001 From: Haley Date: Tue, 23 Dec 2025 14:56:44 -0800 Subject: [PATCH 14/15] Support PPOTrainer change for RayScheduler --- areal/experimental/trainer/rl.py | 4 +++- areal/scheduler/__init__.py | 5 ++--- areal/scheduler/ray.py | 9 +++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/areal/experimental/trainer/rl.py b/areal/experimental/trainer/rl.py index 3c5332a99..a3464e035 100644 --- a/areal/experimental/trainer/rl.py +++ b/areal/experimental/trainer/rl.py @@ -36,7 +36,7 @@ from areal.engine.sglang_remote import RemoteSGLangEngine from areal.engine.vllm_remote import RemotevLLMEngine from areal.platforms import current_platform -from areal.scheduler import LocalScheduler +from areal.scheduler import LocalScheduler, RayScheduler from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller @@ -395,6 +395,8 @@ def _init_scheduler(self) -> Scheduler: cfg = self.config.scheduler if cfg.type == "local": return LocalScheduler(exp_config=self.config) + elif cfg.type == "ray": + return RayScheduler(exp_config=self.config) raise NotImplementedError(f"Unknown scheduler type: {cfg.type}") def _create_dataloader( diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py index 042374a5a..b9a6d5a1e 100644 --- a/areal/scheduler/__init__.py +++ b/areal/scheduler/__init__.py @@ -1,5 +1,4 @@ from .local import LocalScheduler +from .ray import RayScheduler -__all__ = [ - "LocalScheduler", -] +__all__ = ["LocalScheduler", "RayScheduler"] diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py index f5ae8bc4b..304cd0810 100644 --- a/areal/scheduler/ray.py +++ b/areal/scheduler/ray.py @@ -214,8 +214,9 @@ def _create_rollout_workers( # create placement_groups for idx, spec in enumerate(schedulings): worker_id = f"{role}/{idx}" - - bundles = [self._bundle_spec(spec.cpu, spec.gpu, spec.mem)] + # TODO: should later support some parameter whether to allocate gpus or not + gpu = 0 if "eval-rollout" in role else spec.gpu + bundles = [self._bundle_spec(spec.cpu, gpu, spec.mem)] pg = placement_group(bundles, strategy="PACK") try: @@ -234,7 +235,7 @@ def _create_rollout_workers( ) # define resources to actor - options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) + options = self._actor_resource_spec(spec.cpu, gpu, spec.mem) env = get_env_vars( self.exp_config, @@ -413,7 +414,7 @@ def create_workers(self, job: Job, *args, **kwargs) -> list[str]: "RayScheduler only supports separation strategy", ) - if role == "rollout": + if "rollout" in role: worker_info_list, worker_ids = self._create_rollout_workers( role, schedulings ) From 5f7967b069d9424da0b250af67b5960254a083af Mon Sep 17 00:00:00 2001 From: Haley Date: Tue, 23 Dec 2025 14:58:03 -0800 Subject: [PATCH 15/15] Remove gsm8k_grpo_ray.py as it is handled in a unified script now. --- examples/single-controller/gsm8k_grpo_ray.py | 215 ------------------- 1 file changed, 215 deletions(-) delete mode 100644 examples/single-controller/gsm8k_grpo_ray.py diff --git a/examples/single-controller/gsm8k_grpo_ray.py b/examples/single-controller/gsm8k_grpo_ray.py deleted file mode 100644 index e8db9ec17..000000000 --- a/examples/single-controller/gsm8k_grpo_ray.py +++ /dev/null @@ -1,215 +0,0 @@ -import os -import sys - -from areal.api.alloc_mode import AllocationMode -from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config, vLLMConfig -from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta -from areal.controller.rollout_controller import RolloutController -from areal.controller.train_controller import TrainController -from areal.dataset import get_custom_dataset -from areal.engine.ppo.actor import FSDPPPOActor -from areal.engine.sglang_remote import RemoteSGLangEngine -from areal.engine.vllm_remote import RemotevLLMEngine -from areal.scheduler.ray import RayScheduler -from areal.utils import stats_tracker -from areal.utils.dataloader import create_dataloader -from areal.utils.device import log_gpu_stats -from areal.utils.evaluator import Evaluator -from areal.utils.hf_utils import load_hf_tokenizer -from areal.utils.recover import RecoverHandler -from areal.utils.saver import Saver -from areal.utils.stats_logger import StatsLogger - - -def main(args): - config, _ = load_expr_config(args, GRPOConfig) - config: GRPOConfig - - tokenizer = load_hf_tokenizer(config.tokenizer_path) - - # Create dataset and dataloaders - train_dataset = get_custom_dataset( - split="train", dataset_config=config.train_dataset, tokenizer=tokenizer - ) - - train_dataloader = create_dataloader( - train_dataset, - rank=0, - world_size=1, - dataset_config=config.train_dataset, - ) - - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) - - # Initialize scheduler - scheduler = RayScheduler(exp_config=config) - - # Initialize train controller - allocation_mode = AllocationMode.from_str(config.allocation_mode) - actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) - actor.initialize( - role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None - ) - - # Initialize inference engine - - if allocation_mode.gen_backend == "sglang": - engine_class = RemoteSGLangEngine - server_args = SGLangConfig.build_args( - sglang_config=config.sglang, - tp_size=allocation_mode.gen.tp_size, - base_gpu_id=0, - ) - elif allocation_mode.gen_backend == "vllm": - engine_class = RemotevLLMEngine - server_args = vLLMConfig.build_args( - vllm_config=config.vllm, - tp_size=allocation_mode.gen.tp_size, - pp_size=allocation_mode.gen.pp_size, - ) - else: - raise ValueError(f"Unsupported gen_backend: '{allocation_mode.gen_backend}'") - - rollout = RolloutController( - engine_class, config=config.rollout, scheduler=scheduler - ) - rollout.initialize( - role="rollout", - alloc_mode=allocation_mode, - server_args=server_args, - ) - - weight_update_meta = WeightUpdateMeta.from_disk( - experiment_name=config.experiment_name, - trial_name=config.trial_name, - file_root=config.cluster.fileroot, - ) - actor.connect_engine(rollout, weight_update_meta) - - ref = None - if config.actor.kl_ctl > 0 and config.ref is not None: - ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) - ref.initialize( - role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None - ) - - # Run training. - saver = Saver(config.saver, ft_spec) - stats_logger = StatsLogger(config, ft_spec) - evaluator = Evaluator(config.evaluator, ft_spec) - - recover_handler = RecoverHandler(config.recover, ft_spec) - - try: - recover_info = recover_handler.load( - actor, - saver, - evaluator, - stats_logger, - train_dataloader, - inference_engine=rollout, - weight_update_meta=weight_update_meta, - ) - start_step = ( - recover_info.last_step_info.next().global_step - if recover_info is not None - else 0 - ) - - total_epochs = config.total_train_epochs - steps_per_epoch = len(train_dataloader) - max_steps = total_epochs * steps_per_epoch - - for global_step in range(start_step, max_steps): - epoch = global_step // steps_per_epoch - step = global_step % steps_per_epoch - step_info = StepInfo( - global_step=global_step, - epoch=epoch, - epoch_step=step, - steps_per_epoch=steps_per_epoch, - ) - - with stats_tracker.record_timing("rollout"): - workflow_kwargs = dict( - reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", - gconfig=config.gconfig, - tokenizer=config.tokenizer_path, - enable_thinking=False, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), - "generated", - ), - ) - rollout_batch = actor.prepare_batch( - train_dataloader, - workflow="areal.workflow.rlvr.RLVRWorkflow", - workflow_kwargs=workflow_kwargs, - ) - - if config.actor.recompute_logprob or config.actor.use_decoupled_loss: - with stats_tracker.record_timing("recompute_logp"): - prox_logp = actor.compute_logp(rollout_batch) - rollout_batch["prox_logp"] = prox_logp - log_gpu_stats("recompute logp") - - if ref is not None: - with stats_tracker.record_timing("ref_logp"): - ref_logp = ref.compute_logp(rollout_batch) - rollout_batch["ref_logp"] = ref_logp - log_gpu_stats("ref logp") - - with stats_tracker.record_timing("compute_advantage"): - adv_batch = actor.compute_advantages(rollout_batch) - log_gpu_stats("compute advantages") - - with stats_tracker.record_timing("train_step"): - actor.ppo_update(adv_batch) - actor.step_lr_scheduler() - log_gpu_stats("ppo update") - # pause inference for updating weights, save, and evaluation - rollout.pause() - - with stats_tracker.record_timing("update_weights"): - actor.update_weights(weight_update_meta) - - actor.set_version(global_step + 1) - rollout.set_version(global_step + 1) - - with stats_tracker.record_timing("save"): - saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) - - with stats_tracker.record_timing("checkpoint_for_recover"): - recover_handler.dump( - actor, - step_info, - saver, - evaluator, - stats_logger, - train_dataloader, - tokenizer=tokenizer, - ) - - with stats_tracker.record_timing("clear_batches"): - actor.clear_batches(rollout_batch, adv_batch) - - # Upload statistics to the logger (e.g., wandb) - stats_logger.commit(epoch, step, global_step, actor.export_stats()) - - # Resume rollout - rollout.resume() - - finally: - stats_logger.close() - rollout.destroy() - if ref is not None: - ref.destroy() - actor.destroy() - - -if __name__ == "__main__": - main(sys.argv[1:])