diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 039177e08..086ce7a2a 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -928,6 +928,29 @@ class InferenceEngineConfig: ) +@dataclass +class ScalingConfig: + """Configuration for dynamic scaling of inference/training servers.""" + + enable_scaling: bool = field( + default=False, + metadata={"help": "Whether scaling is enabled (True/False)."}, + ) + + mode: str = field( + default="manual", + metadata={ + "help": "Scaling mode — can be 'manual' or 'auto'.", + "choices": ["manual", "auto"], + }, + ) + + scaling_controller_port: int = field( + default=8899, + metadata={"help": "HTTP port for the scale-up service endpoint."}, + ) + + @dataclass class _Timer: experiment_name: str = MISSING @@ -1341,6 +1364,8 @@ class BaseExperimentConfig: scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) + scaling: ScalingConfig = field(default_factory=ScalingConfig) + @dataclass class SFTConfig(BaseExperimentConfig): diff --git a/areal/core/remote_inf_engine.py b/areal/core/remote_inf_engine.py index 34f4b2b00..f49149ec2 100644 --- a/areal/core/remote_inf_engine.py +++ b/areal/core/remote_inf_engine.py @@ -253,6 +253,7 @@ def __init__( self.workflow_executor: WorkflowExecutor self.local_server_processes: list[LocalInfServerInfo] = [] + self.update_servers = False def _get_or_create_session(self) -> aiohttp.ClientSession: """Get or create a ClientSession for the current thread/event loop. @@ -424,6 +425,27 @@ def get_version(self): with self.lock: return self._version + def refresh_addresses(self, new_addresses: list[str]) -> None: + """ + Refresh the list of available servers dynamically. + + Args: + new_addresses (list[str]): Updated list of server addresses. + """ + if not new_addresses: + raise RuntimeError("No servers provided when refreshing addresses.") + + # Only log if there's an actual change + if new_addresses != self.addresses: + self.logger.info(f"Refreshing server addresses: {new_addresses}") + + # Replace with the new set + self.addresses = new_addresses + + # Clamp server_idx to valid range + if self.server_idx >= len(self.addresses): + self.server_idx = 0 + def choose_server(self) -> str: """Choose a server based on the scheduling policy. @@ -437,7 +459,16 @@ def choose_server(self) -> str: NotImplementedError If schedule policy other than round-robin is used """ + + if self.update_servers: + name = names.gen_servers( + self.config.experiment_name, self.config.trial_name + ) + vllm_addrs = name_resolve.get_subtree(name) + self.refresh_addresses(vllm_addrs) + self.update_servers = False if self.config.schedule_policy == "round_robin": + self.server_idx %= len(self.addresses) server = self.addresses[self.server_idx] self.server_idx = (self.server_idx + 1) % len(self.addresses) return server @@ -591,6 +622,12 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]: assert meta.type == current_platform.communication_backend assert not self.distributed_weight_update_initialized + # Refresh the gen servers if there is scale request + name = names.gen_servers(self.config.experiment_name, self.config.trial_name) + vllm_addrs = name_resolve.get_subtree(name) + if vllm_addrs != self.addresses: + self.refresh_addresses(vllm_addrs) + fut = self.executor.submit( _init_weights_update_group_remote, self.backend, @@ -845,6 +882,8 @@ def pause(self): """Pause request submission for async rollout. Used during evaluation to prevent data over generation. """ + # Whenever pause for update weight, make update_servers True to dispatch request to new servers + self.update_servers = True return self.workflow_executor.pause() def resume(self): diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index 237a29d19..0f18f5dc4 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -83,6 +83,8 @@ def __init__(self, config: TrainEngineConfig): self.rank: int self.dp_head: int self.dp_rank: int + self.scaling_count = 0 + self.create_group_count = 0 @property def data_parallel_group(self) -> dist.ProcessGroup: @@ -376,14 +378,21 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta): ) self.weight_update_group = init_custom_process_group( backend=current_platform.communication_backend, - world_size=meta.alloc_mode.gen.world_size + 1, + world_size=meta.alloc_mode.gen.world_size + 1 + self.scaling_count, init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}", rank=0, - group_name=meta.nccl_group_name, + group_name=meta.nccl_group_name + str(self.create_group_count), timeout=DIST_GROUP_DEFAULT_TIMEOUT, ) + self.create_group_count += 1 fut.result() + self.rollout_engine._engine.backend.create_group_count += 1 + + def _re_init_weight_update_from_distributed(self, meta: WeightUpdateMeta): + self.weight_update_group_initialized = False + self._init_weight_update_from_distributed(meta) + self.weight_update_group_initialized = True @trace_perf("fsdp_engine.update_weights_from_distributed", category="comm") def _update_weights_from_distributed(self, meta: WeightUpdateMeta): diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 306a0fc45..2abb5e703 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -29,6 +29,10 @@ class VLLMBackend: """vLLM-specific backend implementation for remote inference.""" + def __init__(self): + self.scaling_count = 0 + self.create_group_count = 0 + def build_generation_request( self, req: ModelRequest, with_lora: bool ) -> HttpRequest: @@ -106,7 +110,8 @@ def build_distributed_weight_update_requests( "names": [pspec.name for pspec in param_specs], "dtypes": [pspec.dtype for pspec in param_specs], "shapes": [pspec.shape for pspec in param_specs], - "group_name": meta.nccl_group_name, + "group_name": meta.nccl_group_name + + str(self.create_group_count), }, ), HttpRequest( @@ -128,9 +133,9 @@ def build_init_weights_group_request( "master_address": meta.nccl_master_address, "master_port": str(meta.nccl_master_port), "rank_offset": rank_offset, - "world_size": meta.alloc_mode.gen.world_size + 1, + "world_size": meta.alloc_mode.gen.world_size + 1 + self.scaling_count, "backend": current_platform.communication_backend, - "group_name": meta.nccl_group_name, + "group_name": meta.nccl_group_name + str(self.create_group_count), } return HttpRequest(endpoint="/areal_init_weights_update_group", payload=payload) diff --git a/areal/launcher/ray.py b/areal/launcher/ray.py index 7832a1419..a9cde4edd 100644 --- a/areal/launcher/ray.py +++ b/areal/launcher/ray.py @@ -1,6 +1,8 @@ import importlib.util +import os import pathlib import re +import subprocess import sys import time from collections.abc import Callable @@ -18,6 +20,7 @@ ClusterSpecConfig, LauncherConfig, RecoverConfig, + ScalingConfig, SGLangConfig, parse_cli_args, to_structured_cfg, @@ -41,7 +44,23 @@ RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds DEFAULT_MAIN_FUNC_NAME = "main" RAY_LAUNCHER = None -RECOVER_TIME_INTERVAL = 10 # seconds +RECOVER_TIME_INTERVAL = 10 # second + + +def launch_scale_common(config_path: str): + """Launch scale_common.py as a background subprocess without blocking.""" + script_path = str( + pathlib.Path(__file__).resolve().parent.joinpath("scaler/scaling_controller.py") + ) + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + subprocess.Popen( + [sys.executable, script_path, config_path], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + env=env, + start_new_session=True, + ) def run_func(file_path, function_name, *args, **kwargs): @@ -334,8 +353,15 @@ def wait( def main(): - ray.init() - config, _ = parse_cli_args(sys.argv[1:]) + ray.init(address="auto") + config, config_file = parse_cli_args(sys.argv[1:]) + config.scaling = to_structured_cfg(config.scaling, ScalingConfig) + # Check whether enable scaling or not + if config.scaling.enable_scaling: + try: + launch_scale_common(str(config_file)) + except Exception as e: + logger.info(f"[RayLauncher] Warning: Failed to scaler.py: {e}") ray_main(config, run_id=0) diff --git a/areal/launcher/scaler/scaling_controller.py b/areal/launcher/scaler/scaling_controller.py new file mode 100644 index 000000000..bf4722e31 --- /dev/null +++ b/areal/launcher/scaler/scaling_controller.py @@ -0,0 +1,238 @@ +import importlib.util +import sys +import threading +from pathlib import Path + +import ray +from fastapi import FastAPI, Request +from omegaconf import OmegaConf +from uvicorn import Config, Server + +import areal.utils.logging as logging +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import ( + ClusterSpecConfig, + LauncherConfig, + RecoverConfig, + ScalingConfig, + to_structured_cfg, + vLLMConfig, +) +from areal.platforms import is_npu_available +from areal.utils import name_resolve +from areal.utils.launcher import get_env_vars, wait_llm_server_addrs +from areal.utils.name_resolve import NameEntryNotFoundError + +logger = logging.getLogger("ScaleUpVLLM") +DEFAULT_MAIN_FUNC_NAME = "main" + + +def run_func(file_path: str, func_name: str, argv: list[str]): + """ + Import module by path and invoke the named function with a single `argv` list. + This matches vllm_server.main(argv) which expects sys.argv[2:]-style args. + """ + module_name = file_path.replace("/", "_").replace(".", "_") + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + func = getattr(module, func_name) + return func(argv) + + +def scale_up_vllm( + cfg, config_path: str, n_new_servers: int, expected: int, vllm_entry_point: str +): + # Make sub-configs structured like original launcher + cfg.launcher = to_structured_cfg(cfg.launcher, LauncherConfig) + cfg.cluster = to_structured_cfg(cfg.cluster, ClusterSpecConfig) + cfg.recover = to_structured_cfg(cfg.recover, RecoverConfig) + cfg.vllm = to_structured_cfg(cfg.vllm, vLLMConfig) + experiment_name = cfg.experiment_name + trial_name = cfg.trial_name + + allocation_mode = AllocationMode.from_str(cfg.allocation_mode) + vllm_tp_size = allocation_mode.gen.tp_size + n_existing_servers = expected - n_new_servers + + cpus_per_gpu = cfg.launcher.inference_server_cpus_per_gpu + mem_per_gpu = cfg.launcher.inference_server_mem_per_gpu + + # Submit new servers + remote_runner = None # we’ll bind ray.remote per device type + futures = [] + for i in range(n_new_servers): + argv = ["--config", config_path] + # Env vars same as original launcher for inference servers + env_vars = get_env_vars( + cfg.cluster.cluster_name, cfg.launcher.inference_server_env_vars + ) + + if is_npu_available: + remote_runner = ray.remote( + num_cpus=cpus_per_gpu * vllm_tp_size, + resources={"NPU": vllm_tp_size}, + memory=mem_per_gpu * vllm_tp_size * 1024 * 1024, # bytes + runtime_env={"env_vars": env_vars}, + )(run_func) + else: + remote_runner = ray.remote( + num_cpus=cpus_per_gpu * vllm_tp_size, + num_gpus=vllm_tp_size, + memory=mem_per_gpu * vllm_tp_size * 1024 * 1024, # bytes + runtime_env={"env_vars": env_vars}, + )(run_func) + + fut = remote_runner.remote(vllm_entry_point, DEFAULT_MAIN_FUNC_NAME, argv) + futures.append(fut) + + try: + ray.get(fut, timeout=5.0) + except ray.exceptions.GetTimeoutError: + pass + except ray.exceptions.RayTaskError as e: + logger.info(f"[ERROR] server {n_existing_servers + i} crashed immediately:") + logger.info(e) + raise + + # Wait until ALL (old + new) servers are registered + total_expected = expected + vllm_addrs = wait_llm_server_addrs( + experiment_name, + trial_name, + total_expected, + ) + + logger.info("\n[Scale-Up Completed]") + logger.info(f"Total servers expected: {len(vllm_addrs)}") + + +app = FastAPI() +shared_state = { + "cfg": None, + "config_path": None, + "num_rollout": None, + "vllm_entry_point": None, +} +shared_state_lock = threading.Lock() + + +@app.post("/scale_up") +async def http_scale_up(request: Request): + """ + Scaling controller endpoint. + Example usage: + curl -X POST localhost:8899/scale_up \ + -H "Content-Type: application/json" \ + -d '{"scaled_k": 1}' + """ + body = await request.json() + scaled_k = int(body.get("scaled_k", 1)) + + with shared_state_lock: + cfg = shared_state["cfg"] + config_path = shared_state["config_path"] + num_rollout = shared_state["num_rollout"] + vllm_entry_point = shared_state["vllm_entry_point"] + + # More complete initialization check + if ( + cfg is None + or config_path is None + or num_rollout is None + or vllm_entry_point is None + ): + return {"status": "error", "msg": "Scale server not initialized yet"} + + new_total = num_rollout + scaled_k + shared_state["num_rollout"] = new_total + + try: + logger.info(f"[HTTP] Received manual scale-up request: {scaled_k}") + name_resolve.add("scale_up_request", {"scaled_k": int(scaled_k)}, replace=True) + + scale_up_vllm( + cfg, + config_path, + scaled_k, + new_total, + vllm_entry_point, + ) + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + + name_resolve.add("scale_up_done", {"done": 1}) + logger.info(f"[HTTP] Scale-up done. Total rollout={new_total}") + return { + "status": "ok", + "scaled_k": scaled_k, + "new_total": new_total, + } + except Exception as e: + logger.error(f"[HTTP] Scale-up failed: {e}") + return {"status": "error", "msg": str(e)} + + +def run_http_server(port: int): + """Run FastAPI server in background thread (non-blocking).""" + config = Config(app, host="0.0.0.0", port=port, log_level="info") + server = Server(config) + + def _serve(): + logger.info(f"[HTTP] Starting scaling controller server on port {port}") + server.run() + + t = threading.Thread(target=_serve, daemon=False) + t.start() + logger.info("[HTTP] Scaling controller server started in background.") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + logger.info("Usage: python scaling_controller ") + sys.exit(1) + + config_path = sys.argv[1] + cfg = OmegaConf.load(config_path) + name_resolve.reconfigure(cfg.cluster.name_resolve) + experiment_name = cfg.experiment_name + trial_name = cfg.trial_name + + allocation_mode = AllocationMode.from_str(cfg.allocation_mode) + num_rollout = allocation_mode.gen.dp_size + + # Remove all the keys related to scaling before start the experiment + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + pass + + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + + # Init ray and connect it to existing cluster + ray.init(address="auto", namespace=f"{experiment_name}_{trial_name}") + + # Get port for scale up + cfg.scaling = to_structured_cfg(cfg.scaling, ScalingConfig) + port = cfg.scaling.scaling_controller_port + + # Resolve vLLM entry point + vllm_entry_point = str(Path(__file__).resolve().parent.parent / "vllm_server.py") + + # Initialize shared_state atomically before starting HTTP server + with shared_state_lock: + shared_state["cfg"] = cfg + shared_state["config_path"] = config_path + shared_state["num_rollout"] = num_rollout + shared_state["vllm_entry_point"] = vllm_entry_point + + logger.info(f"[HTTP] num_rollout initialized to {num_rollout}") + + # Run http for scale-up (after shared_state is fully initialized) + run_http_server(port) diff --git a/areal/thirdparty/vllm/vllm_worker_extension.py b/areal/thirdparty/vllm/vllm_worker_extension.py index c1d471b29..d25f78c07 100644 --- a/areal/thirdparty/vllm/vllm_worker_extension.py +++ b/areal/thirdparty/vllm/vllm_worker_extension.py @@ -82,7 +82,8 @@ def init_update_weight_group( group_name: str, ): if getattr(self, "weight_update_group", None) is not None: - return True, "Success" + # If the group is there, make the current group None and create a new one for scaling + self.weight_update_group = None try: self.weight_update_group = init_custom_process_group( backend=backend, diff --git a/areal/utils/scaling.py b/areal/utils/scaling.py new file mode 100644 index 000000000..cf3a4ac13 --- /dev/null +++ b/areal/utils/scaling.py @@ -0,0 +1,59 @@ +import ast +import time + +from areal.utils import logging, name_resolve +from areal.utils.name_resolve import NameEntryNotFoundError + +logger = logging.getLogger("Scaler") + + +def handle_scale_up(name_resolve: name_resolve, actor, rollout, weight_update_meta): + """ + Handle scale-up logic when scale_up_request is detected. + Requires: name_resolve, actor, rollout. + """ + try: + req_raw = name_resolve.get("scale_up_request") + new_scale = ast.literal_eval(req_raw)["scaled_k"] + except NameEntryNotFoundError: + return + + logger.info(f"Handling scale_up_request: {req_raw}") + + # Now wait until scale_up_done is posted from scaling_controller + start = time.time() + try: + name_resolve.delete("scale_up_request") + except NameEntryNotFoundError: + pass + + while True: + try: + done_raw = name_resolve.get("scale_up_done") + except NameEntryNotFoundError: + done_raw = None + + if done_raw: + logger.info(f"[areal] Scale-up finished: {done_raw}") + name_resolve.add( + "scale_up_time", + {"time": time.time() - start}, + replace=True, + ) + + try: + name_resolve.delete("scale_up_done") + except NameEntryNotFoundError: + pass + + # Increase the number of scale in rollout engine and actor. To get correct world size + actor.scaling_count = actor.scaling_count + new_scale + rollout._engine.backend.scaling_count = ( + rollout._engine.backend.scaling_count + new_scale + ) + rollout._engine.distributed_weight_update_initialized = False + actor._re_init_weight_update_from_distributed(weight_update_meta) + + break + + time.sleep(0.5) diff --git a/examples/math/gsm8k_grpo.py b/examples/math/gsm8k_grpo.py index 9c4f4cfce..33e84ca33 100644 --- a/examples/math/gsm8k_grpo.py +++ b/examples/math/gsm8k_grpo.py @@ -10,14 +10,16 @@ from areal.dataset import get_custom_dataset from areal.engine.ppo.actor import FSDPPPOActor from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.engine.vllm_remote import RemotevLLMEngine from areal.platforms import current_platform -from areal.utils import seeding, stats_tracker +from areal.utils import name_resolve, seeding, stats_tracker from areal.utils.dataloader import create_dataloader from areal.utils.device import log_gpu_stats from areal.utils.evaluator import Evaluator from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.recover import RecoverHandler from areal.utils.saver import Saver +from areal.utils.scaling import handle_scale_up from areal.utils.stats_logger import StatsLogger from areal.workflow.rlvr import RLVRWorkflow @@ -38,6 +40,9 @@ def main(args): allocation_mode = AllocationMode.from_str(config.allocation_mode) parallel_strategy = allocation_mode.train assert parallel_strategy is not None + # Reconfig the name_resolve to connect exisisting name_resolve + name_resolve.reconfigure(config.cluster.name_resolve) + scale = config.scaling.enable_scaling # Initialize train engine actor = FSDPPPOActor(config=config.actor) @@ -70,9 +75,14 @@ def main(args): ) # Initialize inference engine - rollout = RemoteSGLangEngine(config.rollout) + if allocation_mode.gen_backend == "vllm": + rollout = RemotevLLMEngine(config.rollout) + eval_rollout = RemotevLLMEngine(deepcopy(config.rollout)) + elif allocation_mode.gen_backend == "sglang": + rollout = RemoteSGLangEngine(config.rollout) + eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) rollout.initialize(train_data_parallel_size=parallel_strategy.dp_size) - eval_rollout = RemoteSGLangEngine(deepcopy(config.rollout)) + # NOTE: eval does not have any offpolicyness control eval_rollout.config.max_head_offpolicyness = int(1e12) eval_rollout.initialize() @@ -178,6 +188,8 @@ def main(args): # pause inference for updating weights, save, and evaluation rollout.pause() + if dist.get_rank() == 0 and scale: + handle_scale_up(name_resolve, actor, rollout, weight_update_meta) with stats_tracker.record_timing("update_weights"): actor.update_weights(weight_update_meta) diff --git a/examples/math/gsm8k_grpo_npu_scale.yaml b/examples/math/gsm8k_grpo_npu_scale.yaml new file mode 100644 index 000000000..02129843c --- /dev/null +++ b/examples/math/gsm8k_grpo_npu_scale.yaml @@ -0,0 +1,157 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +total_train_epochs: 10 +tokenizer_path: ${actor.path} +async_training: true + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: vllm.d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-1.5B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + backend: fsdp + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + backend: fsdp + +# VLLM +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 + +scaling: + enable_scaling: true + mode: manual + scaling_controller_port: 8899