Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions areal/experimental/openai/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,20 @@ async def set_reward(request: AReaLSetRewardRequest, session_id: str):
raise HTTPException(
status_code=400, detail=f"Session {session_id} not found"
)
if interaction_id is None:
# take the last interaction id
interaction_id = state.session_cache[
session_id
].completions.last_interaction_id

completions = state.session_cache[session_id].completions
if interaction_id not in completions:
if interaction_id is None:
# take the last interaction id
if len(completions) == 0:
logger.error(f"No interactions in session {session_id}")
raise HTTPException(
status_code=400, detail="No interactions in session"
)
interaction_id = completions.last_interaction_id
elif interaction_id not in completions:
logger.error(
f"Interaction {interaction_id} not found in session {session_id}"
)
raise HTTPException(
status_code=400, detail=f"Interaction {interaction_id} not found"
)
Expand Down
20 changes: 18 additions & 2 deletions areal/utils/proxy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ async def _set_reward(
url: str = RL_SET_REWARD_PATHNAME,
):
payload = AReaLSetRewardRequest(interaction_id=interaction_id, reward=reward)
await post_json_with_retry(http_session, url=url, payload=payload)
try:
await post_json_with_retry(http_session, url=url, payload=payload)
except aiohttp.ClientResponseError as e:
if e.status == 400:
logger.error(f"[error code {e.status}] Error setting reward: {e.message}")
else:
raise e


async def set_interaction_reward(
Expand Down Expand Up @@ -196,7 +202,16 @@ def _get_float_reward(reward: float | int):
)

async with aiohttp.ClientSession(base_url) as session:
rewards = await func(data)
info = None
results = await func(data)
if isinstance(results, tuple):
if len(results) != 2:
raise ValueError(
f"Results must be a tuple of (rewards, info), got {len(results)}"
)
rewards, info = results
else:
rewards = results

if isinstance(rewards, dict):
for interaction_id, reward in rewards.items():
Expand All @@ -212,3 +227,4 @@ def _get_float_reward(reward: float | int):
reward=_get_float_reward(rewards),
url=pathname,
)
return info
77 changes: 77 additions & 0 deletions examples/tau2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Customer Service Agent Training with Tau2 Benchmark

## Overview

This example demonstrates how to train customer service agents using the [$\tau^2$-Bench](https://github.com/sierra-research/tau2-bench) with AReaL's PPO/GRPO training pipeline. The $\tau^2$-Bench provides realistic customer service simulation environments across multiple domains (retail, airline, telecom) where agents must help with user's request by both using agent tools and guiding users using their tools.

## Code Architecture

The code is modified from the [proxy](../experimental/proxy/README.md) example so that the training workflow (`tau2_train.py`) and the agent runner script (`tau2_agent.py`) can be decoupled, with common utilities in `tau2_utils.py`.

* `tau2_train.py`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The description for tau2_train.py is incomplete. Please add a brief explanation of its role to improve the documentation's clarity and help users understand the example's structure.

* `tau2_agent.py`: Reuse the orchestrator, agent and user simulator from the tau2-bench package to build the runner.

## Running the Example

### Prerequisites

Please make sure AReaL is setup and working following the [installation guide](https://inclusionai.github.io/AReaL/tutorial/installation.html).

1. Install the (forked) tau2-bench package:
```bash
pip install git+https://github.com/dhh1995/tau2-bench.git@dhh/async-and-custom-completion
```
Note that the training relies on the async version of the agent and user simulator in the tau2-bench package. These changes will be merged into the [original tau2-bench repository](https://github.com/sierra-research/tau2-bench) later.

2. setup TAU2_DATA_DIR environment variable.
```bash
export TAU2_DATA_DIR=/path/to/tau2-bench/data
```

### Basic Training Command

1. Prepare the user simulator server.

You need to first setup a user simulator server if using self-hosted LLMs. For example when [using Qwen with SGLang](https://qwen.readthedocs.io/en/latest/deployment/sglang.html):
```bash
python3 -m sglang.launch_server --model-path Qwen/Qwen3-32B --host 0.0.0.0 --tool-call-parser qwen25 --chat-template ./qwen3_nonthinking.jinja --dp-size 8
```

Below we assume the hosted address is http://0.0.0.0:30000/v1/.

2. Run the training.

In this example, we use a `small` subset of the tau2-telecom domain, which contains 20 tasks where each task only contains one subtask.

```bash
python3 -m areal.launcher.ray examples/tau2/tau2_train.py \
--config examples/tau2/config.yaml \
experiment_name=tau2-grpo \
trial_name=trial0 \
cluster.n_nodes=3 \
cluster.n_gpus_per_node=8 \
allocation_mode=sglang:d16+megatron:d2p4 \
gconfig.n_samples=16 \
actor.path=Qwen/Qwen2.5-7B-Instruct \
econfig.domain=telecom \
econfig.max_steps=30 \
train_dataset.path=tau2/small \
train_dataset.batch_size=8 \
user_llm_base_url=http://0.0.0.0:30000/v1/
```

It uses 2 nodes for rollout, 1 node for training and 1 node for user simulator.
The training data batch size is 8 and group size is 16, resulting in 128 rollouts per step.

### Curve

The rollout reward for the training tasks are shown below.

![Curve](./curve.png)

For the above example configuration, it usually takes about less than 10 minutes in average (depending on the hardware) for one step.

## Notes

1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #17813](https://github.com/BerriAI/litellm/issues/17813). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There appears to be a typo in the litellm issue number. Issue #17813 does not exist. The correct issue number is likely #1781, which discusses the Queue bound to different event loop error. Please correct the link to ensure it points to the correct resource.

Suggested change
1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #17813](https://github.com/BerriAI/litellm/issues/17813). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed.
1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #1781](https://github.com/BerriAI/litellm/issues/1781). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed.

2. The trajectories will be dumped as `json` and `txt` files in the `generated/` directory. You may read and analyze the trajectories as your need.
174 changes: 174 additions & 0 deletions examples/tau2/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
experiment_name: tau2-grpo
trial_name: trial0

seed: 1
enable_offload: false
total_train_epochs: 50
tokenizer_path: ${actor.path}

do_eval: false
export_style: concat

cluster:
n_nodes: 3
n_gpus_per_node: 8
fileroot: /tmp/areal/experiments
name_resolve:
type: nfs
nfs_record_root: /tmp/areal/name_resolve

allocation_mode: sglang:d16+megatron:d2p4

rollout:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
max_concurrent_rollouts: 256
queue_size: null
consumer_batch_size: ${train_dataset.batch_size}
max_head_offpolicyness: 2
enable_rollout_tracing: false

gconfig:
n_samples: 16
min_new_tokens: 0
max_new_tokens: 512
greedy: false
temperature: 1.0

actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: Qwen/Qwen2.5-7B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 32000
optimizer:
type: adam
lr: 5e-6
weight_decay: 0.01
beta1: 0.9
beta2: 0.999
eps: 1e-8
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
group_size: ${gconfig.n_samples}
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
reward_bias: -0.5
kl_ctl: 0.0
ppo_n_minibatches: 1
recompute_logprob: true
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
dynamic_sampling: false
reward_norm:
mean_level: group
std_level: group
group_size: ${gconfig.n_samples}
adv_norm:
mean_level: batch
std_level: batch
max_new_tokens: ${gconfig.max_new_tokens}

ref:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: ${actor.path}
init_from_scratch: false
disable_dropout: true
dtype: ${actor.dtype}
mb_spec:
max_tokens_per_mb: 32000
optimizer: null

econfig:
domain: telecom
max_steps: 50
add_thinking_tool: true
solo_mode: false
user_llm_base_url: null # replace with your URL for the user LLM
user_llm: null # replace with your model name. Use 'openai/' for self-hosted openai-compatible server, e.g. 'openai/hosted'
user_llm_args:
temperature: 0.0
max_completion_tokens: 512
turn_discount: 1.0
invalid_format_penalty: 0.1

# SGLang
sglang:
model_path: ${actor.path}
random_seed: ${seed}
skip_tokenizer_init: true
dtype: ${actor.dtype}
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.8

vllm:
model: ${actor.path}
seed: ${seed}
skip_tokenizer_init: false
dtype: ${actor.dtype}
max_model_len: 32768
gpu_memory_utilization: 0.9

# datasets
train_dataset:
batch_size: 8
pin_memory: true
num_workers: 4
path: tau2/train
type: rl
max_length: 1024

valid_dataset:
batch_size: 16
pin_memory: true
num_workers: 4
path: tau2/test
type: rl
drop_last: false

# Utilities
saver:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null

recover:
mode: disabled
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: 3600

evaluator:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
freq_epochs: 1
freq_steps: null
freq_secs: null

stats_logger:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
fileroot: ${cluster.fileroot}
wandb:
mode: disabled

launcher:
inference_server_cpus_per_gpu: 4
inference_server_mem_per_gpu: 32768
trainer_cpus_per_gpu: 4
trainer_mem_per_gpu: 32768
Binary file added examples/tau2/curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading