diff --git a/areal/experimental/openai/proxy.py b/areal/experimental/openai/proxy.py index ac20fcb98..b094941e1 100644 --- a/areal/experimental/openai/proxy.py +++ b/areal/experimental/openai/proxy.py @@ -222,14 +222,20 @@ async def set_reward(request: AReaLSetRewardRequest, session_id: str): raise HTTPException( status_code=400, detail=f"Session {session_id} not found" ) - if interaction_id is None: - # take the last interaction id - interaction_id = state.session_cache[ - session_id - ].completions.last_interaction_id completions = state.session_cache[session_id].completions - if interaction_id not in completions: + if interaction_id is None: + # take the last interaction id + if len(completions) == 0: + logger.error(f"No interactions in session {session_id}") + raise HTTPException( + status_code=400, detail="No interactions in session" + ) + interaction_id = completions.last_interaction_id + elif interaction_id not in completions: + logger.error( + f"Interaction {interaction_id} not found in session {session_id}" + ) raise HTTPException( status_code=400, detail=f"Interaction {interaction_id} not found" ) diff --git a/areal/utils/proxy_utils.py b/areal/utils/proxy_utils.py index b4dff7a79..042024ee4 100644 --- a/areal/utils/proxy_utils.py +++ b/areal/utils/proxy_utils.py @@ -142,7 +142,13 @@ async def _set_reward( url: str = RL_SET_REWARD_PATHNAME, ): payload = AReaLSetRewardRequest(interaction_id=interaction_id, reward=reward) - await post_json_with_retry(http_session, url=url, payload=payload) + try: + await post_json_with_retry(http_session, url=url, payload=payload) + except aiohttp.ClientResponseError as e: + if e.status == 400: + logger.error(f"[error code {e.status}] Error setting reward: {e.message}") + else: + raise e async def set_interaction_reward( @@ -196,7 +202,16 @@ def _get_float_reward(reward: float | int): ) async with aiohttp.ClientSession(base_url) as session: - rewards = await func(data) + info = None + results = await func(data) + if isinstance(results, tuple): + if len(results) != 2: + raise ValueError( + f"Results must be a tuple of (rewards, info), got {len(results)}" + ) + rewards, info = results + else: + rewards = results if isinstance(rewards, dict): for interaction_id, reward in rewards.items(): @@ -212,3 +227,4 @@ def _get_float_reward(reward: float | int): reward=_get_float_reward(rewards), url=pathname, ) + return info diff --git a/examples/tau2/README.md b/examples/tau2/README.md new file mode 100644 index 000000000..7b6aa3049 --- /dev/null +++ b/examples/tau2/README.md @@ -0,0 +1,77 @@ +# Customer Service Agent Training with Tau2 Benchmark + +## Overview + +This example demonstrates how to train customer service agents using the [$\tau^2$-Bench](https://github.com/sierra-research/tau2-bench) with AReaL's PPO/GRPO training pipeline. The $\tau^2$-Bench provides realistic customer service simulation environments across multiple domains (retail, airline, telecom) where agents must help with user's request by both using agent tools and guiding users using their tools. + +## Code Architecture + +The code is modified from the [proxy](../experimental/proxy/README.md) example so that the training workflow (`tau2_train.py`) and the agent runner script (`tau2_agent.py`) can be decoupled, with common utilities in `tau2_utils.py`. + +* `tau2_train.py`: +* `tau2_agent.py`: Reuse the orchestrator, agent and user simulator from the tau2-bench package to build the runner. + +## Running the Example + +### Prerequisites + +Please make sure AReaL is setup and working following the [installation guide](https://inclusionai.github.io/AReaL/tutorial/installation.html). + +1. Install the (forked) tau2-bench package: +```bash +pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion +``` +Note that the training relies on the async version of the agent and user simulator in the tau2-bench package. These changes will be merged into the [original tau2-bench repository](https://github.com/sierra-research/tau2-bench) later. + +2. setup TAU2_DATA_DIR environment variable. +```bash +export TAU2_DATA_DIR=/path/to/tau2-bench/data +``` + +### Basic Training Command + +1. Prepare the user simulator server. + +You need to first setup a user simulator server if using self-hosted LLMs. For example when [using Qwen with SGLang](https://qwen.readthedocs.io/en/latest/deployment/sglang.html): +```bash +python3 -m sglang.launch_server --model-path Qwen/Qwen3-32B --host 0.0.0.0 --tool-call-parser qwen25 --chat-template ./qwen3_nonthinking.jinja --dp-size 8 +``` + +Below we assume the hosted address is http://0.0.0.0:30000/v1/. + +2. Run the training. + +In this example, we use a `small` subset of the tau2-telecom domain, which contains 20 tasks where each task only contains one subtask. + +```bash +python3 -m areal.launcher.ray examples/tau2/tau2_train.py \ + --config examples/tau2/config.yaml \ + experiment_name=tau2-grpo \ + trial_name=trial0 \ + cluster.n_nodes=3 \ + cluster.n_gpus_per_node=8 \ + allocation_mode=sglang:d16+megatron:d2p4 \ + gconfig.n_samples=16 \ + actor.path=Qwen/Qwen2.5-7B-Instruct \ + econfig.domain=telecom \ + econfig.max_steps=30 \ + train_dataset.path=tau2/small \ + train_dataset.batch_size=8 \ + user_llm_base_url=http://0.0.0.0:30000/v1/ +``` + +It uses 2 nodes for rollout, 1 node for training and 1 node for user simulator. +The training data batch size is 8 and group size is 16, resulting in 128 rollouts per step. + +### Curve + +The rollout reward for the training tasks are shown below. + +![Curve](./curve.png) + +For the above example configuration, it usually takes about less than 10 minutes in average (depending on the hardware) for one step. + +## Notes + +1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #17813](https://github.com/BerriAI/litellm/issues/17813). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed. +2. The trajectories will be dumped as `json` and `txt` files in the `generated/` directory. You may read and analyze the trajectories as your need. diff --git a/examples/tau2/config.yaml b/examples/tau2/config.yaml new file mode 100644 index 000000000..9141c76b5 --- /dev/null +++ b/examples/tau2/config.yaml @@ -0,0 +1,174 @@ +experiment_name: tau2-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 50 +tokenizer_path: ${actor.path} + +do_eval: false +export_style: concat + +cluster: + n_nodes: 3 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +allocation_mode: sglang:d16+megatron:d2p4 + +rollout: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + +gconfig: + n_samples: 16 + min_new_tokens: 0 + max_new_tokens: 512 + greedy: false + temperature: 1.0 + +actor: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-7B-Instruct + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 32000 + optimizer: + type: adam + lr: 5e-6 + weight_decay: 0.01 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + group_size: ${gconfig.n_samples} + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + behav_imp_weight_cap: 5.0 + dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + max_new_tokens: ${gconfig.max_new_tokens} + +ref: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 32000 + optimizer: null + +econfig: + domain: telecom + max_steps: 50 + add_thinking_tool: true + solo_mode: false + user_llm_base_url: null # replace with your URL for the user LLM + user_llm: null # replace with your model name. Use 'openai/' for self-hosted openai-compatible server, e.g. 'openai/hosted' + user_llm_args: + temperature: 0.0 + max_completion_tokens: 512 + turn_discount: 1.0 + invalid_format_penalty: 0.1 + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.9 + +# datasets +train_dataset: + batch_size: 8 + pin_memory: true + num_workers: 4 + path: tau2/train + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 16 + pin_memory: true + num_workers: 4 + path: tau2/test + type: rl + drop_last: false + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +launcher: + inference_server_cpus_per_gpu: 4 + inference_server_mem_per_gpu: 32768 + trainer_cpus_per_gpu: 4 + trainer_mem_per_gpu: 32768 diff --git a/examples/tau2/curve.png b/examples/tau2/curve.png new file mode 100644 index 000000000..037b63190 Binary files /dev/null and b/examples/tau2/curve.png differ diff --git a/examples/tau2/tau2_agent.py b/examples/tau2/tau2_agent.py new file mode 100644 index 000000000..3a316b171 --- /dev/null +++ b/examples/tau2/tau2_agent.py @@ -0,0 +1,229 @@ +import asyncio +import time + +from litellm import acompletion, register_model +from tau2.agent.llm_agent import LLMAgent, LLMAgentState, LLMSoloAgent, LocalAgent +from tau2.data_model.tasks import Task +from tau2.environment.environment import Environment +from tau2.environment.tool import Tool +from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation +from tau2.orchestrator.orchestrator import Orchestrator +from tau2.registry import registry +from tau2.user.user_simulator import BaseUser, DummyUser, UserSimulator +from tau2_utils import Tau2EnvConfig, Tau2RunInfo + +from areal.api.cli_args import GenerationHyperparameters +from areal.utils import logging +from areal.utils.proxy_utils import run_and_submit_rewards + +logger = logging.getLogger("Tau2 Agent") + +register_model( + { + "dummy": { + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "openai", + "mode": "chat", + }, + } +) + + +def _get_task(domain: str, task_id: str, split: str | None = None) -> Task: + tasks: list[Task] = registry.get_tasks_loader(domain)(split) + for task in tasks: + if task.id == task_id: + return task + raise ValueError(f"No task found with id {task_id} for domain {domain}") + + +def think(thoughts: str): + """Use this tool to think. The thoughts will be visible in the history. Only use this tool to think when necessary.""" + return "Your thoughts are recorded. Please continue your work." + + +class Tau2Runner: + def __init__(self, econfig: Tau2EnvConfig, gconfig: GenerationHyperparameters): + self.econfig = econfig + self.gconfig = gconfig + self.domain = econfig.domain + self.solo_mode = econfig.solo_mode + self.gen_args = gconfig.to_openai_args_dict(api_format="completions") + + def _get_environment(self) -> Environment: + environment_constructor = registry.get_env_constructor(self.domain) + return environment_constructor(solo_mode=self.solo_mode) + + def _get_agent_and_user(self, task: Task, env: Environment, run_info: Tau2RunInfo): + agent_policy_doc = env.get_policy() + tools: list[Tool] = env.get_tools() + try: + user_tools = env.get_user_tools() + except Exception: + user_tools = [] + if self.econfig.add_thinking_tool: + tools.append(Tool(think)) + + # * Backup: use acreate to replace acompletion + # async def _acreate(*args, **kwargs): + # kwargs.pop("num_retries", None) + # completion = await client.chat.completions.create(*args, **kwargs) + # return completion + + # async def _acreate_with_base_url(*args, **kwargs): + # kwargs.pop("num_retries", None) + # async with AsyncOpenAI(base_url=self.econfig.user_llm_base_url) as client: + # completion = await client.chat.completions.create(*args, **kwargs) + # return completion + + async def _acompletion(*args, **kwargs): + start_time = time.perf_counter() + try: + return await acompletion(*args, **kwargs) + finally: + run_info.agent_time.append(time.perf_counter() - start_time) + + async def _acompletion_with_base_url(*args, **kwargs): + start_time = time.perf_counter() + try: + return await acompletion( + *args, base_url=self.econfig.user_llm_base_url, **kwargs + ) + finally: + run_info.user_time.append(time.perf_counter() - start_time) + + if self.solo_mode: + agent = LLMSoloAgent( + tools=tools + user_tools, + domain_policy=agent_policy_doc, + llm="dummy", + llm_args=self.gen_args, + task=task, + completion_fn=_acompletion, + ) + user = DummyUser() + else: + agent = LLMAgent( + tools=tools, + domain_policy=agent_policy_doc, + llm="dummy", + llm_args=self.gen_args, + completion_fn=_acompletion, + ) + + user = UserSimulator( + tools=user_tools if len(user_tools) > 0 else None, + instructions=str(task.user_scenario), + llm=self.econfig.user_llm, + llm_args=self.econfig.user_llm_args, + completion_fn=_acompletion_with_base_url, + ) + return agent, user + + def _get_orchestrator( + self, + agent: LocalAgent[LLMAgentState], + user: BaseUser, + env: Environment, + task: Task, + ) -> Orchestrator: + return Orchestrator( + domain=self.domain, + agent=agent, + user=user, + environment=env, + task=task, + max_steps=self.econfig.max_steps, + # max_errors=self.econfig.max_errors, + # seed=self.econfig.seed, + ) + + async def run(self, task: Task) -> Tau2RunInfo: + domain = self.domain + solo_mode = self.solo_mode + logger.info( + f"STARTING SIMULATION: Domain: {domain}, Task: {task.id}, " + f"Solo Mode: {solo_mode}" + ) + + env = self._get_environment() + run_info = Tau2RunInfo( + reward=0.0, + task=task, + messages=[], + agent_time=[], + user_time=[], + reward_info=None, + error=None, + ) + agent, user = self._get_agent_and_user(task=task, env=env, run_info=run_info) + orchestrator = self._get_orchestrator( + agent=agent, user=user, env=env, task=task + ) + + try: + simulation = await orchestrator.arun() + run_info.messages = simulation.messages + except Exception as e: + logger.error( + f"ERROR RUNNING SIMULATION: Domain: {domain}, Task: {task.id}, " + f"Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. " + f"Error running simulation: {e}. Setting reward to 0.0" + ) + run_info.messages = orchestrator.get_trajectory() + run_info.error = str(e) + return run_info + + try: + reward_info = evaluate_simulation( + domain=domain, + task=task, + simulation=simulation, + evaluation_type=EvaluationType.ALL, + solo_mode=solo_mode, + ) + run_info.reward_info = reward_info + run_info.reward = reward_info.reward + except Exception as e: + logger.error( + f"ERROR EVALUATING SIMULATION: Domain: {domain}, Task: {task.id}, " + f"Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. " + f"Error evaluating simulation: {e}. Setting reward to 0.0" + ) + run_info.reward_info = None + run_info.error = str(e) + return run_info + + logger.info( + f"FINISHED SIMULATION: Domain: {domain}, Task: {task.id}, " + f"Agent: {agent.__class__.__name__}, User: {user.__class__.__name__}. " + f"Reward: {reward_info.reward}" + ) + return run_info + + +async def run_agent_return_reward(data: dict) -> tuple[float, dict]: + econfig = Tau2EnvConfig(**data.get("econfig", {})) + gconfig = GenerationHyperparameters(**data.get("gconfig", {})) + + domain = econfig.domain + split = data["split"] + task_id = data["task_id"] + task = _get_task(domain=domain, task_id=task_id, split=split) + + tau2_runner = Tau2Runner(econfig, gconfig) + run_info = await tau2_runner.run(task) + return run_info.reward, run_info + + +async def run_and_submit(data: dict): + return await run_and_submit_rewards(func=run_agent_return_reward, data=data) + + +if __name__ == "__main__": + import json + import sys + + data = json.loads(sys.stdin.readline()) + asyncio.run(run_and_submit(data)) diff --git a/examples/tau2/tau2_train.py b/examples/tau2/tau2_train.py new file mode 100644 index 000000000..6db522b2d --- /dev/null +++ b/examples/tau2/tau2_train.py @@ -0,0 +1,317 @@ +import asyncio +import os +import sys +import traceback +import uuid +from collections.abc import Callable +from concurrent.futures import ProcessPoolExecutor +from dataclasses import asdict, dataclass, field +from typing import Any + +import aiofiles +import numpy as np +from datasets import Dataset +from loguru import logger as loguru_logger +from tau2.registry import registry +from tau2_utils import Tau2EnvConfig, Tau2RunInfo + +from areal.api.cli_args import ( + GenerationHyperparameters, + PPOConfig, + load_expr_config, +) +from areal.api.engine_api import InferenceEngine +from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai.proxy import ( + ProxyServer, + ProxySession, + ensure_end_with_slash, +) +from areal.experimental.trainer.rl import PPOTrainer +from areal.utils import logging, stats_tracker +from areal.utils.dynamic_import import import_from_string +from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.stats_logger import StatsLogger + +logger = logging.getLogger("Tau2 Example") + + +# ================================ dataset ================================ +def get_tau2_dataset( + domain: str, + type: str = "rl", + split: str = "train", +) -> Dataset: + """Create a HuggingFace Dataset from tau2 task IDs. + + Args: + domain: The tau2 domain name, e.g., 'retail', 'airline', 'telecom' + split: Dataset split (e.g., 'train', 'test') + type: Dataset type (e.g., 'rl', 'sft'), only 'rl' is supported for now + + Returns: + Dataset: HuggingFace Dataset containing task_id entries + """ + assert type == "rl", "Only RL dataset is supported for now" + # TODO: support SFT dataset + + splits_loader_fn = registry.get_task_splits_loader(domain) + if splits_loader_fn is None: + raise ValueError(f"No task splits loader found for domain {domain}") + splits = splits_loader_fn() + if split not in splits: + raise ValueError( + f"Split {split} not found in {splits}, available splits: {splits.keys()}" + ) + task_ids = splits[split] + # print(f"domain: {domain}, split: {split}, task_ids: {task_ids}") + + dataset_items = [{"task_id": task_id, "split": split} for task_id in task_ids] + dataset = Dataset.from_list(dataset_items) + return dataset + + +def run_fn(func: Callable, extra_envs: dict, *args, **kwargs) -> Any: + for key, value in extra_envs.items(): + os.environ[key] = value + try: + if asyncio.iscoroutinefunction(func): + return asyncio.run(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + except Exception as e: + traceback.print_exc() + raise e + + +class Tau2Workflow(RolloutWorkflow): + def __init__( + self, + proxy_server: ProxyServer, + gconfig: GenerationHyperparameters, + econfig: Tau2EnvConfig, + base_url: str, + max_concurrent_processes: int, + agent_module_path: str, + agent_run_args: dict | None = None, + rollout_stat_scope: str = "rollout", + export_style: str = "concat", + max_tokens_per_mb: int = 32768, + dump_dir: str | None = None, + ): + self.proxy_server = proxy_server + self.group_size = gconfig.n_samples + self.gconfig = gconfig.new(n_samples=1) + self.econfig = econfig + self.base_url = ensure_end_with_slash(base_url) + self.process_pool = ProcessPoolExecutor(max_workers=max_concurrent_processes) + self.agent_func = import_from_string( + ".".join([agent_module_path, "run_and_submit"]) + ) + self.agent_run_args = agent_run_args or {} + self.rollout_stat_scope = rollout_stat_scope + self.export_style = export_style + self.max_tokens_per_mb = max_tokens_per_mb + self.dump_dir = dump_dir + if self.dump_dir is not None and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + + async def _run_episode(self, task_id: str, data: dict) -> Any: + process_data = { + "gconfig": asdict(self.gconfig), + "econfig": asdict(self.econfig), + "agent_run_args": self.agent_run_args, + **data, + } + async with ProxySession(base_url=self.base_url, task_id=task_id) as session: + extra_envs = { + "OPENAI_BASE_URL": session.session_url, + "OPENAI_API_KEY": os.environ["OPENAI_API_KEY"], + "AREAL_SESSION_ID": session.session_id, + "AREAL_TASK_ID": task_id, + } + return await asyncio.wrap_future( + self.process_pool.submit( + run_fn, + func=self.agent_func, + extra_envs=extra_envs, + data=process_data, + ) + ) + + async def arun_episode(self, engine: InferenceEngine, data): + task_id = uuid.uuid4().hex # use a unique task id for each run + run_infos: list[Tau2RunInfo] = await asyncio.gather( + *[self._run_episode(task_id, data) for _ in range(self.group_size)] + ) + # the queue is prepared for separated agent and trainer mode, should not be used in this example + await ProxyServer.finish_task( + task_id, base_url=self.base_url, put_to_queue=False + ) + + session_ids = [f"{task_id}-{i}" for i in range(self.group_size)] + rewards, completions = await self.proxy_server.get_results( + session_ids, style=self.export_style + ) + + # log stats + for reward in rewards.values(): + stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward) + + for info in run_infos: + stats_tracker.get(self.rollout_stat_scope).scalar( + steps_count=len(info.messages), + orchestrator_error=int(info.error is not None), + ) + + def add_to_stats(name: str, times: list[float]): + if len(times): + for key in ["mean", "max", "min", "std", "sum"]: + stats_tracker.get(self.rollout_stat_scope).scalar( + **{f"{name}_time/{key}": getattr(np.array(times), key)()} + ) + + add_to_stats("agent", info.agent_time) + add_to_stats("user", info.user_time) + + # Dump info to file + if "task_id" in data: + real_task_id = data["task_id"][:120] + "-" + task_id + else: + real_task_id = task_id + for i, info in enumerate(run_infos): + try: + json_path = os.path.join(self.dump_dir, f"{real_task_id}-{i}.json") + async with aiofiles.open(json_path, "w") as f: + await f.write(info.model_dump_json()) + + file_path = os.path.join(self.dump_dir, f"{real_task_id}-{i}.txt") + async with aiofiles.open(file_path, "a") as f: + await f.write(str(info)) + except Exception as e: + logger.error(f"Error dumping rollout to file: {e}") + + if len(completions) != self.group_size: + raise RuntimeError( + f"Expected {self.group_size} completions, but got {len(completions)}" + ) + + return completions + + +@dataclass +class Tau2PPOConfig(PPOConfig): + econfig: Tau2EnvConfig = field(default_factory=Tau2EnvConfig) + tool_call_parser: str = field( + default="qwen25", + metadata={"help": "Tool call parser that used by ArealOpenAI client."}, + ) + export_style: str = field( + default="concat", + metadata={ + "help": "Style for exporting completion results from the proxy server." + }, + ) + agent_module_path: str | None = field( + default="examples.tau2.tau2_agent", + metadata={"help": "Module path for the agent definition."}, + ) + agent_run_args: dict = field( + default_factory=dict, + metadata={"help": "Arguments for running the agent."}, + ) + do_eval: bool = field( + default=False, + metadata={"help": "Whether to do evaluation."}, + ) + + +def main(args): + config, _ = load_expr_config(args, Tau2PPOConfig) + tokenizer = load_hf_tokenizer(config.tokenizer_path) + domain = config.econfig.domain + + # remove the logging of loguru logger in tau2-bench package. + loguru_logger.remove() + loguru_logger.add( + os.path.join(StatsLogger.get_log_path(config.stats_logger), "tau2.log"), + level="INFO", + ) + + # Create dataset and dataloaders + train_dataset = get_tau2_dataset( + domain=domain, + type=config.train_dataset.type, + split=config.train_dataset.path.split("/")[-1], + ) + valid_dataset = get_tau2_dataset( + domain=domain, + type=config.valid_dataset.type, + split=config.valid_dataset.path.split("/")[-1], + ) + + with PPOTrainer( + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + ) as trainer: + world_size = trainer.actor.data_parallel_world_size + chat_template_type = "hf" if config.export_style == "individual" else "concat" + cpu_count = os.cpu_count() or 64 + num_processes = config.rollout.max_concurrent_rollouts or cpu_count + + def _get_server_and_workflow(rollout: InferenceEngine, is_eval: bool = False): + gconfig = ( + config.gconfig.new(temperature=0.6, n_samples=1) + if is_eval + else config.gconfig + ) + name = "train" if not is_eval else "eval" + rollout_stat_scope = "rollout" if not is_eval else "eval-rollout" + dump_dir = "generated" if not is_eval else "generated-eval" + + server = ProxyServer( + rollout=rollout, + tokenizer=tokenizer, + tool_call_parser=config.tool_call_parser, + chat_template_type=chat_template_type, + name=f"{name} proxy server", + ) + server.start(wait_until_ready=True) + workflow = Tau2Workflow( + proxy_server=server, + gconfig=gconfig, + econfig=config.econfig, + base_url=f"{server.public_addr}/v1", + max_concurrent_processes=num_processes // world_size, + agent_module_path=config.agent_module_path, + agent_run_args=config.agent_run_args, + rollout_stat_scope=rollout_stat_scope, + export_style=config.export_style, + max_tokens_per_mb=config.actor.mb_spec.max_tokens_per_mb, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), + dump_dir, + ), + ) + return server, workflow + + proxy_server, workflow = _get_server_and_workflow( + trainer.rollout, is_eval=False + ) + + if config.do_eval: + eval_proxy_server, eval_workflow = _get_server_and_workflow( + trainer.eval_rollout, is_eval=True + ) + else: + eval_proxy_server, eval_workflow = None, None + + trainer.train(workflow, eval_workflow) + proxy_server.close() + if eval_proxy_server is not None: + eval_proxy_server.close() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/tau2/tau2_utils.py b/examples/tau2/tau2_utils.py new file mode 100644 index 000000000..d579c5051 --- /dev/null +++ b/examples/tau2/tau2_utils.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass, field + +import yaml +from pydantic import BaseModel +from tau2.data_model.message import Message +from tau2.data_model.simulation import RewardInfo +from tau2.data_model.tasks import Task + + +class Tau2RunInfo(BaseModel): + reward: float + agent_time: list[float] + user_time: list[float] + messages: list[Message] + task: Task + reward_info: RewardInfo | None = None + error: str | None = None + + def __str__(self): + s = f"[REWARD]: {self.reward}\n\n" + s += "[TASK]\n" + s += yaml.dump(self.task.model_dump()) + "\n" + if self.reward_info: + s += "[REWARD_INFO]\n" + s += yaml.dump(self.reward_info.model_dump()) + "\n" + s += f"[TURNS COUNT]: {len(self.messages)}\n" + s += "[MESSAGES]\n" + for message in self.messages: + turn_idx = message.turn_idx + role = message.role + content = message.content or "" + usage = getattr(message, "usage", {}) + tool_calls = getattr(message, "tool_calls", None) + if tool_calls: + content += "\n[TOOL_CALLS]\n" + content += yaml.dump( + [tool_call.model_dump() for tool_call in tool_calls] + ) + s += f"[{turn_idx}][{role}]: {content}\n" + if usage: + s += f"[{turn_idx}][{role}][USAGE]: {yaml.dump(usage)}\n" + if len(self.agent_time): + s += f"[AGENT_TIME]: total {sum(self.agent_time)}, avg {sum(self.agent_time) / len(self.agent_time)}\n" + if len(self.user_time): + s += f"[USER_TIME]: total {sum(self.user_time)}, avg {sum(self.user_time) / len(self.user_time)}\n" + if self.error: + s += f"[ERROR]: {self.error}\n" + return s + + +# ================================ config ================================ +# Customized config for tau2, add env config +@dataclass +class Tau2EnvConfig: + domain: str = field( + default="telecom", + metadata={ + "help": "The tau2 domain name, e.g., 'retail', 'airline', 'telecom'." + }, + ) + max_steps: int = field( + default=100, metadata={"help": "Maximum number of steps per episode."} + ) + add_thinking_tool: bool = field( + default=True, metadata={"help": "Whether to add a thinking tool."} + ) + solo_mode: bool = field( + default=False, metadata={"help": "Whether to use solo mode."} + ) + user_llm_base_url: str | None = field( + default=None, + metadata={"help": "The base URL of the user LLM."}, + ) + user_llm: str | None = field( + default=None, + metadata={"help": "The user LLM to use, default to the gpt-4.1 model."}, + ) + user_llm_args: dict | None = field( + default=None, metadata={"help": "The arguments for the user LLM."} + ) + turn_discount: float = field( + default=1.0, metadata={"help": "Discount factor for turn-based learning."} + ) + invalid_format_penalty: float = field( + default=0.1, metadata={"help": "Penalty for invalid format in completions."} + )