From 2b9288559f094ebcab906529fa604096ae81df39 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Fri, 14 Nov 2025 16:36:14 -0800 Subject: [PATCH 1/5] Add rollout scale-up feature --- areal/api/cli_args.py | 25 ++ areal/core/remote_inf_engine.py | 39 +++ areal/engine/fsdp_engine.py | 13 +- areal/engine/vllm_remote.py | 11 +- areal/launcher/ray.py | 36 ++- areal/launcher/scaler/scaling_controller.py | 227 ++++++++++++++++++ .../thirdparty/vllm/vllm_worker_extension.py | 3 +- areal/utils/scaling.py | 65 +++++ examples/math/gsm8k_grpo.py | 18 +- examples/math/gsm8k_grpo_npu_scale.yaml | 157 ++++++++++++ 10 files changed, 583 insertions(+), 11 deletions(-) create mode 100644 areal/launcher/scaler/scaling_controller.py create mode 100644 areal/utils/scaling.py create mode 100644 examples/math/gsm8k_grpo_npu_scale.yaml diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 039177e08..086ce7a2a 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -928,6 +928,29 @@ class InferenceEngineConfig: ) +@dataclass +class ScalingConfig: + """Configuration for dynamic scaling of inference/training servers.""" + + enable_scaling: bool = field( + default=False, + metadata={"help": "Whether scaling is enabled (True/False)."}, + ) + + mode: str = field( + default="manual", + metadata={ + "help": "Scaling mode — can be 'manual' or 'auto'.", + "choices": ["manual", "auto"], + }, + ) + + scaling_controller_port: int = field( + default=8899, + metadata={"help": "HTTP port for the scale-up service endpoint."}, + ) + + @dataclass class _Timer: experiment_name: str = MISSING @@ -1341,6 +1364,8 @@ class BaseExperimentConfig: scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + scaling: ScalingConfig = field(default_factory=ScalingConfig) + @dataclass class SFTConfig(BaseExperimentConfig): diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 34f4b2b00..f49149ec2 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -253,6 +253,7 @@ def __init__( self.workflow_executor: WorkflowExecutor self.local_server_processes: list[LocalInfServerInfo] = [] + self.update_servers = False def _get_or_create_session(self) -> aiohttp.ClientSession: """Get or create a ClientSession for the current thread/event loop. @@ -424,6 +425,27 @@ def get_version(self): with self.lock: return self._version + def refresh_addresses(self, new_addresses: list[str]) -> None: + """ + Refresh the list of available servers dynamically. + + Args: + new_addresses (list[str]): Updated list of server addresses. + """ + if not new_addresses: + raise RuntimeError("No servers provided when refreshing addresses.") + + # Only log if there's an actual change + if new_addresses != self.addresses: + self.logger.info(f"Refreshing server addresses: {new_addresses}") + + # Replace with the new set + self.addresses = new_addresses + + # Clamp server_idx to valid range + if self.server_idx >= len(self.addresses): + self.server_idx = 0 + def choose_server(self) -> str: """Choose a server based on the scheduling policy. @@ -437,7 +459,16 @@ def choose_server(self) -> str: NotImplementedError If schedule policy other than round-robin is used """ + + if self.update_servers: + name = names.gen_servers( + self.config.experiment_name, self.config.trial_name + ) + vllm_addrs = name_resolve.get_subtree(name) + self.refresh_addresses(vllm_addrs) + self.update_servers = False if self.config.schedule_policy == "round_robin": + self.server_idx %= len(self.addresses) server = self.addresses[self.server_idx] self.server_idx = (self.server_idx + 1) % len(self.addresses) return server @@ -591,6 +622,12 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: assert meta.type == current_platform.communication_backend assert not self.distributed_weight_update_initialized + # Refresh the gen servers if there is scale request + name = names.gen_servers(self.config.experiment_name, self.config.trial_name) + vllm_addrs = name_resolve.get_subtree(name) + if vllm_addrs != self.addresses: + self.refresh_addresses(vllm_addrs) + fut = self.executor.submit( _init_weights_update_group_remote, self.backend, @@ -845,6 +882,8 @@ def pause(self): """Pause request submission for async rollout. Used during evaluation to prevent data over generation. """ + # Whenever pause for update weight, make update_servers True to dispatch request to new servers + self.update_servers = True return self.workflow_executor.pause() def resume(self): diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 237a29d19..0f18f5dc4 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -83,6 +83,8 @@ def __init__(self, config: TrainEngineConfig): self.rank: int self.dp_head: int self.dp_rank: int + self.scaling_count = 0 + self.create_group_count = 0 @property def data_parallel_group(self) -> dist.ProcessGroup: @@ -376,14 +378,21 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta): ) self.weight_update_group = init_custom_process_group( backend=current_platform.communication_backend, - world_size=meta.alloc_mode.gen.world_size + 1, + world_size=meta.alloc_mode.gen.world_size + 1 + self.scaling_count, init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}", rank=0, - group_name=meta.nccl_group_name, + group_name=meta.nccl_group_name + str(self.create_group_count), timeout=DIST_GROUP_DEFAULT_TIMEOUT, ) + self.create_group_count += 1 fut.result() + self.rollout_engine._engine.backend.create_group_count += 1 + + def _re_init_weight_update_from_distributed(self, meta: WeightUpdateMeta): + self.weight_update_group_initialized = False + self._init_weight_update_from_distributed(meta) + self.weight_update_group_initialized = True @trace_perf("fsdp_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta): diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 306a0fc45..2abb5e703 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -29,6 +29,10 @@ class VLLMBackend: """vLLM-specific backend implementation for remote inference.""" + def __init__(self): + self.scaling_count = 0 + self.create_group_count = 0 + def build_generation_request( self, req: ModelRequest, with_lora: bool ) -> HttpRequest: @@ -106,7 +110,8 @@ def build_distributed_weight_update_requests( "names": [pspec.name for pspec in param_specs], "dtypes": [pspec.dtype for pspec in param_specs], "shapes": [pspec.shape for pspec in param_specs], - "group_name": meta.nccl_group_name, + "group_name": meta.nccl_group_name + + str(self.create_group_count), }, ), HttpRequest( @@ -128,9 +133,9 @@ def build_init_weights_group_request( "master_address": meta.nccl_master_address, "master_port": str(meta.nccl_master_port), "rank_offset": rank_offset, - "world_size": meta.alloc_mode.gen.world_size + 1, + "world_size": meta.alloc_mode.gen.world_size + 1 + self.scaling_count, "backend": current_platform.communication_backend, - "group_name": meta.nccl_group_name, + "group_name": meta.nccl_group_name + str(self.create_group_count), } return HttpRequest(endpoint="/areal_init_weights_update_group", payload=payload) diff --git a/areal/launcher/ray.py b/areal/launcher/ray.py index 7832a1419..77459ee26 100644 --- a/areal/launcher/ray.py +++ b/areal/launcher/ray.py @@ -1,6 +1,8 @@ import importlib.util +import os import pathlib import re +import subprocess import sys import time from collections.abc import Callable @@ -18,6 +20,7 @@ ClusterSpecConfig, LauncherConfig, RecoverConfig, + ScalingConfig, SGLangConfig, parse_cli_args, to_structured_cfg, @@ -41,7 +44,23 @@ RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds DEFAULT_MAIN_FUNC_NAME = "main" RAY_LAUNCHER = None -RECOVER_TIME_INTERVAL = 10 # seconds +RECOVER_TIME_INTERVAL = 10 # second + + +def launch_scale_common(config_path: str): + """Launch scale_common.py as a background subprocess without blocking.""" + script_path = str( + pathlib.Path(__file__).resolve().parent.joinpath("scaler/scaling_controller.py") + ) + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + subprocess.Popen( + [sys.executable, script_path, config_path], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + env=env, + start_new_session=True, + ) def run_func(file_path, function_name, *args, **kwargs): @@ -334,8 +353,21 @@ def wait( def main(): - ray.init() + ray.init(address="auto") config, _ = parse_cli_args(sys.argv[1:]) + config_path = None + args = sys.argv[1:] + config.scaling = to_structured_cfg(config.scaling, ScalingConfig) + # Check whether enable scaling or not + if config.scaling.enable_scaling: + if "--config" in args: + idx = args.index("--config") + if idx + 1 < len(args): + config_path = args[idx + 1] + try: + launch_scale_common(config_path) + except Exception as e: + logger.info(f"[RayLauncher] Warning: Failed to scaler.py: {e}") ray_main(config, run_id=0) diff --git a/areal/launcher/scaler/scaling_controller.py b/areal/launcher/scaler/scaling_controller.py new file mode 100644 index 000000000..ebd31ef2e --- /dev/null +++ b/areal/launcher/scaler/scaling_controller.py @@ -0,0 +1,227 @@ +import sys +import threading +from pathlib import Path + +import ray +from fastapi import FastAPI, Request +from omegaconf import OmegaConf +from uvicorn import Config, Server + +import areal.utils.logging as logging +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + ScalingConfig, + to_structured_cfg, + vLLMConfig, +) +from areal.platforms import is_npu_available +from areal.utils import name_resolve +from areal.utils.launcher import get_env_vars, wait_llm_server_addrs +from areal.utils.name_resolve import NameEntryNotFoundError + +logger = logging.getLogger("ScaleUpVLLM") +DEFAULT_MAIN_FUNC_NAME = "main" + + +def run_func(file_path: str, func_name: str, argv: list[str]): + """ + Import module by path and invoke the named function with a single `argv` list. + This matches vllm_server.main(argv) which expects sys.argv[2:]-style args. + """ + import importlib.util + + module_name = file_path.replace("/", "_").replace(".", "_") + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + func = getattr(module, func_name) + return func(argv) + + +def scale_up_vllm( + cfg, config_path: str, n_new_servers: int, expected: int, vllm_entry_point: str +): + # Make sub-configs structured like original launcher + cfg.launcher = to_structured_cfg(cfg.launcher, LauncherConfig) + cfg.cluster = to_structured_cfg(cfg.cluster, ClusterSpecConfig) + cfg.recover = to_structured_cfg(cfg.recover, RecoverConfig) + cfg.vllm = to_structured_cfg(cfg.vllm, vLLMConfig) + experiment_name = cfg.experiment_name + trial_name = cfg.trial_name + + # allocation_mode + allocation_mode = AllocationMode.from_str(cfg.allocation_mode) + vllm_tp_size = allocation_mode.gen.tp_size + n_existing_servers = allocation_mode.gen.dp_size + + cpus_per_gpu = cfg.launcher.inference_server_cpus_per_gpu + mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu # MB per GPU + + # Submit new servers + remote_runner = None # we’ll bind ray.remote per device type + futures = [] + for i in range(n_new_servers): + argv = ["--config", config_path] + # Env vars same as original launcher for inference servers + env_vars = get_env_vars( + cfg.cluster.cluster_name, cfg.launcher.inference_server_env_vars + ) + + if is_npu_available: + remote_runner = ray.remote( + num_cpus=cpus_per_gpu * vllm_tp_size, + resources={"NPU": vllm_tp_size}, + memory=mem_per_gpu * vllm_tp_size * 1024 * 1024, # bytes + runtime_env={"env_vars": env_vars}, + )(run_func) + else: + remote_runner = ray.remote( + num_cpus=cpus_per_gpu * vllm_tp_size, + num_gpus=vllm_tp_size, + memory=mem_per_gpu * vllm_tp_size * 1024 * 1024, # bytes + runtime_env={"env_vars": env_vars}, + )(run_func) + + fut = remote_runner.remote(vllm_entry_point, DEFAULT_MAIN_FUNC_NAME, argv) + futures.append(fut) + + try: + ray.get(fut, timeout=5.0) + except ray.exceptions.GetTimeoutError: + pass + except ray.exceptions.RayTaskError as e: + logger.info(f"[ERROR] server {n_existing_servers + i} crashed immediately:") + logger.info(e) + raise + + # Wait until ALL (old + new) servers are registered + total_expected = expected + vllm_addrs = wait_llm_server_addrs( + experiment_name, + trial_name, + total_expected, + ) + + logger.info("\n[Scale-Up Completed]") + logger.info(f"Total servers expected: {len(vllm_addrs)}") + + +app = FastAPI() +shared_state = { + "cfg": None, + "config_path": None, + "num_rollout": None, + "Vllm_entry_point": None, +} + + +@app.post("/scale_up") +async def http_scale_up(request: Request): + """ + Manual scale-up endpoint. + Example usage: + curl -X POST localhost:8899/scale_up \ + -H "Content-Type: application/json" \ + -d '{"scaled_k": 1}' + """ + body = await request.json() + scaled_k = int(body.get("scaled_k", 1)) + cfg = shared_state["cfg"] + config_path = shared_state["config_path"] + num_rollout = shared_state["num_rollout"] + + if cfg is None or config_path is None: + return {"status": "error", "msg": "Scale server not initialized yet"} + + try: + logger.info(f"[HTTP] Received manual scale-up request: {scaled_k}") + shared_state["num_rollout"] = num_rollout + scaled_k + + name_resolve.add("scale_up_request", {"scaled_k": int(scaled_k)}, replace=True) + scale_up_vllm( + cfg, + config_path, + scaled_k, + num_rollout + scaled_k, + shared_state["vllm_entry_point"], + ) + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + + name_resolve.add("scale_up_done", {"step": 0}) + logger.info( + f"[HTTP] Scale-up done. Total rollout={shared_state['num_rollout']}" + ) + return { + "status": "ok", + "scaled_k": scaled_k, + "new_total": shared_state["num_rollout"], + } + except Exception as e: + logger.error(f"[HTTP] Scale-up failed: {e}") + return {"status": "error", "msg": str(e)} + + +def run_http_server(): + """Run FastAPI server in background thread (non-blocking).""" + config = Config(app, host="0.0.0.0", port=HTTP_SCALE_PORT, log_level="info") + server = Server(config) + + def _serve(): + logger.info(f"[HTTP] Starting manual scale-up server on port {HTTP_SCALE_PORT}") + server.run() + + t = threading.Thread(target=_serve, daemon=False) + t.start() + logger.info("[HTTP] Manual scale-up service started in background.") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + logger.info("Usage: python scaling_controller.py ") + sys.exit(1) + + config_path = sys.argv[1] + cfg = OmegaConf.load(config_path) + name_resolve.reconfigure(cfg.cluster.name_resolve) + experiment_name = cfg.experiment_name + trial_name = cfg.trial_name + + # allocation_mode + allocation_mode = AllocationMode.from_str(cfg.allocation_mode) + # Set-the-experiments-configs for rollout ------------------ + num_rollout = allocation_mode.gen.dp_size + + # Remove all the keys related to scaling before start the experiment + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + logger.info("no delete") + + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + # Init the ray and conncet it to existing cluster + ray.init(address="auto", namespace=f"{experiment_name}_{trial_name}") + + # Get port for scale up + cfg.scaling = to_structured_cfg(cfg.scaling, ScalingConfig) + HTTP_SCALE_PORT = cfg.scaling.scaling_controller_port + + # Run http for scale-up + run_http_server() + + logger.info("[HTTP] Manual scale-up service started in background.") + vllm_entry_point = str(Path(__file__).resolve().parent.parent / "vllm_server.py") + shared_state["cfg"] = cfg + shared_state["config_path"] = config_path + shared_state["num_rollout"] = num_rollout + shared_state["vllm_entry_point"] = vllm_entry_point + logger.info(f"[HTTP] num_rollout initialized to {num_rollout}") diff --git a/areal/thirdparty/vllm/vllm_worker_extension.py b/areal/thirdparty/vllm/vllm_worker_extension.py index c1d471b29..d25f78c07 100644 --- a/areal/thirdparty/vllm/vllm_worker_extension.py +++ b/areal/thirdparty/vllm/vllm_worker_extension.py @@ -82,7 +82,8 @@ def init_update_weight_group( group_name: str, ): if getattr(self, "weight_update_group", None) is not None: - return True, "Success" + # If the group is there, make the current group None and create a new one for scaling + self.weight_update_group = None try: self.weight_update_group = init_custom_process_group( backend=backend, diff --git a/areal/utils/scaling.py b/areal/utils/scaling.py new file mode 100644 index 000000000..1a8435760 --- /dev/null +++ b/areal/utils/scaling.py @@ -0,0 +1,65 @@ +import ast +import time + +from areal.utils import logging, name_resolve +from areal.utils.name_resolve import NameEntryNotFoundError + +logger = logging.getLogger("Scaler") + + +def handle_scale_up(name_resolve: name_resolve, actor, rollout, weight_update_meta): + """ + Handle scale-up logic when scale_up_request is detected. + Requires: name_resolve, actor, rollout. + """ + new_scale = 0 + req_raw = None + try: + req_raw = name_resolve.get("scale_up_request") + new_scale = ast.literal_eval(req_raw)["scaled_k"] + except NameEntryNotFoundError: + logger.info("scale_up_request not found") + pass # no request → don't wait + + logger.info(f"scale_up_request {req_raw}") + + if req_raw: + # Now wait until scale_up_done is posted from scaler process + start = time.time() + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + pass + + while True: + try: + done_raw = name_resolve.get("scale_up_done") + except NameEntryNotFoundError: + done_raw = None + + if done_raw: + logger.info(f"[areal] Scale-up finished: {done_raw}") + name_resolve.add( + "scale_up_time", + {"time": time.time() - start}, + replace=True, + ) + + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + pass + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + # Increase teh number of scale in rollout engine and actor. To get correct world size + actor.scaling_count = actor.scaling_count + new_scale + rollout._engine.backend.scaling_count = ( + rollout._engine.backend.scaling_count + new_scale + ) + rollout._engine.distributed_weight_update_initialized = False + actor._re_init_weight_update_from_distributed(weight_update_meta) + + break + time.sleep(0.5) diff --git a/examples/math/gsm8k_grpo.py b/examples/math/gsm8k_grpo.py index 9c4f4cfce..33e84ca33 100644 --- a/examples/math/gsm8k_grpo.py +++ b/examples/math/gsm8k_grpo.py @@ -10,14 +10,16 @@ from areal.dataset import get_custom_dataset from areal.engine.ppo.actor import FSDPPPOActor from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.engine.vllm_remote import RemotevLLMEngine from areal.platforms import current_platform -from areal.utils import seeding, stats_tracker +from areal.utils import name_resolve, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver +from areal.utils.scaling import handle_scale_up from areal.utils.stats_logger import StatsLogger from areal.workflow.rlvr import RLVRWorkflow @@ -38,6 +40,9 @@ def main(args): allocation_mode = AllocationMode.from_str(config.allocation_mode) parallel_strategy = allocation_mode.train assert parallel_strategy is not None + # Reconfig the name_resolve to connect exisisting name_resolve + name_resolve.reconfigure(config.cluster.name_resolve) + scale = config.scaling.enable_scaling # Initialize train engine actor = FSDPPPOActor(config=config.actor) @@ -70,9 +75,14 @@ def main(args): ) # Initialize inference engine - rollout = RemoteSGLangEngine(config.rollout) + if allocation_mode.gen_backend == "vllm": + rollout = RemotevLLMEngine(config.rollout) + eval_rollout = RemotevLLMEngine(deepcopy(config.rollout)) + elif allocation_mode.gen_backend == "sglang": + rollout = RemoteSGLangEngine(config.rollout) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) - eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control eval_rollout.config.max_head_offpolicyness = int(1e12) eval_rollout.initialize() @@ -178,6 +188,8 @@ def main(args): # pause inference for updating weights, save, and evaluation rollout.pause() + if dist.get_rank() == 0 and scale: + handle_scale_up(name_resolve, actor, rollout, weight_update_meta) with stats_tracker.record_timing("update_weights"): actor.update_weights(weight_update_meta) diff --git a/examples/math/gsm8k_grpo_npu_scale.yaml b/examples/math/gsm8k_grpo_npu_scale.yaml new file mode 100644 index 000000000..02129843c --- /dev/null +++ b/examples/math/gsm8k_grpo_npu_scale.yaml @@ -0,0 +1,157 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: vllm.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# VLLM +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 + +scaling: + enable_scaling: true + mode: manual + scaling_controller_port: 8899 From 3b8152c9d86d06ccd1c49dee7d581b1b9b0f86e2 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Wed, 19 Nov 2025 11:28:09 -0800 Subject: [PATCH 2/5] Fix typo --- areal/launcher/scaler/scaling_controller.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/areal/launcher/scaler/scaling_controller.py b/areal/launcher/scaler/scaling_controller.py index ebd31ef2e..6b6f89d06 100644 --- a/areal/launcher/scaler/scaling_controller.py +++ b/areal/launcher/scaler/scaling_controller.py @@ -115,7 +115,7 @@ def scale_up_vllm( "cfg": None, "config_path": None, "num_rollout": None, - "Vllm_entry_point": None, + "vllm_entry_point": None, } From 375b2a0a5f502e765d88decc35183dfc634cfbaa Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Wed, 19 Nov 2025 11:22:20 -0800 Subject: [PATCH 3/5] Adjust logging based on Gemini suggestion --- areal/utils/scaling.py | 68 +++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/areal/utils/scaling.py b/areal/utils/scaling.py index 1a8435760..cf3a4ac13 100644 --- a/areal/utils/scaling.py +++ b/areal/utils/scaling.py @@ -12,54 +12,48 @@ def handle_scale_up(name_resolve: name_resolve, actor, rollout, weight_update_me Handle scale-up logic when scale_up_request is detected. Requires: name_resolve, actor, rollout. """ - new_scale = 0 - req_raw = None try: req_raw = name_resolve.get("scale_up_request") new_scale = ast.literal_eval(req_raw)["scaled_k"] except NameEntryNotFoundError: - logger.info("scale_up_request not found") - pass # no request → don't wait + return - logger.info(f"scale_up_request {req_raw}") + logger.info(f"Handling scale_up_request: {req_raw}") - if req_raw: - # Now wait until scale_up_done is posted from scaler process - start = time.time() + # Now wait until scale_up_done is posted from scaling_controller + start = time.time() + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + pass + + while True: try: - name_resolve.delete("scale_up_request") + done_raw = name_resolve.get("scale_up_done") except NameEntryNotFoundError: - pass + done_raw = None + + if done_raw: + logger.info(f"[areal] Scale-up finished: {done_raw}") + name_resolve.add( + "scale_up_time", + {"time": time.time() - start}, + replace=True, + ) - while True: try: - done_raw = name_resolve.get("scale_up_done") + name_resolve.delete("scale_up_done") except NameEntryNotFoundError: - done_raw = None + pass - if done_raw: - logger.info(f"[areal] Scale-up finished: {done_raw}") - name_resolve.add( - "scale_up_time", - {"time": time.time() - start}, - replace=True, - ) + # Increase the number of scale in rollout engine and actor. To get correct world size + actor.scaling_count = actor.scaling_count + new_scale + rollout._engine.backend.scaling_count = ( + rollout._engine.backend.scaling_count + new_scale + ) + rollout._engine.distributed_weight_update_initialized = False + actor._re_init_weight_update_from_distributed(weight_update_meta) - try: - name_resolve.delete("scale_up_request") - except NameEntryNotFoundError: - pass - try: - name_resolve.delete("scale_up_done") - except NameEntryNotFoundError: - pass - # Increase teh number of scale in rollout engine and actor. To get correct world size - actor.scaling_count = actor.scaling_count + new_scale - rollout._engine.backend.scaling_count = ( - rollout._engine.backend.scaling_count + new_scale - ) - rollout._engine.distributed_weight_update_initialized = False - actor._re_init_weight_update_from_distributed(weight_update_meta) + break - break - time.sleep(0.5) + time.sleep(0.5) From fb5b72fe66da3b4a6d7bb792ddc76901fa395be2 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Wed, 19 Nov 2025 11:46:57 -0800 Subject: [PATCH 4/5] Fix config loading --- areal/launcher/ray.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/areal/launcher/ray.py b/areal/launcher/ray.py index 77459ee26..a9cde4edd 100644 --- a/areal/launcher/ray.py +++ b/areal/launcher/ray.py @@ -354,18 +354,12 @@ def wait( def main(): ray.init(address="auto") - config, _ = parse_cli_args(sys.argv[1:]) - config_path = None - args = sys.argv[1:] + config, config_file = parse_cli_args(sys.argv[1:]) config.scaling = to_structured_cfg(config.scaling, ScalingConfig) # Check whether enable scaling or not if config.scaling.enable_scaling: - if "--config" in args: - idx = args.index("--config") - if idx + 1 < len(args): - config_path = args[idx + 1] try: - launch_scale_common(config_path) + launch_scale_common(str(config_file)) except Exception as e: logger.info(f"[RayLauncher] Warning: Failed to scaler.py: {e}") ray_main(config, run_id=0) From 4f6eb013b0d0abd6dfea8d41329298eb75025615 Mon Sep 17 00:00:00 2001 From: Huawei Vancouver ICI Lab Date: Thu, 20 Nov 2025 10:44:46 -0800 Subject: [PATCH 5/5] Add locks etc. in scaling controller --- areal/launcher/scaler/scaling_controller.py | 87 ++++++++++++--------- 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/areal/launcher/scaler/scaling_controller.py b/areal/launcher/scaler/scaling_controller.py index 6b6f89d06..bf4722e31 100644 --- a/areal/launcher/scaler/scaling_controller.py +++ b/areal/launcher/scaler/scaling_controller.py @@ -1,3 +1,4 @@ +import importlib.util import sys import threading from pathlib import Path @@ -31,8 +32,6 @@ def run_func(file_path: str, func_name: str, argv: list[str]): Import module by path and invoke the named function with a single `argv` list. This matches vllm_server.main(argv) which expects sys.argv[2:]-style args. """ - import importlib.util - module_name = file_path.replace("/", "_").replace(".", "_") spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) @@ -53,13 +52,12 @@ def scale_up_vllm( experiment_name = cfg.experiment_name trial_name = cfg.trial_name - # allocation_mode allocation_mode = AllocationMode.from_str(cfg.allocation_mode) vllm_tp_size = allocation_mode.gen.tp_size - n_existing_servers = allocation_mode.gen.dp_size + n_existing_servers = expected - n_new_servers cpus_per_gpu = cfg.launcher.inference_server_cpus_per_gpu - mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu # MB per GPU + mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu # Submit new servers remote_runner = None # we’ll bind ray.remote per device type @@ -117,12 +115,13 @@ def scale_up_vllm( "num_rollout": None, "vllm_entry_point": None, } +shared_state_lock = threading.Lock() @app.post("/scale_up") async def http_scale_up(request: Request): """ - Manual scale-up endpoint. + Scaling controller endpoint. Example usage: curl -X POST localhost:8899/scale_up \ -H "Content-Type: application/json" \ @@ -130,61 +129,70 @@ async def http_scale_up(request: Request): """ body = await request.json() scaled_k = int(body.get("scaled_k", 1)) - cfg = shared_state["cfg"] - config_path = shared_state["config_path"] - num_rollout = shared_state["num_rollout"] - if cfg is None or config_path is None: - return {"status": "error", "msg": "Scale server not initialized yet"} + with shared_state_lock: + cfg = shared_state["cfg"] + config_path = shared_state["config_path"] + num_rollout = shared_state["num_rollout"] + vllm_entry_point = shared_state["vllm_entry_point"] + + # More complete initialization check + if ( + cfg is None + or config_path is None + or num_rollout is None + or vllm_entry_point is None + ): + return {"status": "error", "msg": "Scale server not initialized yet"} + + new_total = num_rollout + scaled_k + shared_state["num_rollout"] = new_total try: logger.info(f"[HTTP] Received manual scale-up request: {scaled_k}") - shared_state["num_rollout"] = num_rollout + scaled_k - name_resolve.add("scale_up_request", {"scaled_k": int(scaled_k)}, replace=True) + scale_up_vllm( cfg, config_path, scaled_k, - num_rollout + scaled_k, - shared_state["vllm_entry_point"], + new_total, + vllm_entry_point, ) try: name_resolve.delete("scale_up_done") except NameEntryNotFoundError: pass - name_resolve.add("scale_up_done", {"step": 0}) - logger.info( - f"[HTTP] Scale-up done. Total rollout={shared_state['num_rollout']}" - ) + name_resolve.add("scale_up_done", {"done": 1}) + logger.info(f"[HTTP] Scale-up done. Total rollout={new_total}") return { "status": "ok", "scaled_k": scaled_k, - "new_total": shared_state["num_rollout"], + "new_total": new_total, } except Exception as e: logger.error(f"[HTTP] Scale-up failed: {e}") return {"status": "error", "msg": str(e)} -def run_http_server(): +def run_http_server(port: int): """Run FastAPI server in background thread (non-blocking).""" - config = Config(app, host="0.0.0.0", port=HTTP_SCALE_PORT, log_level="info") + config = Config(app, host="0.0.0.0", port=port, log_level="info") server = Server(config) def _serve(): - logger.info(f"[HTTP] Starting manual scale-up server on port {HTTP_SCALE_PORT}") + logger.info(f"[HTTP] Starting scaling controller server on port {port}") server.run() t = threading.Thread(target=_serve, daemon=False) t.start() - logger.info("[HTTP] Manual scale-up service started in background.") + logger.info("[HTTP] Scaling controller server started in background.") if __name__ == "__main__": if len(sys.argv) < 2: - logger.info("Usage: python scaling_controller.py ") + logger.info("Usage: python scaling_controller ") sys.exit(1) config_path = sys.argv[1] @@ -193,35 +201,38 @@ def _serve(): experiment_name = cfg.experiment_name trial_name = cfg.trial_name - # allocation_mode allocation_mode = AllocationMode.from_str(cfg.allocation_mode) - # Set-the-experiments-configs for rollout ------------------ num_rollout = allocation_mode.gen.dp_size # Remove all the keys related to scaling before start the experiment try: name_resolve.delete("scale_up_request") except NameEntryNotFoundError: - logger.info("no delete") + pass try: name_resolve.delete("scale_up_done") except NameEntryNotFoundError: pass - # Init the ray and conncet it to existing cluster + + # Init ray and connect it to existing cluster ray.init(address="auto", namespace=f"{experiment_name}_{trial_name}") # Get port for scale up cfg.scaling = to_structured_cfg(cfg.scaling, ScalingConfig) - HTTP_SCALE_PORT = cfg.scaling.scaling_controller_port - - # Run http for scale-up - run_http_server() + port = cfg.scaling.scaling_controller_port - logger.info("[HTTP] Manual scale-up service started in background.") + # Resolve vLLM entry point vllm_entry_point = str(Path(__file__).resolve().parent.parent / "vllm_server.py") - shared_state["cfg"] = cfg - shared_state["config_path"] = config_path - shared_state["num_rollout"] = num_rollout - shared_state["vllm_entry_point"] = vllm_entry_point + + # Initialize shared_state atomically before starting HTTP server + with shared_state_lock: + shared_state["cfg"] = cfg + shared_state["config_path"] = config_path + shared_state["num_rollout"] = num_rollout + shared_state["vllm_entry_point"] = vllm_entry_point + logger.info(f"[HTTP] num_rollout initialized to {num_rollout}") + + # Run http for scale-up (after shared_state is fully initialized) + run_http_server(port)