Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions areal/api/engine_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.alloc_mode import ParallelStrategy
from areal.api.cli_args import PerfTracerConfig
from areal.api.io_struct import (
DeviceRuntimeInfo,
LocalInfServerInfo,
Expand Down Expand Up @@ -478,6 +479,33 @@ def offload(self) -> None:
def get_device_stats(self) -> DeviceRuntimeInfo:
raise NotImplementedError()

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
"""Save performance tracer data.

Parameters
----------
step : int, optional
The current training step number, by default None
force : bool, optional
If True, force save regardless of internal conditions, by default False
"""

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
"""Configure performance tracer.

Parameters
----------
config : PerfTracerConfig
Configuration for the performance tracer.
rank : int
Rank of the current process within its role.
role : str
Role of this process. "master" by default or "actor",
"ref", "rollout", etc. in RPC workers.
"""


class InferenceEngine(abc.ABC):
def initialize(self, *args, **kwargs):
Expand Down Expand Up @@ -867,3 +895,30 @@ def export_stats(self) -> dict[str, float]:
The recorded scalar statistics.
"""
raise NotImplementedError()

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
"""Save performance tracer data.

Parameters
----------
step : int, optional
The current training step number, by default None
force : bool, optional
If True, force save regardless of internal conditions, by default False
"""

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
"""Configure performance tracer.

Parameters
----------
config : PerfTracerConfig
Configuration for the performance tracer.
rank : int
Rank of the current process within its role.
role : str
Role of this process. "master" by default or "actor",
"ref", "rollout", etc. in RPC workers.
"""
21 changes: 20 additions & 1 deletion areal/controller/rollout_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.alloc_mode import AllocationMode
from areal.api.cli_args import InferenceEngineConfig, SchedulingSpec
from areal.api.cli_args import InferenceEngineConfig, PerfTracerConfig, SchedulingSpec
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
LocalInfServerInfo,
Expand Down Expand Up @@ -551,6 +551,25 @@ def export_stats(self) -> dict[str, float]:
final_stats[k] = v / counts[count_key]
return final_stats

def config_perf_tracer(self, config: PerfTracerConfig, role: str) -> None:
async def _call():
tasks = [
self.scheduler.async_call_engine(
worker_id=worker.id,
method="config_perf_tracer",
rank=rank,
role=role,
config=config,
)
for rank, worker in enumerate(self.workers)
]
return await asyncio.gather(*tasks)

asyncio.run(_call())

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
self._collective_rpc("save_perf_tracer", step=step, force=force)

@property
def staleness_manager(self):
return self._staleness_manager
Expand Down
21 changes: 20 additions & 1 deletion areal/controller/train_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.alloc_mode import ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
from areal.api.cli_args import PerfTracerConfig, TrainEngineConfig
from areal.api.engine_api import TrainEngine
from areal.api.io_struct import (
AllocationMode,
Expand Down Expand Up @@ -490,6 +490,25 @@ def connect_engine(self, rollout: RolloutController, meta: WeightUpdateMeta):
def get_device_stats(self):
return self._custom_function_call("get_device_stats")

def config_perf_tracer(self, config: PerfTracerConfig, role: str) -> None:
async def _call():
tasks = [
self.scheduler.async_call_engine(
worker_id=worker.id,
method="config_perf_tracer",
rank=rank,
role=role,
config=config,
)
for rank, worker in enumerate(self.workers)
]
return await asyncio.gather(*tasks)

self._run_async_task(_call())

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
self._custom_function_call("save_perf_tracer", step=step, force=force)

def prepare_batch(
self,
dataloader: StatefulDataLoader,
Expand Down
19 changes: 17 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

from areal.api.alloc_mode import FSDPParallelStrategy, ParallelStrategy
from areal.api.cli_args import TrainEngineConfig
from areal.api.cli_args import PerfTracerConfig, TrainEngineConfig
from areal.api.engine_api import InferenceEngine, TrainEngine
from areal.api.io_struct import (
DeviceRuntimeInfo,
Expand All @@ -59,7 +59,14 @@
)
from areal.models.transformers.ulyssess_patch import apply_monkey_patch
from areal.platforms import current_platform
from areal.utils import logging, name_resolve, names, pkg_version, stats_tracker
from areal.utils import (
logging,
name_resolve,
names,
perf_tracer,
pkg_version,
stats_tracker,
)
from areal.utils.constants import DIST_GROUP_DEFAULT_TIMEOUT
from areal.utils.data import (
MicroBatchItem,
Expand Down Expand Up @@ -616,6 +623,14 @@ def clear_batches(self, *args):
def get_device_stats(self) -> DeviceRuntimeInfo:
return DeviceRuntimeInfo.get_current()

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
perf_tracer.save(step=step, force=force)

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
perf_tracer.configure(config, rank=rank, role=role)

def _make_parallel_strategy(
self, parallel_strategy: ParallelStrategy
) -> FSDPParallelStrategy:
Expand Down
12 changes: 10 additions & 2 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers import PretrainedConfig

from areal.api.alloc_mode import MegatronParallelStrategy, ParallelStrategy
from areal.api.cli_args import MicroBatchSpec, TrainEngineConfig
from areal.api.cli_args import MicroBatchSpec, PerfTracerConfig, TrainEngineConfig
from areal.api.engine_api import InferenceEngine, TrainEngine
from areal.api.io_struct import (
DeviceRuntimeInfo,
Expand All @@ -47,7 +47,7 @@
from areal.models.mcore.hf_save import save_weights_to_hf_with_mbridge_fast
from areal.models.mcore.registry import make_hf_and_mcore_config, make_mcore_model
from areal.platforms import current_platform
from areal.utils import logging, name_resolve, names, stats_tracker
from areal.utils import logging, name_resolve, names, perf_tracer, stats_tracker
from areal.utils.constants import DIST_GROUP_DEFAULT_TIMEOUT
from areal.utils.data import (
MicroBatchItem,
Expand Down Expand Up @@ -695,6 +695,14 @@ def clear_batches(self, *args):
def get_device_stats(self) -> DeviceRuntimeInfo:
return DeviceRuntimeInfo.get_current()

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
perf_tracer.save(step=step, force=force)

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
perf_tracer.configure(config, rank=rank, role=role)

def _make_parallel_strategy(
self, parallel_strategy: ParallelStrategy
) -> MegatronParallelStrategy:
Expand Down
12 changes: 10 additions & 2 deletions areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.cli_args import InferenceEngineConfig, SGLangConfig
from areal.api.cli_args import InferenceEngineConfig, PerfTracerConfig, SGLangConfig
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
HttpGenerationResult,
Expand All @@ -26,7 +26,7 @@
from areal.core import RemoteInfEngine
from areal.core.workflow_executor import WorkflowExecutor
from areal.platforms import current_platform
from areal.utils import stats_tracker
from areal.utils import perf_tracer, stats_tracker
from areal.utils.launcher import TRITON_CACHE_PATH


Expand Down Expand Up @@ -364,3 +364,11 @@ def as_controller(

def clear_batches(self, *args):
"""Placeholder method of single-controller API."""

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
perf_tracer.save(step=step, force=force)

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
perf_tracer.configure(config, rank=rank, role=role)
12 changes: 10 additions & 2 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torchdata.stateful_dataloader import StatefulDataLoader

from areal.api.cli_args import InferenceEngineConfig, vLLMConfig
from areal.api.cli_args import InferenceEngineConfig, PerfTracerConfig, vLLMConfig
from areal.api.engine_api import InferenceEngine
from areal.api.io_struct import (
HttpGenerationResult,
Expand All @@ -26,7 +26,7 @@
from areal.core import RemoteInfEngine
from areal.core.workflow_executor import WorkflowExecutor
from areal.platforms import current_platform
from areal.utils import stats_tracker
from areal.utils import perf_tracer, stats_tracker
from areal.utils.launcher import TRITON_CACHE_PATH


Expand Down Expand Up @@ -406,3 +406,11 @@ def as_controller(

def clear_batches(self, *args):
"""Placeholder method of single-controller API."""

def save_perf_tracer(self, step: int | None = None, force: bool = False) -> None:
perf_tracer.save(step=step, force=force)

def config_perf_tracer(
self, config: PerfTracerConfig, rank: int, role: str
) -> None:
perf_tracer.configure(config, rank=rank, role=role)
32 changes: 28 additions & 4 deletions areal/experimental/trainer/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def __init__(
valid_dataset: Dataset | None = None,
):
rank = int(os.getenv("RANK", "0"))
# Configure performance tracer
if config.perf_tracer is not None:
perf_tracer.configure(config.perf_tracer, rank=rank)

self.config = config
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
Expand Down Expand Up @@ -175,6 +172,8 @@ def __init__(
weight_update_meta=self.weight_update_meta,
)

self._config_perf_tracer()

def train(
self,
workflow: RolloutWorkflow | type[RolloutWorkflow] | str,
Expand Down Expand Up @@ -378,7 +377,7 @@ def train(
# Resume rollout
self.rollout.resume()

perf_tracer.save(step=global_step)
self._save_perf_tracer(step=global_step)

def close(self):
self.stats_logger.close()
Expand All @@ -391,6 +390,31 @@ def close(self):
self.actor.destroy()
perf_tracer.save(force=True)

def _config_perf_tracer(self):
rank = int(os.getenv("RANK", "0"))
if self.config.perf_tracer is None:
return
perf_tracer.configure(self.config.perf_tracer, rank=rank, role="master")
self.actor.config_perf_tracer(self.config.perf_tracer, role="actor")
if self.critic is not None:
self.critic.config_perf_tracer(self.config.perf_tracer, role="critic")
if self.ref is not None:
self.ref.config_perf_tracer(self.config.perf_tracer, role="ref")
self.rollout.config_perf_tracer(self.config.perf_tracer, role="rollout")
self.eval_rollout.config_perf_tracer(
self.config.perf_tracer, role="eval-rollout"
)

def _save_perf_tracer(self, step: int):
self.actor.save_perf_tracer(step=step)
if self.ref is not None:
self.ref.save_perf_tracer(step=step)
if self.critic is not None:
self.critic.save_perf_tracer(step=step)
self.eval_rollout.save_perf_tracer(step=step)
self.rollout.save_perf_tracer(step=step)
perf_tracer.save(step=step)

def _init_scheduler(self) -> Scheduler:
cfg = self.config.scheduler
if cfg.type == "local":
Expand Down
20 changes: 16 additions & 4 deletions areal/experimental/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def __init__(
valid_dataset: Dataset | None = None,
):
rank = int(os.getenv("RANK", "0"))
# Configure performance tracer
if config.perf_tracer is not None:
perf_tracer.configure(config.perf_tracer, rank=rank)

self.config = config
self.processor, self.tokenizer = load_hf_processor_and_tokenizer(
Expand Down Expand Up @@ -119,6 +116,8 @@ def __init__(
self.train_dataloader,
)

self._config_perf_tracer()

def train(self):
config = self.config
start_step = (
Expand Down Expand Up @@ -221,13 +220,26 @@ def train(self):
epoch=epoch, epoch_step=step, global_step=global_step
)

perf_tracer.save(step=global_step)
self._save_perf_tracer(step=global_step)

def close(self):
self.stats_logger.close()
self.actor.destroy()
perf_tracer.save(force=True)

def _config_perf_tracer(self):
rank = int(os.getenv("RANK", "0"))
if self.config.perf_tracer is None:
return
perf_tracer.configure(self.config.perf_tracer, rank=rank, role="master")
self.actor.config_perf_tracer(self.config.perf_tracer, role="actor")

def _save_perf_tracer(self, step: int):
if self.config.perf_tracer is None:
return
self.actor.save_perf_tracer(step=step)
perf_tracer.save(step=step)

def _init_scheduler(self) -> Scheduler:
cfg = self.config.scheduler
if cfg.type == "local":
Expand Down
Loading