diff --git a/areal/controller/train_controller.py b/areal/controller/train_controller.py index d07f6637e..19c6d4599 100644 --- a/areal/controller/train_controller.py +++ b/areal/controller/train_controller.py @@ -4,7 +4,6 @@ from pathlib import Path from typing import Any -import aiohttp import torch.distributed as dist from torchdata.stateful_dataloader import StatefulDataLoader @@ -579,25 +578,14 @@ def update_weights(self, meta: WeightUpdateMeta): raise ValueError(f"Unknown weight update type {meta.type}") async def _async_clear_batches(self, *targets: dict[str, RTensor]): - """Extract shard IDs and call /data/clear on each worker.""" + """Extract shard IDs and clear tensors on each worker.""" shards_by_node = RTensor.collect_shards(targets) if not shards_by_node: return - async def clear_node(node_addr, shard_ids): - async with aiohttp.ClientSession() as session: - async with session.delete( - f"http://{node_addr}/data/clear", json={"shard_ids": shard_ids} - ) as resp: - if resp.status == 200: - result = await resp.json() - logger.info( - f"Cleared {result.get('cleared_count', 0)} shards on {node_addr}" - ) - await asyncio.gather( - *[clear_node(addr, sids) for addr, sids in shards_by_node.items()], + *[RTensor.clear_node(addr, sids) for addr, sids in shards_by_node.items()], return_exceptions=True, ) diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index bb123eeef..74dced343 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -14,6 +14,7 @@ from typing import Any, Protocol import aiohttp +import ray import requests import torch.distributed as dist import uvloop @@ -949,6 +950,13 @@ def launch_server(self, server_args: dict[str, Any]) -> LocalInfServerInfo: try: self._wait_for_server(address) self.local_server_processes.append(server_info) + if ray.is_initialized(): + # do not return with process for ray as it is not picklable + return LocalInfServerInfo( + host=server_args["host"], + port=server_args["port"], + process=None, + ) return server_info except TimeoutError: logger.warning( diff --git a/areal/experimental/trainer/rl.py b/areal/experimental/trainer/rl.py index 3c5332a99..a3464e035 100644 --- a/areal/experimental/trainer/rl.py +++ b/areal/experimental/trainer/rl.py @@ -36,7 +36,7 @@ from areal.engine.sglang_remote import RemoteSGLangEngine from areal.engine.vllm_remote import RemotevLLMEngine from areal.platforms import current_platform -from areal.scheduler import LocalScheduler +from areal.scheduler import LocalScheduler, RayScheduler from areal.utils import logging, perf_tracer, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.environ import is_single_controller @@ -395,6 +395,8 @@ def _init_scheduler(self) -> Scheduler: cfg = self.config.scheduler if cfg.type == "local": return LocalScheduler(exp_config=self.config) + elif cfg.type == "ray": + return RayScheduler(exp_config=self.config) raise NotImplementedError(f"Unknown scheduler type: {cfg.type}") def _create_dataloader( diff --git a/areal/scheduler/__init__.py b/areal/scheduler/__init__.py index 042374a5a..b9a6d5a1e 100644 --- a/areal/scheduler/__init__.py +++ b/areal/scheduler/__init__.py @@ -1,5 +1,4 @@ from .local import LocalScheduler +from .ray import RayScheduler -__all__ = [ - "LocalScheduler", -] +__all__ = ["LocalScheduler", "RayScheduler"] diff --git a/areal/scheduler/ray.py b/areal/scheduler/ray.py new file mode 100644 index 000000000..304cd0810 --- /dev/null +++ b/areal/scheduler/ray.py @@ -0,0 +1,639 @@ +import asyncio +import math +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + +import ray +import ray.exceptions +import torch +from ray.runtime_env import RuntimeEnv +from ray.util.placement_group import ( + PlacementGroup, + placement_group, + remove_placement_group, +) +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.scheduler_api import Job, Scheduler, SchedulingSpec, Worker +from areal.scheduler.exceptions import ( + EngineCallError, + WorkerCreationError, + WorkerFailedError, + WorkerNotFoundError, + WorkerTimeoutError, +) +from areal.scheduler.rpc.ray_rpc_server import RayRPCServer +from areal.utils import logging +from areal.utils.launcher import get_env_vars +from areal.utils.ray import get_placement_group_master_ip_and_port + +logger = logging.getLogger("RayScheduler") + + +def ray_resource_type(): + if torch.cuda.is_available(): + return "GPU" + + from areal.platforms import is_npu_available + + if is_npu_available: + return "NPU" + + return "CPU" + + +@dataclass +class RayWorkerInfo: + worker: Worker + actor: ray.actor.ActorHandle + role: str + placement_group: PlacementGroup + bundle_index: int | None + created_at: float + env_vars: dict[str, str] = field(default_factory=dict) + + +class RayScheduler(Scheduler): + def __init__( + self, + startup_timeout: float = 30.0, + *, + exp_config: BaseExperimentConfig | None = None, + ): + self.exp_config = exp_config + self.startup_timeout = startup_timeout + + self._workers: dict[str, list[RayWorkerInfo]] = defaultdict(list) + self._worker_info_by_id: dict[str, RayWorkerInfo] = {} + self._placement_groups: list[PlacementGroup] = [] + + def _prepare_worker_specs( + self, role: str, num_workers: int, schedulings: list[SchedulingSpec] | None + ) -> list[SchedulingSpec]: + if not schedulings: + raise WorkerCreationError( + role, "Invalid configuration", "Tasks SchedulingSpec must be provided" + ) + if len(schedulings) == 1: + return [schedulings[0]] * num_workers + + if len(schedulings) == num_workers: + return schedulings + + raise WorkerCreationError( + role, + "Invalid Configuration", + f"schedulings length ({len(schedulings)}) must be 1 or equal to replicas ({num_workers})", + ) + + def _bundle_spec(self, cpu: int, gpu: int, mem: int) -> dict: + """ + define a bundle dict for a given cpu, gpu, mem requirement + """ + device = ray_resource_type() + if device == "CPU" and gpu > 0: + raise ValueError( + f"Current detected device is CPU but specified number of GPUs is {gpu}" + ) + return { + "CPU": cpu, + device: float(gpu), + "memory": mem * 1024 * 1024, # convert mb to bytes + } + + def _create_bundle_list_gpu(self, cpu: int, gpu: int, mem: int) -> list[dict]: + """ + for dividing out resources so that 1 bundle can be contained on 1 node and creates a list of bundles + """ + bundle_list = [] + + n_gpus_per_node = self.exp_config.cluster.n_gpus_per_node + + if n_gpus_per_node == 0 and gpu > 0: + raise ValueError( + f"Requested {gpu} GPUs but number of GPUs per node is {n_gpus_per_node}" + ) + + if gpu < n_gpus_per_node: + return [self._bundle_spec(cpu, gpu, mem)] + + gpu_remaining_to_be_assigned = gpu + + while gpu_remaining_to_be_assigned > 0: + # do not want to take all gpus in node if we do not need that many + gpu_in_bundle = min(gpu_remaining_to_be_assigned, n_gpus_per_node) + + # for scaling the amount of cpu and memory relative to gpu in bundle + resource_per_node_multiplier = min(gpu_in_bundle / gpu, 1) + cpu_in_bundle = math.ceil(cpu * resource_per_node_multiplier) + mem_in_bundle = math.ceil(mem * resource_per_node_multiplier) + + bundle_list.append( + self._bundle_spec(cpu_in_bundle, gpu_in_bundle, mem_in_bundle) + ) + gpu_remaining_to_be_assigned -= gpu_in_bundle + + return bundle_list + + def _actor_resource_spec(self, cpu: int, gpu: int, mem: int) -> dict: + """ + create a dictionary for passing into ray actor options specifying resource requirements + """ + + device = ray_resource_type() + if device == "CPU" and gpu > 0: + raise ValueError( + f"Current detected device is CPU but specified number of GPUs is {gpu}" + ) + + return { + "num_cpus": cpu, + "resources": {device: float(gpu)}, + "memory": mem * 1024 * 1024, + } + + def _sum_resource_spec( + self, schedulings: list[SchedulingSpec] + ) -> tuple[int, int, int]: + num_cpu = sum(spec.cpu for spec in schedulings) + num_gpu = sum(spec.gpu for spec in schedulings) + num_mem = sum(spec.mem for spec in schedulings) + + return (num_cpu, num_gpu, num_mem) + + def _ping_workers(self, role: str, timeout: float | None = None): + worker_info_list = self._workers[role] + timeout = timeout if timeout is not None else self.startup_timeout + refs = [wi.actor.ping.remote() for wi in worker_info_list] + + ref_to_worker = {ref: wi for wi, ref in zip(worker_info_list, refs)} + + pending = refs + while pending: + ready, pending = ray.wait(pending, num_returns=1, timeout=timeout) + # ray.wait timed out + if len(ready) == 0: + raise WorkerTimeoutError(role, timeout) + + ref = ready[0] + + try: + # get to determine if this is a failed actor + ray.get(ref) + except ray.exceptions.GetTimeoutError: + failed_worker = ref_to_worker[ref] + raise WorkerTimeoutError(failed_worker.worker.id, timeout) + except ray.exceptions.RayActorError: + failed_worker = ref_to_worker[ref] + raise WorkerFailedError(failed_worker.worker.id, -1) + + def _create_rollout_workers( + self, role: str, schedulings: list[SchedulingSpec] + ) -> tuple[list[RayWorkerInfo], list[str]]: + """ + Crate rollout workers, assuming 1 worker per rollout instance. + + Parameters + --------- + role: str + schedulings: list[SchedulingSpec] + + Returns + -------- + Tuple[list[RayWorkerInfo], list[str]] + List of RayWorkerInfo of created workers + List of worker IDs created + """ + + worker_info_list: list[RayWorkerInfo] = [] + worker_ids: list[str] = [] + + # create placement_groups + for idx, spec in enumerate(schedulings): + worker_id = f"{role}/{idx}" + # TODO: should later support some parameter whether to allocate gpus or not + gpu = 0 if "eval-rollout" in role else spec.gpu + bundles = [self._bundle_spec(spec.cpu, gpu, spec.mem)] + pg = placement_group(bundles, strategy="PACK") + + try: + ray.get(pg.ready(), timeout=self.startup_timeout) + except ray.exceptions.GetTimeoutError: + logger.error( + f"Ray placement group timeout for train role {role}\n" + f"ray.nodes(): {ray.nodes()}" + f"bundles: {bundles}" + ) + raise + self._placement_groups.append(pg) + + master_ip, master_port = get_placement_group_master_ip_and_port( + pg, placement_group_bundle_index=0 + ) + + # define resources to actor + options = self._actor_resource_spec(spec.cpu, gpu, spec.mem) + + env = get_env_vars( + self.exp_config, + ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]), + ) + + if spec.env_vars: + env.update(spec.env_vars) + + actor = RayRPCServer.options( + **options, + name=worker_id, + runtime_env=RuntimeEnv(env_vars=env), + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + placement_group_capture_child_tasks=True, + ), + ).remote() + + # 0 needed to pad the list as the trainer takes index 1 for ports + worker_ports = ["0", str(master_port)] + worker = Worker( + id=worker_id, ip=master_ip, worker_ports=worker_ports, engine_ports=[] + ) + + wi = RayWorkerInfo( + worker=worker, + actor=actor, + role=role, + placement_group=pg, + bundle_index=0, + created_at=time.time(), + env_vars=env, + ) + + worker_info_list.append(wi) + worker_ids.append(worker_id) + + return worker_info_list, worker_ids + + def _create_train_workers( + self, role: str, schedulings: list[SchedulingSpec] + ) -> tuple[list[RayWorkerInfo], list[str]]: + """ + Create workers for training roles. One PG per role with multiple bundles. + Assume 1 ray worker per train rank. + + Parameters + --------- + role: str + schedulings: list[SchedulingSpec] + + Returns + -------- + Tuple[list[RayWorkerInfo], list[str]] + List of RayWorkerInfo of created workers + List of worker IDs created + """ + # build bundles + sum_cpu, sum_gpu, sum_mem = self._sum_resource_spec(schedulings) + bundles: list[dict[str, float]] = self._create_bundle_list_gpu( + sum_cpu, sum_gpu, sum_mem + ) + + pg = placement_group(bundles=bundles, strategy="PACK") + + try: + ray.get(pg.ready(), timeout=self.startup_timeout) + except ray.exceptions.GetTimeoutError: + logger.error( + f"Ray placement group timeout for train role {role}\n" + f"ray.nodes(): {ray.nodes()}" + f"bundles: {bundles}" + ) + raise + + self._placement_groups.append(pg) + + master_ip, master_port = get_placement_group_master_ip_and_port( + pg, placement_group_bundle_index=0 + ) + + worker_info_list: list[RayWorkerInfo] = [] + worker_ids: list[str] = [] + + for idx, spec in enumerate(schedulings): + worker_id = f"{role}/{idx}" + + options = self._actor_resource_spec(spec.cpu, spec.gpu, spec.mem) + + env = get_env_vars( + self.exp_config, + ",".join([f"{k}={v}" for k, v in spec.env_vars.items()]), + ) + + if spec.env_vars: + env.update(spec.env_vars) + + actor = RayRPCServer.options( + **options, + name=worker_id, + runtime_env=RuntimeEnv(env_vars=env), + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_capture_child_tasks=True + ), + ).remote() + + worker_ports = ["0", str(master_port)] + worker = Worker( + id=worker_id, ip=master_ip, worker_ports=worker_ports, engine_ports=[] + ) + + wi = RayWorkerInfo( + worker=worker, + actor=actor, + role=role, + placement_group=pg, + bundle_index=None, # decided by ray + created_at=time.time(), + env_vars=env, + ) + worker_info_list.append(wi) + worker_ids.append(worker_id) + + return worker_info_list, worker_ids + + def create_workers(self, job: Job, *args, **kwargs) -> list[str]: + """ + Create worker actors. + + Parameters + -------- + job: Job + Job configuration with role, replicas, tasks, scheduling strategy + *args + Additional arguments (UNUSED) + **kwargs + Additional keyword arguments (UNUSED) + + Returns + -------- + list[str] + List of worker IDs created (e.g., ["rollout/0", "rollout/1]) + + Raises + -------- + WorkerCreationError + If worker creation fails + """ + role = job.role + if role in self._workers: + raise WorkerCreationError( + role, + "Worker group already exists", + f"Use delete_workers('{role}') first to remove existing workers.", + ) + + num_workers = job.replicas + if num_workers == 0: + raise WorkerCreationError( + role, "Invalud configuration", "replicas must be greater than 0" + ) + + schedulings = self._prepare_worker_specs(role, num_workers, job.tasks) + + strategy = job.scheduling_strategy + if strategy is None: + strategy_type = "separation" + else: + strategy_type = strategy.type or "separation" + if strategy_type == "colocation": + raise WorkerCreationError( + role, + "Unavailable strategy type", + "RayScheduler only supports separation strategy", + ) + + if "rollout" in role: + worker_info_list, worker_ids = self._create_rollout_workers( + role, schedulings + ) + else: + worker_info_list, worker_ids = self._create_train_workers(role, schedulings) + + self._workers[role].extend(worker_info_list) + + for wi in worker_info_list: + self._worker_info_by_id[wi.worker.id] = wi + + self._ping_workers(role, self.startup_timeout) + + if self.exp_config is not None: + for rank, wi in enumerate(worker_info_list): + try: + wi.actor.configure.remote(self.exp_config, wi.role, rank) + except Exception as e: + logger.error( + f"Configure failed on worker {wi.worker.id}: {e}", exc_info=True + ) + self._cleanup_workers(worker_info_list) + raise WorkerCreationError( + role, "Worker configuration failed", str(e) + ) + + return worker_ids + + def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: + if role not in self._workers: + raise WorkerNotFoundError(role) + + worker_info_list = self._workers[role] + + self._ping_workers(role, timeout) + + return [wi.worker for wi in worker_info_list] + + def delete_workers(self, role: str | None = None): + """ + Delete workers and clean up resources + + Parameters + -------- + role: str, optional + Specific worker role to delete, or None to delete all + """ + if role is None: + roles = list(self._workers.keys()) + for r in roles: + self.delete_workers(r) + return + + if role not in self._workers: + logger.warning(f"Worker role '{role}' not found, skipping deletion") + return + + workers = self._workers[role] + logger.info(f"Deleting {len(workers)} workers for role '{role}'") + + self._cleanup_workers(workers) + + del self._workers[role] + + logger.info(f"Successfully deleted workers for role '{role}'") + + def _cleanup_workers(self, workers: list[RayWorkerInfo]): + # Kill actors first + for wi in workers: + actor = wi.actor + try: + # Asynchronously destroy actor + actor.destroy.remote() + except Exception: + logger.warning( + f"Could not destroy remote actor {actor}, force killing actor" + ) + ray.kill(actor, no_restart=True) + + # Collect unique placement groups and remove them + unique_pgs = {wi.placement_group for wi in workers} + for pg in unique_pgs: + try: + remove_placement_group(pg) + except Exception: + logger.warning(f"Could not remove placement group {pg}") + if pg in self._placement_groups: + self._placement_groups.remove(pg) + + def _get_worker_info_by_id(self, worker_id: str) -> RayWorkerInfo | None: + return self._worker_info_by_id.get(worker_id, None) + + async def set_worker_env(self, worker_id: str, env: dict[str, str]) -> None: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + if not env: + return + + await wi.actor.set_env.remote(env) + wi.env_vars.update(env) + + async def create_engine(self, worker_id: str, engine: str, *args, **kwargs) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + if not isinstance(engine, str): + raise WorkerCreationError( + worker_id, f"Engine must be a string import path, got {type(engine)}" + ) + await wi.actor.create_engine.remote(engine, *args, **kwargs) + + def call_engine( + self, + worker_id: str, + method: str, + *args, + http_timeout: float = 7200.0, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + last_error: str | None = None + + for attempt in range(1, max_retries + 1): + try: + ref = wi.actor.call.remote(method, *args, **kwargs) + result = ray.get(ref, timeout=http_timeout) + if attempt > 1: + logger.info( + f"Method '{method}' on '{worker_id}' " + f"succeeded after {attempt} attempts" + ) + return result + except ray.exceptions.GetTimeoutError as e: + last_error = f"Timeout: {e}" + except ray.exceptions.RayActorError as e: + raise WorkerFailedError(worker_id, -1, str(e)) from e + except ray.exceptions.RayTaskError as e: + raise EngineCallError(worker_id, method, str(e), attempt) from e + except EngineCallError: + raise + except Exception as e: + last_error = f"Ray call failed: {e}" + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + time.sleep(delay) + + raise EngineCallError( + worker_id, method, last_error or "Max retries exceeded", attempt=max_retries + ) + + async def async_call_engine( + self, + worker_id: str, + method: str, + *args, + http_timeout: float = 7200.0, + max_retries: int = 3, + retry_delay: float = 1.0, + **kwargs, + ) -> Any: + wi = self._get_worker_info_by_id(worker_id) + if wi is None: + raise WorkerNotFoundError(worker_id) + + last_error: str | None = None + + for attempt in range(1, max_retries + 1): + try: + ref = wi.actor.call.remote(method, *args, **kwargs) + result = await ref + if attempt > 1: + logger.info( + f"Method '{method}' on '{worker_id}' " + f"succeeded after {attempt} attempts" + ) + return result + except ray.exceptions.GetTimeoutError as e: + last_error = f"Timeout: {e}" + except ray.exceptions.RayActorError as e: + raise WorkerFailedError(worker_id, -1, str(e)) from e + except ray.exceptions.RayTaskError as e: + raise EngineCallError(worker_id, method, str(e), attempt) from e + except EngineCallError: + raise + except Exception as e: + last_error = f"Ray async call failed: {e}" + + # Retry with exponential backoff + if attempt < max_retries: + delay = retry_delay * (2 ** (attempt - 1)) + logger.warning( + f"Method '{method}' failed on worker '{worker_id}' " + f"(attempt {attempt}/{max_retries}): {last_error}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + + raise EngineCallError( + worker_id, method, last_error or "Max retries exceeded", attempt=max_retries + ) + + def __del__(self): + # delete in case delete_workers is not called from controllers + # explicit shutdown is by directly calling delete_workers + try: + self.delete_workers() + except Exception: + pass diff --git a/areal/scheduler/rpc/ray_rpc_server.py b/areal/scheduler/rpc/ray_rpc_server.py new file mode 100644 index 000000000..cb5e75a25 --- /dev/null +++ b/areal/scheduler/rpc/ray_rpc_server.py @@ -0,0 +1,140 @@ +import os +import traceback +from concurrent.futures import Future +from typing import Any + +import ray + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.engine_api import InferenceEngine, TrainEngine +from areal.scheduler.rpc.rtensor import RTensor +from areal.utils import logging, name_resolve, seeding +from areal.utils.data import ( + broadcast_tensor_container, + tensor_container_to, +) +from areal.utils.dynamic_import import import_from_string + + +@ray.remote +class RayRPCServer: + """ + Ray engine container. Represents either: + - one training world rank, or + - one rollout instance + + Placement group scheduling is controlled by the scheduler. + The actor is only responsible for the engine lifecycle and method calls + within this process. + """ + + def __init__(self): + self._engine: TrainEngine | InferenceEngine | None = None + self.logger = logging.getLogger("RayRPCServer") + + def _get_device(self): + # lazy resolve the device inside worker process + from areal.platforms import current_platform + + return current_platform.current_device() + + def ping(self) -> str: + return "ok" + + def configure(self, config: BaseExperimentConfig, role: str, rank: int) -> None: + name_resolve.reconfigure(config.cluster.name_resolve) + if isinstance(self._engine, TrainEngine): + seeding.set_random_seed(config.seed, key=f"{role}{rank}") + self.logger.info(f"RayRPCServer configured for role role={role}, rank={rank}") + + def set_env(self, env: dict[str, str]) -> None: + for k, v in env.items(): + os.environ[str(k)] = str(v) + + def create_engine(self, engine_path: str, *init_args, **init_kwargs) -> None: + try: + engine_class = import_from_string(engine_path) + self.logger.debug(f"Initializing engine {engine_class}") + if not issubclass(engine_class, (TrainEngine, InferenceEngine)): + raise TypeError( + f"Engine class must be a TrainEngine or InferenceEngine, but got {engine_class}" + ) + self._engine = engine_class(*init_args, **init_kwargs) + self.logger.info( + f"RayRPCServer Engine '{engine_path}' instantiated successfully!" + ) + except Exception as e: + self.logger.error( + f"RayRPCServer failed to create engine '{engine_path}' : {e}\n" + f"{traceback.format_exc()}" + ) + raise + + def call(self, method: str, *args, **kwargs) -> Any: + self.logger.debug(f"Calling {method} with arguments {args=} {kwargs=}") + if self._engine is None: + raise RuntimeError("Engine not initialized. Call create_engine() first") + + raw_args = list(args) + raw_kwargs = kwargs.copy() + # fetch remote tensors if any + args = RTensor.localize(raw_args) + kwargs = RTensor.localize(raw_kwargs) + + should_broadcast = kwargs.pop("should_broadcast", True) + + # keep broadcast behavior the same as RPCServer + try: + if should_broadcast and isinstance(self._engine, TrainEngine): + device = self._get_device() + + args = tensor_container_to(args, device) + args = broadcast_tensor_container( + args, + src_rank=self._engine.current_data_parallel_head(), + group=self._engine.context_and_model_parallel_group, + ) + kwargs = tensor_container_to(kwargs, device) + kwargs = broadcast_tensor_container( + kwargs, + src_rank=self._engine.current_data_parallel_head(), + group=self._engine.context_and_model_parallel_group, + ) + except Exception as e: + self.logger.error( + f"RayRPCServer broadcast failed for '{method}': {e}\n" + f"{traceback.format_exc()}" + ) + raise + + try: + fn = getattr(self._engine, method) + result = fn(*args, **kwargs) + if isinstance(result, Future): + result = result.result() + # Convert all tensors to RTensors and store the tensor locally + layout = RTensor.extract_layout( + result, layouts=dict(args=raw_args, kwargs=raw_kwargs), node_addr="" + ) + if layout is not None: + result = RTensor.remotize(result, layout, node_addr="") + # put back to cpu to mimic RPCServer encode/decode + result = tensor_container_to(result, "cpu") + self.logger.debug(f"Successfully completed RayRPCServer call {result}") + return result + except Exception as e: + self.logger.error( + f"RayRPCServer Engine method '{method}' failed: {e}\n" + f"{traceback.format_exc()}" + ) + raise + + def destroy(self) -> None: + if self._engine is not None: + try: + self._engine.destroy() + self.logger.info("RayRPCServer Engine destroyed successfully") + except Exception as e: + self.logger.error(f"RayRPCServer error destroying engine: {e}") + self._engine = None + ray.actor.exit_actor() diff --git a/areal/scheduler/rpc/rtensor.py b/areal/scheduler/rpc/rtensor.py index e784e98fb..f2e29cfe2 100644 --- a/areal/scheduler/rpc/rtensor.py +++ b/areal/scheduler/rpc/rtensor.py @@ -2,62 +2,77 @@ import asyncio import uuid +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from threading import Lock -from typing import Any +from typing import Any, ClassVar import aiohttp import numpy as np import orjson +import ray import torch from areal.utils.datapack import ffd_allocate, flat2d @dataclass -class TensorShardInfo: - """Metadata for a single shard of an RTensor.""" - - shard_id: str - node_addr: str +class BaseTensorShardInfo(ABC): size: int # Batch size (shape[0]) of this shard seqlens: list[int] # Sequence lengths of this shard (from attention_mask) + shard_id: str + node_addr: str + @classmethod + @abstractmethod + def fetch(cls: type[BaseTensorShardInfo], shards: BaseTensorShardInfo): + pass -@dataclass -class RTensor: - """Single tensor distributed as CPU shards across nodes.""" + @classmethod + @abstractmethod + def store(cls: type[BaseTensorShardInfo], shard_id: str, tensor: torch.Tensor): + pass - shards: list[TensorShardInfo] - data: torch.Tensor + @classmethod + @abstractmethod + def create( + cls: type[BaseTensorShardInfo], + *, + size: int, + seqlens: list[int], + **kwargs, + ) -> BaseTensorShardInfo: + pass - def to_local(self) -> torch.Tensor: - """Fetch all shards via HTTP, concatenate along dim 0.""" - if not self.data.is_meta: - return self.data - # Fetch all shards first - tensors = self._fetch() - self.data = _pad_cat_dim0(tensors) - return self.data + @classmethod + @abstractmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + pass - def _fetch(self): + +@dataclass +class TensorShardInfo(BaseTensorShardInfo): + """Metadata for a single shard of an RTensor.""" + + @classmethod + def fetch(cls, shards): """Fetch all shards synchronously.""" async def _fetch_all(): async with aiohttp.ClientSession() as session: return await asyncio.gather( *[ - RTensor._fetch_tensor(session, s.shard_id, s.node_addr) - for s in self.shards + cls._fetch_tensor(session, s.shard_id, s.node_addr) + for s in shards ] ) return asyncio.run(_fetch_all()) - @staticmethod + @classmethod async def _fetch_tensor( - session: aiohttp.ClientSession, shard_id: str, node_addr: str + cls, session: aiohttp.ClientSession, shard_id: str, node_addr: str ) -> torch.Tensor: # Avoid circular import from areal.scheduler.rpc.serialization import deserialize_value @@ -70,6 +85,85 @@ async def _fetch_tensor( serialized_data = orjson.loads(data_bytes) return deserialize_value(serialized_data) + @classmethod + def store(cls, shard_id: str, tensor: torch.Tensor): + store(shard_id, tensor) + return shard_id + + @classmethod + def create(cls, *, size, seqlens, shard_id: str, node_addr: str, **_): + return cls( + shard_id=shard_id, + node_addr=node_addr, + size=size, + seqlens=seqlens, + ) + + @classmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + async with aiohttp.ClientSession() as session: + async with session.delete( + f"http://{node_addr}/data/clear", json={"shard_ids": shard_ids} + ) as resp: + if resp.status == 200: + await resp.json() + + +@dataclass +class RayTensorShardInfo(BaseTensorShardInfo): + shard_id: ray.ObjectRef + + @classmethod + def fetch(cls, shards: RayTensorShardInfo) -> list[torch.Tensor]: + return ray.get([s.shard_id for s in shards]) + + @classmethod + def store(cls, shard_id: str, tensor: torch.Tensor): + return ray.put(tensor) + + @classmethod + def create(cls, *, size, seqlens, shard_id=None, **_): + return cls( + shard_id=shard_id, + size=size, + seqlens=seqlens, + node_addr="", + ) + + @classmethod + async def delete_by_shard_id(cls, node_addr, shard_ids): + ray.internal.free(shard_ids) + + +@dataclass +class RTensor: + """Single tensor distributed as CPU shards across nodes.""" + + shards: list[BaseTensorShardInfo] + data: torch.Tensor + _tensor_info_cls: ClassVar[type[BaseTensorShardInfo]] = None + + @classmethod + def tensor_info_cls(cls) -> type[BaseTensorShardInfo]: + if cls._tensor_info_cls is None: + if ray.is_initialized(): + cls._tensor_info_cls = RayTensorShardInfo + else: + cls._tensor_info_cls = TensorShardInfo + return cls._tensor_info_cls + + def to_local(self) -> torch.Tensor: + """Fetch all shards via HTTP, concatenate along dim 0.""" + if not self.data.is_meta: + return self.data + # Fetch all shards first + tensors = self._fetch() + self.data = _pad_cat_dim0(tensors) + return self.data + + def _fetch(self): + return RTensor.tensor_info_cls().fetch(self.shards) + @staticmethod def split_tensor(batch_tensor: torch.Tensor, layout: RTensor) -> list[torch.Tensor]: offsets = np.cumsum([0] + [shard.size for shard in layout.shards]) @@ -94,20 +188,19 @@ def from_batched(cls, batch_tensor: torch.Tensor, layout: RTensor, node_addr: st shards = [] for tensor, shard_info in zip(tensors, layout.shards): sid = str(uuid.uuid4()) - info = TensorShardInfo( - shard_id=sid, - node_addr=node_addr, - size=shard_info.size, - seqlens=shard_info.seqlens.copy(), - ) - shards.append(info) - # Truncate at the maximum sequence length # to prevent over-padding if tensor.ndim > 1: tensor = tensor[:, : max(shard_info.seqlens)] # Store locally - store(sid, tensor) + shard_id = RTensor.tensor_info_cls().store(sid, tensor) + info = RTensor.tensor_info_cls().create( + size=shard_info.size, + seqlens=shard_info.seqlens.copy(), + shard_id=shard_id, + node_addr=node_addr, + ) + shards.append(info) return cls(shards=shards, data=batch_tensor.to("meta")) @@ -162,15 +255,15 @@ def extract_layout(obj: Any, layouts: Any, node_addr: str | None): if attn_mask is None: raise RuntimeError("`attention_mask` is not found") assert node_addr is not None + shard = RTensor.tensor_info_cls().create( + size=attn_mask.shape[0], + seqlens=[int(am.sum()) for am in attn_mask], + shard_id="", + node_addr=node_addr, + ) + layout_rtensor = RTensor( - shards=[ - TensorShardInfo( - shard_id="", - node_addr=node_addr, - size=attn_mask.shape[0], - seqlens=[int(am.sum()) for am in attn_mask], - ) - ], + shards=[shard], data=torch.empty_like(attn_mask, device="meta"), ) return layout_rtensor @@ -361,6 +454,9 @@ def _collect(o): _collect(obj) return shards_by_node + async def clear_node(node_addr, shard_ids): + await RTensor.tensor_info_cls().delete_by_shard_id(node_addr, shard_ids) + @property def shape(self): return self.data.shape diff --git a/areal/tests/test_ray_scheduler.py b/areal/tests/test_ray_scheduler.py new file mode 100644 index 000000000..caf4025a6 --- /dev/null +++ b/areal/tests/test_ray_scheduler.py @@ -0,0 +1,155 @@ +import asyncio +from unittest.mock import Mock, patch + +import ray +from ray.util.state import summarize_actors + +from areal.api.cli_args import BaseExperimentConfig +from areal.api.scheduler_api import Job, SchedulingSpec, Worker +from areal.scheduler.ray import RayScheduler, RayWorkerInfo, ray_resource_type + + +class TestRaySchedulerInitialization: + def test_init(self): + scheduler = RayScheduler( + startup_timeout=60.0, exp_config=BaseExperimentConfig() + ) + assert scheduler.startup_timeout == 60.0 + + +class TestWorkerCreationAndDeletion: + def test_create_delete_workers(self): + ray.init() + + config = BaseExperimentConfig() + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + + job = Job( + replicas=2, + role="train", + tasks=[ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + ], + ) + + # create workers + worker_ids = scheduler.create_workers(job) + assert len(worker_ids) == 2 + assert len(scheduler._workers["train"]) == 2 + + actor_summary = summarize_actors() + + assert ( + actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["ALIVE"] + == 2 + ) + assert len(scheduler.get_workers("train")) == 2 + + scheduler._ping_workers("train") + + # delete workers + scheduler.delete_workers() + assert len(scheduler._workers["train"]) == 0 + + actor_summary = summarize_actors() + assert ( + actor_summary["cluster"]["summary"]["RayRPCServer"]["state_counts"]["DEAD"] + == 2 + ) + + +class TestWorkerCallEngine: + def test_create_call_engine(self): + # to simulate an awaitable None + async def async_none(*args, **kwargs): + return None + + config = BaseExperimentConfig() + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + ray_actor_handle = Mock() + ray_actor_handle.create_engine.remote = async_none + + worker = RayWorkerInfo( + worker=Worker(id="test/0", ip="0.0.0.0"), + actor=ray_actor_handle, + role="test", + placement_group=None, + bundle_index=0, + created_at=0, + ) + + scheduler._workers["test"] = [worker] + scheduler._worker_info_by_id[worker.worker.id] = worker + + # create engine + result = asyncio.run( + scheduler.create_engine( + worker.worker.id, "test_engines.DummyEngine", name="TestEngine" + ) + ) + assert result is None + + # sync + ray_actor_handle.call.remote = lambda x: None + with patch("areal.scheduler.ray.ray.get", return_value=None): + result = scheduler.call_engine(worker.worker.id, "test_fn") + assert result is None + + # async + ray_actor_handle.call.remote = async_none + result = asyncio.run(scheduler.async_call_engine(worker.worker.id, "test_fn")) + assert result is None + + +class TestUtilityFunctions: + def test_utilities(self): + _num_gpu_per_node = 16 + config = BaseExperimentConfig() + + config.cluster.n_gpus_per_node = _num_gpu_per_node + + scheduler = RayScheduler(startup_timeout=60.0, exp_config=config) + + schedulings = [ + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + SchedulingSpec( + cpu=1, + mem=1024, + gpu=1, + ), + ] + + new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings) + assert len(new_schedulings) == 2 + for spec in new_schedulings: + assert spec.cpu == 1 + assert spec.mem == 1024 + assert spec.gpu == 1 + + # case where only 1 spec is passed but multiple workers + new_schedulings = scheduler._prepare_worker_specs("train", 2, schedulings[0:]) + assert len(new_schedulings) == 2 + for spec in new_schedulings: + assert spec.cpu == 1 + assert spec.mem == 1024 + assert spec.gpu == 1 + + bundle_list = scheduler._create_bundle_list_gpu(1, 24, 1024) + assert len(bundle_list) == 2 + for bundle in bundle_list: + assert bundle[ray_resource_type()] <= _num_gpu_per_node diff --git a/areal/utils/data.py b/areal/utils/data.py index fafb591e9..715b83234 100644 --- a/areal/utils/data.py +++ b/areal/utils/data.py @@ -336,7 +336,9 @@ def pad_and_stack_tensors_along_first_dim(tensor_list: list[torch.Tensor]): def tensor_container_to( - d: dict[str, Any] | torch.Tensor | list[torch.Tensor], *args, **kwargs + d: dict[str, Any] | torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...], + *args, + **kwargs, ): """Apply `t.to(*args, **kwargs)` to all tensors in the dictionary. Support nested dictionaries. @@ -344,7 +346,7 @@ def tensor_container_to( if torch.is_tensor(d): return d.to(*args, **kwargs) - if isinstance(d, list): + if isinstance(d, list) or isinstance(d, tuple): return [tensor_container_to(v, *args, **kwargs) for v in d] if isinstance(d, dict):