From 7432064fbef7583bf73f9bfaebd1d75caf74798e Mon Sep 17 00:00:00 2001 From: Gursimran Singh Date: Tue, 16 Dec 2025 15:27:47 -0800 Subject: [PATCH 1/4] Working and tested examples for grpo single controller lora using the vllm backend --- .../lora/gsm8k_grpo_vllm_single_controller.py | 242 ++++++++++++++++++ .../gsm8k_grpo_vllm_single_controller.yaml | 186 ++++++++++++++ 2 files changed, 428 insertions(+) create mode 100644 examples/lora/gsm8k_grpo_vllm_single_controller.py create mode 100644 examples/lora/gsm8k_grpo_vllm_single_controller.yaml diff --git a/examples/lora/gsm8k_grpo_vllm_single_controller.py b/examples/lora/gsm8k_grpo_vllm_single_controller.py new file mode 100644 index 000000000..c25e26c3b --- /dev/null +++ b/examples/lora/gsm8k_grpo_vllm_single_controller.py @@ -0,0 +1,242 @@ +import os +import sys + +from areal.api.alloc_mode import AllocationMode +from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config, vLLMConfig +from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta +from areal.controller.rollout_controller import RolloutController +from areal.controller.train_controller import TrainController +from areal.dataset import get_custom_dataset +from areal.engine.ppo.actor import FSDPPPOActor +from areal.engine.sglang_remote import RemoteSGLangEngine +from areal.engine.vllm_remote import RemotevLLMEngine +from areal.scheduler.local import LocalScheduler +from areal.utils import stats_tracker +from areal.utils.data import ( + cycle_dataloader, +) +from areal.utils.dataloader import create_dataloader +from areal.utils.device import log_gpu_stats +from areal.utils.evaluator import Evaluator +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.recover import RecoverHandler +from areal.utils.saver import Saver +from areal.utils.stats_logger import StatsLogger + + +def main(args): + config, _ = load_expr_config(args, GRPOConfig) + config: GRPOConfig + + tokenizer = load_hf_tokenizer(config.tokenizer_path) + + # Create dataset and dataloaders + train_dataset = get_custom_dataset( + split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + ) + + train_dataloader = create_dataloader( + train_dataset, + rank=0, + world_size=1, + dataset_config=config.train_dataset, + ) + + ft_spec = FinetuneSpec( + total_train_epochs=config.total_train_epochs, + dataset_size=len(train_dataloader) * config.train_dataset.batch_size, + train_batch_size=config.train_dataset.batch_size, + ) + + # Initialize scheduler + scheduler = LocalScheduler(exp_config=config) + + # Initialize train controller + allocation_mode = AllocationMode.from_str(config.allocation_mode) + actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) + actor.initialize( + role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Initialize inference engine + + if allocation_mode.gen_backend == "sglang": + engine_class = RemoteSGLangEngine + server_args = SGLangConfig.build_args( + sglang_config=config.sglang, + tp_size=allocation_mode.gen.tp_size, + base_gpu_id=0, + ) + elif allocation_mode.gen_backend == "vllm": + engine_class = RemotevLLMEngine + server_args = vLLMConfig.build_args( + vllm_config=config.vllm, + tp_size=allocation_mode.gen.tp_size, + pp_size=allocation_mode.gen.pp_size, + ) + else: + raise ValueError(f"Unsupported gen_backend: '{allocation_mode.gen_backend}'") + + # import debugpy, os + # debugpy.listen(("0.0.0.0", 2500)) + # debugpy.wait_for_client() + # debugpy.breakpoint() + + rollout = RolloutController( + engine_class, config=config.rollout, scheduler=scheduler + ) + rollout.initialize( + role="rollout", + alloc_mode=allocation_mode, + server_args=server_args, + ) + + if config.actor.weight_update_mode == "disk": + weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.saver.experiment_name, + trial_name=config.saver.trial_name, + file_root=config.saver.fileroot, + use_lora=config.actor.use_lora, + lora_name=config.gconfig.lora_name, + lora_int_id=1, + base_model_name=config.actor.path, + ) + elif config.actor.weight_update_mode == "xccl": + weight_update_meta = WeightUpdateMeta.from_fsdp_xccl( + allocation_mode, + use_lora=config.actor.use_lora, + lora_name=config.gconfig.lora_name, + lora_int_id=1, # hard coded for the single lora example + base_model_name=config.actor.path, + ) + + actor.connect_engine(rollout, weight_update_meta) + + ref = None + if config.actor.kl_ctl > 0 and config.ref is not None: + ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) + ref.initialize( + role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None + ) + + # Run training. + saver = Saver(config.saver, ft_spec) + stats_logger = StatsLogger(config, ft_spec) + evaluator = Evaluator(config.evaluator, ft_spec) + + recover_handler = RecoverHandler(config.recover, ft_spec) + + try: + recover_info = recover_handler.load( + actor, + saver, + evaluator, + stats_logger, + train_dataloader, + inference_engine=rollout, + weight_update_meta=weight_update_meta, + ) + start_step = ( + recover_info.last_step_info.next().global_step + if recover_info is not None + else 0 + ) + + total_epochs = config.total_train_epochs + steps_per_epoch = len(train_dataloader) + max_steps = total_epochs * steps_per_epoch + + data_generator = cycle_dataloader(train_dataloader) + for global_step in range(start_step, max_steps): + epoch = global_step // steps_per_epoch + step = global_step % steps_per_epoch + step_info = StepInfo( + global_step=global_step, + epoch=epoch, + epoch_step=step, + steps_per_epoch=steps_per_epoch, + ) + + with stats_tracker.record_timing("rollout"): + workflow_kwargs = dict( + reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", + gconfig=config.gconfig, + tokenizer=config.tokenizer_path, + enable_thinking=False, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + "generated", + ), + ) + if config.rollout.max_head_offpolicyness > 0: + batch = actor.prepare_batch( + train_dataloader, + workflow="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=workflow_kwargs, + ) + else: + batch = actor.rollout_batch( + next(data_generator), + workflow="areal.workflow.rlvr.RLVRWorkflow", + workflow_kwargs=workflow_kwargs, + ) + + if config.actor.recompute_logprob or config.actor.use_decoupled_loss: + with stats_tracker.record_timing("recompute_logp"): + logp = actor.compute_logp(batch) + batch["prox_logp"] = logp + log_gpu_stats("recompute logp") + + if ref is not None: + with stats_tracker.record_timing("ref_logp"): + batch["ref_logp"] = ref.compute_logp(batch) + log_gpu_stats("ref logp") + + with stats_tracker.record_timing("compute_advantage"): + batch = actor.compute_advantages(batch) + log_gpu_stats("compute advantages") + + with stats_tracker.record_timing("train_step"): + actor.ppo_update(batch) + actor.step_lr_scheduler() + log_gpu_stats("ppo update") + + # pause inference for updating weights, save, and evaluation + rollout.pause() + + with stats_tracker.record_timing("update_weights"): + actor.update_weights(weight_update_meta) + + actor.set_version(global_step + 1) + rollout.set_version(global_step + 1) + + with stats_tracker.record_timing("save"): + saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) + + with stats_tracker.record_timing("checkpoint_for_recover"): + recover_handler.dump( + actor, + step_info, + saver, + evaluator, + stats_logger, + train_dataloader, + tokenizer=tokenizer, + ) + + # Upload statistics to the logger (e.g., wandb) + stats_logger.commit(epoch, step, global_step, actor.export_stats()) + + # Resume rollout + rollout.resume() + + finally: + stats_logger.close() + rollout.destroy() + if ref is not None: + ref.destroy() + actor.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/lora/gsm8k_grpo_vllm_single_controller.yaml b/examples/lora/gsm8k_grpo_vllm_single_controller.yaml new file mode 100644 index 000000000..f62948414 --- /dev/null +++ b/examples/lora/gsm8k_grpo_vllm_single_controller.yaml @@ -0,0 +1,186 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: vllm:d4p1t1+d4p1t1 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + use_lora: true + scheduling_spec: + - task_type: worker + port_count: 2 + cpu: 1 + gpu: 1 + mem: 1024 + cmd: python3 -m areal.scheduler.rpc.rpc_server + + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + lora_name: "lora-gsm8k" + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ./model/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: false + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-4 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + weight_update_mode: disk + use_lora: ${rollout.use_lora} + peft_type: lora + lora_rank: 16 + lora_alpha: 16 + target_modules: [all-linear] + scheduling_spec: + - task_type: worker + port_count: 2 + cpu: 1 + gpu: 1 + mem: 1024 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: + - task_type: worker + port_count: 2 + cpu: 1 + gpu: 1 + mem: 1024 + cmd: python3 -m areal.scheduler.rpc.rpc_server + +# vLLM +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + enable_lora: ${rollout.use_lora} + lora_modules: '{"name": "${gconfig.lora_name}", "path": "./model/Qwen3.0.6B-16rank", "base_model_name": "${actor.path}"}' + enforce_eager: true + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 From 3309483a1c51656a46ddc3c468b2e5d212e31a6e Mon Sep 17 00:00:00 2001 From: Gursimran Singh Date: Tue, 16 Dec 2025 15:40:35 -0800 Subject: [PATCH 2/4] Cleaned up some debug statements --- examples/lora/gsm8k_grpo_vllm_single_controller.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/lora/gsm8k_grpo_vllm_single_controller.py b/examples/lora/gsm8k_grpo_vllm_single_controller.py index c25e26c3b..65702fca2 100644 --- a/examples/lora/gsm8k_grpo_vllm_single_controller.py +++ b/examples/lora/gsm8k_grpo_vllm_single_controller.py @@ -77,11 +77,6 @@ def main(args): else: raise ValueError(f"Unsupported gen_backend: '{allocation_mode.gen_backend}'") - # import debugpy, os - # debugpy.listen(("0.0.0.0", 2500)) - # debugpy.wait_for_client() - # debugpy.breakpoint() - rollout = RolloutController( engine_class, config=config.rollout, scheduler=scheduler ) From c7484a0c99bdd26a06400acdf3c2fa57d3e51546 Mon Sep 17 00:00:00 2001 From: Gursimran Singh Date: Fri, 2 Jan 2026 11:25:05 -0800 Subject: [PATCH 3/4] Updated and tested (performance matched with full RL) as per new design --- areal/api/io_struct.py | 2 +- areal/experimental/trainer/rl.py | 27 ++-- examples/math/gsm8k_grpo_lora.yaml | 192 +++++++++++++++++++++++++++++ 3 files changed, 212 insertions(+), 9 deletions(-) create mode 100644 examples/math/gsm8k_grpo_lora.yaml diff --git a/areal/api/io_struct.py b/areal/api/io_struct.py index 50cddc723..c550a8b3e 100644 --- a/areal/api/io_struct.py +++ b/areal/api/io_struct.py @@ -148,7 +148,7 @@ def from_disk( use_lora: bool = False, clear_checkpoint_after_load: bool = True, lora_name: str = "", - lora_int_id: int = 0, + lora_int_id: int = 1, base_model_name: str = "", ) -> "WeightUpdateMeta": from areal.utils.saver import Saver diff --git a/areal/experimental/trainer/rl.py b/areal/experimental/trainer/rl.py index 91a05252c..f28c10a15 100644 --- a/areal/experimental/trainer/rl.py +++ b/areal/experimental/trainer/rl.py @@ -132,14 +132,25 @@ def __init__( # Prepare weight update meta and connect to inference engine if self.config.actor.weight_update_mode == "disk": - self.weight_update_meta = WeightUpdateMeta.from_disk( - experiment_name=config.experiment_name, - trial_name=config.trial_name, - file_root=config.cluster.fileroot, - name="default", - use_lora=config.actor.use_lora, - clear_checkpoint_after_load=True, - ) + if config.actor.use_lora: + self.weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + file_root=config.cluster.fileroot, + name="default", + clear_checkpoint_after_load=True, + use_lora=config.actor.use_lora, + lora_name=config.gconfig.lora_name, + base_model_name=config.actor.path, + ) + else: + self.weight_update_meta = WeightUpdateMeta.from_disk( + experiment_name=config.experiment_name, + trial_name=config.trial_name, + file_root=config.cluster.fileroot, + name="default", + clear_checkpoint_after_load=True, + ) elif self.config.actor.weight_update_mode == "xccl": # NCCL/XCCL weight update if self.allocation_mode.train_backend == "megatron": diff --git a/examples/math/gsm8k_grpo_lora.yaml b/examples/math/gsm8k_grpo_lora.yaml new file mode 100644 index 000000000..b0cbdce36 --- /dev/null +++ b/examples/math/gsm8k_grpo_lora.yaml @@ -0,0 +1,192 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 3 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 16 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: vllm:d8p1t1+d8p1t1 + + +scheduler: + type: local + + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + use_lora: true + scheduling_spec: ${actor.scheduling_spec} + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 1024 + greedy: false + temperature: 1.0 + lora_name: "lora-gsm8k" + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-4 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + weight_update_mode: disk + use_lora: ${rollout.use_lora} + peft_type: lora + lora_rank: 16 + lora_alpha: 16 + target_modules: [all-linear] + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + cpu: 4 + mem: 32 + cmd: python3 -m areal.scheduler.rpc.rpc_server + env_vars: {} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +# vLLM +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + enable_lora: ${rollout.use_lora} + lora_modules: '{"name": "${gconfig.lora_name}", "path": ./model/Qwen3.0.6B-16rank", "base_model_name": "${actor.path}"}' + enforce_eager: true + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: null + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false From bfa6b137de74539662ceafea40b70b431b409253 Mon Sep 17 00:00:00 2001 From: Gursimran Singh Date: Fri, 2 Jan 2026 13:09:16 -0800 Subject: [PATCH 4/4] removed old single controller examples in lora folder as they are not required anymore --- .../lora/gsm8k_grpo_vllm_single_controller.py | 237 ------------------ .../gsm8k_grpo_vllm_single_controller.yaml | 186 -------------- 2 files changed, 423 deletions(-) delete mode 100644 examples/lora/gsm8k_grpo_vllm_single_controller.py delete mode 100644 examples/lora/gsm8k_grpo_vllm_single_controller.yaml diff --git a/examples/lora/gsm8k_grpo_vllm_single_controller.py b/examples/lora/gsm8k_grpo_vllm_single_controller.py deleted file mode 100644 index 65702fca2..000000000 --- a/examples/lora/gsm8k_grpo_vllm_single_controller.py +++ /dev/null @@ -1,237 +0,0 @@ -import os -import sys - -from areal.api.alloc_mode import AllocationMode -from areal.api.cli_args import GRPOConfig, SGLangConfig, load_expr_config, vLLMConfig -from areal.api.io_struct import FinetuneSpec, StepInfo, WeightUpdateMeta -from areal.controller.rollout_controller import RolloutController -from areal.controller.train_controller import TrainController -from areal.dataset import get_custom_dataset -from areal.engine.ppo.actor import FSDPPPOActor -from areal.engine.sglang_remote import RemoteSGLangEngine -from areal.engine.vllm_remote import RemotevLLMEngine -from areal.scheduler.local import LocalScheduler -from areal.utils import stats_tracker -from areal.utils.data import ( - cycle_dataloader, -) -from areal.utils.dataloader import create_dataloader -from areal.utils.device import log_gpu_stats -from areal.utils.evaluator import Evaluator -from areal.utils.hf_utils import load_hf_tokenizer -from areal.utils.recover import RecoverHandler -from areal.utils.saver import Saver -from areal.utils.stats_logger import StatsLogger - - -def main(args): - config, _ = load_expr_config(args, GRPOConfig) - config: GRPOConfig - - tokenizer = load_hf_tokenizer(config.tokenizer_path) - - # Create dataset and dataloaders - train_dataset = get_custom_dataset( - split="train", dataset_config=config.train_dataset, tokenizer=tokenizer - ) - - train_dataloader = create_dataloader( - train_dataset, - rank=0, - world_size=1, - dataset_config=config.train_dataset, - ) - - ft_spec = FinetuneSpec( - total_train_epochs=config.total_train_epochs, - dataset_size=len(train_dataloader) * config.train_dataset.batch_size, - train_batch_size=config.train_dataset.batch_size, - ) - - # Initialize scheduler - scheduler = LocalScheduler(exp_config=config) - - # Initialize train controller - allocation_mode = AllocationMode.from_str(config.allocation_mode) - actor = TrainController(FSDPPPOActor, config=config.actor, scheduler=scheduler) - actor.initialize( - role="actor", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None - ) - - # Initialize inference engine - - if allocation_mode.gen_backend == "sglang": - engine_class = RemoteSGLangEngine - server_args = SGLangConfig.build_args( - sglang_config=config.sglang, - tp_size=allocation_mode.gen.tp_size, - base_gpu_id=0, - ) - elif allocation_mode.gen_backend == "vllm": - engine_class = RemotevLLMEngine - server_args = vLLMConfig.build_args( - vllm_config=config.vllm, - tp_size=allocation_mode.gen.tp_size, - pp_size=allocation_mode.gen.pp_size, - ) - else: - raise ValueError(f"Unsupported gen_backend: '{allocation_mode.gen_backend}'") - - rollout = RolloutController( - engine_class, config=config.rollout, scheduler=scheduler - ) - rollout.initialize( - role="rollout", - alloc_mode=allocation_mode, - server_args=server_args, - ) - - if config.actor.weight_update_mode == "disk": - weight_update_meta = WeightUpdateMeta.from_disk( - experiment_name=config.saver.experiment_name, - trial_name=config.saver.trial_name, - file_root=config.saver.fileroot, - use_lora=config.actor.use_lora, - lora_name=config.gconfig.lora_name, - lora_int_id=1, - base_model_name=config.actor.path, - ) - elif config.actor.weight_update_mode == "xccl": - weight_update_meta = WeightUpdateMeta.from_fsdp_xccl( - allocation_mode, - use_lora=config.actor.use_lora, - lora_name=config.gconfig.lora_name, - lora_int_id=1, # hard coded for the single lora example - base_model_name=config.actor.path, - ) - - actor.connect_engine(rollout, weight_update_meta) - - ref = None - if config.actor.kl_ctl > 0 and config.ref is not None: - ref = TrainController(FSDPPPOActor, config=config.ref, scheduler=scheduler) - ref.initialize( - role="ref", alloc_mode=allocation_mode, ft_spec=ft_spec, addr=None - ) - - # Run training. - saver = Saver(config.saver, ft_spec) - stats_logger = StatsLogger(config, ft_spec) - evaluator = Evaluator(config.evaluator, ft_spec) - - recover_handler = RecoverHandler(config.recover, ft_spec) - - try: - recover_info = recover_handler.load( - actor, - saver, - evaluator, - stats_logger, - train_dataloader, - inference_engine=rollout, - weight_update_meta=weight_update_meta, - ) - start_step = ( - recover_info.last_step_info.next().global_step - if recover_info is not None - else 0 - ) - - total_epochs = config.total_train_epochs - steps_per_epoch = len(train_dataloader) - max_steps = total_epochs * steps_per_epoch - - data_generator = cycle_dataloader(train_dataloader) - for global_step in range(start_step, max_steps): - epoch = global_step // steps_per_epoch - step = global_step % steps_per_epoch - step_info = StepInfo( - global_step=global_step, - epoch=epoch, - epoch_step=step, - steps_per_epoch=steps_per_epoch, - ) - - with stats_tracker.record_timing("rollout"): - workflow_kwargs = dict( - reward_fn="areal.reward.gsm8k.gsm8k_reward_fn", - gconfig=config.gconfig, - tokenizer=config.tokenizer_path, - enable_thinking=False, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), - "generated", - ), - ) - if config.rollout.max_head_offpolicyness > 0: - batch = actor.prepare_batch( - train_dataloader, - workflow="areal.workflow.rlvr.RLVRWorkflow", - workflow_kwargs=workflow_kwargs, - ) - else: - batch = actor.rollout_batch( - next(data_generator), - workflow="areal.workflow.rlvr.RLVRWorkflow", - workflow_kwargs=workflow_kwargs, - ) - - if config.actor.recompute_logprob or config.actor.use_decoupled_loss: - with stats_tracker.record_timing("recompute_logp"): - logp = actor.compute_logp(batch) - batch["prox_logp"] = logp - log_gpu_stats("recompute logp") - - if ref is not None: - with stats_tracker.record_timing("ref_logp"): - batch["ref_logp"] = ref.compute_logp(batch) - log_gpu_stats("ref logp") - - with stats_tracker.record_timing("compute_advantage"): - batch = actor.compute_advantages(batch) - log_gpu_stats("compute advantages") - - with stats_tracker.record_timing("train_step"): - actor.ppo_update(batch) - actor.step_lr_scheduler() - log_gpu_stats("ppo update") - - # pause inference for updating weights, save, and evaluation - rollout.pause() - - with stats_tracker.record_timing("update_weights"): - actor.update_weights(weight_update_meta) - - actor.set_version(global_step + 1) - rollout.set_version(global_step + 1) - - with stats_tracker.record_timing("save"): - saver.save(actor, epoch, step, global_step, tokenizer=tokenizer) - - with stats_tracker.record_timing("checkpoint_for_recover"): - recover_handler.dump( - actor, - step_info, - saver, - evaluator, - stats_logger, - train_dataloader, - tokenizer=tokenizer, - ) - - # Upload statistics to the logger (e.g., wandb) - stats_logger.commit(epoch, step, global_step, actor.export_stats()) - - # Resume rollout - rollout.resume() - - finally: - stats_logger.close() - rollout.destroy() - if ref is not None: - ref.destroy() - actor.destroy() - - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/examples/lora/gsm8k_grpo_vllm_single_controller.yaml b/examples/lora/gsm8k_grpo_vllm_single_controller.yaml deleted file mode 100644 index f62948414..000000000 --- a/examples/lora/gsm8k_grpo_vllm_single_controller.yaml +++ /dev/null @@ -1,186 +0,0 @@ -experiment_name: gsm8k-grpo -trial_name: trial0 - -seed: 1 -enable_offload: false -total_train_epochs: 10 -tokenizer_path: ${actor.path} - -cluster: - n_nodes: 1 - n_gpus_per_node: 8 - fileroot: /tmp/areal/experiments - name_resolve: - type: nfs - nfs_record_root: /tmp/areal/name_resolve - -allocation_mode: vllm:d4p1t1+d4p1t1 - -rollout: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - max_concurrent_rollouts: 256 - queue_size: null - consumer_batch_size: ${train_dataset.batch_size} - max_head_offpolicyness: 2 - enable_rollout_tracing: false - use_lora: true - scheduling_spec: - - task_type: worker - port_count: 2 - cpu: 1 - gpu: 1 - mem: 1024 - cmd: python3 -m areal.scheduler.rpc.rpc_server - - -gconfig: - n_samples: 4 - min_new_tokens: 0 - max_new_tokens: 1024 - greedy: false - temperature: 1.0 - lora_name: "lora-gsm8k" - -actor: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ./model/Qwen3-0.6B - init_from_scratch: false - disable_dropout: true - gradient_checkpointing: false - dtype: bfloat16 - mb_spec: - max_tokens_per_mb: 10240 - optimizer: - type: adam - lr: 1.70e-4 - weight_decay: 0.017 - beta1: 0.9 - beta2: 0.999 - eps: 1e-8 - lr_scheduler_type: constant - gradient_clipping: 1.0 - warmup_steps_proportion: 0.001 - group_size: ${gconfig.n_samples} - eps_clip: 0.4 - temperature: ${gconfig.temperature} - reward_scaling: 10.0 - reward_bias: -0.5 - kl_ctl: 0.0 - ppo_n_minibatches: 1 - recompute_logprob: true - use_decoupled_loss: true - behav_imp_weight_cap: 5.0 - dynamic_sampling: false - reward_norm: - mean_level: group - std_level: group - group_size: ${gconfig.n_samples} - adv_norm: - mean_level: batch - std_level: batch - max_new_tokens: ${gconfig.max_new_tokens} - weight_update_mode: disk - use_lora: ${rollout.use_lora} - peft_type: lora - lora_rank: 16 - lora_alpha: 16 - target_modules: [all-linear] - scheduling_spec: - - task_type: worker - port_count: 2 - cpu: 1 - gpu: 1 - mem: 1024 - cmd: python3 -m areal.scheduler.rpc.rpc_server - -ref: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: ${actor.path} - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 10240 - optimizer: null - scheduling_strategy: - type: colocation - target: actor - scheduling_spec: - - task_type: worker - port_count: 2 - cpu: 1 - gpu: 1 - mem: 1024 - cmd: python3 -m areal.scheduler.rpc.rpc_server - -# vLLM -vllm: - model: ${actor.path} - seed: ${seed} - skip_tokenizer_init: false - dtype: ${actor.dtype} - max_model_len: 32768 - gpu_memory_utilization: 0.9 - enable_lora: ${rollout.use_lora} - lora_modules: '{"name": "${gconfig.lora_name}", "path": "./model/Qwen3.0.6B-16rank", "base_model_name": "${actor.path}"}' - enforce_eager: true - -# datasets -train_dataset: - batch_size: 256 - shuffle: true - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - max_length: 1024 - -valid_dataset: - batch_size: 256 - shuffle: true - pin_memory: true - num_workers: 4 - path: openai/gsm8k - type: rl - -# Utilities -saver: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -recover: - mode: disabled - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: 3600 - -evaluator: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - freq_epochs: 1 - freq_steps: null - freq_secs: null - -stats_logger: - experiment_name: ${experiment_name} - trial_name: ${trial_name} - fileroot: ${cluster.fileroot} - wandb: - mode: disabled - -launcher: - inference_server_cpus_per_gpu: 4 - inference_server_mem_per_gpu: 32768 - trainer_cpus_per_gpu: 4 - trainer_mem_per_gpu: 32768