-
Notifications
You must be signed in to change notification settings - Fork 257
[WIP] example: tau2-bench training #740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @dhh1995, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new, self-contained example for training customer service agents within the tau2-bench simulation framework, leveraging AReaL's reinforcement learning capabilities. It includes necessary adjustments to the proxy server's reward handling for improved robustness and flexibility, alongside a new workflow to orchestrate training episodes. The addition of dedicated configuration and utility files streamlines the setup and execution of these complex training scenarios, making it easier to reproduce and extend the example. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new example for training customer service agents using the tau2-bench, which is a significant and valuable addition. The overall structure is well-designed, with clear separation of concerns between the agent logic, training script, and utilities. The modifications to existing proxy utilities, such as improved error handling and more flexible function returns, are also commendable. My review focuses on enhancing the new example files by addressing minor issues in documentation, improving code clarity, and suggesting small optimizations for better maintainability.
|
|
||
| The code is modified from the [proxy](../experimental/proxy/README.md) example so that the training workflow (`tau2_train.py`) and the agent runner script (`tau2_agent.py`) can be decoupled, with common utilities in `tau2_utils.py`. | ||
|
|
||
| * `tau2_train.py`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| ## Notes | ||
|
|
||
| 1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #17813](https://github.com/BerriAI/litellm/issues/17813). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a typo in the litellm issue number. Issue #17813 does not exist. The correct issue number is likely #1781, which discusses the Queue bound to different event loop error. Please correct the link to ensure it points to the correct resource.
| 1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #17813](https://github.com/BerriAI/litellm/issues/17813). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed. | |
| 1. When using litellm with multiprocessing, the `Queue bound to different event loop` error may occur. See also: [litellm issue #1781](https://github.com/BerriAI/litellm/issues/1781). This will not stop the training, but will make the outputs hard to read. You may use `grep -aivE "loop|queue|\^|asyncio|litellm"` to filter out the error messages before this issue is fixed. |
| tasks: list[Task] = registry.get_tasks_loader(domain)(split) | ||
| for task in tasks: | ||
| if task.id == task_id: | ||
| return task | ||
| raise ValueError(f"No task found with id {task_id} for domain {domain}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation iterates through the list of tasks to find a match, which has a time complexity of O(n). For better performance, especially if the number of tasks grows, consider converting the list of tasks into a dictionary for O(1) lookups.
| tasks: list[Task] = registry.get_tasks_loader(domain)(split) | |
| for task in tasks: | |
| if task.id == task_id: | |
| return task | |
| raise ValueError(f"No task found with id {task_id} for domain {domain}") | |
| tasks: list[Task] = registry.get_tasks_loader(domain)(split) | |
| task_map = {task.id: task for task in tasks} | |
| if task_id not in task_map: | |
| raise ValueError(f"No task found with id {task_id} for domain {domain}") | |
| return task_map[task_id] |
| # * Backup: use acreate to replace acompletion | ||
| # async def _acreate(*args, **kwargs): | ||
| # kwargs.pop("num_retries", None) | ||
| # completion = await client.chat.completions.create(*args, **kwargs) | ||
| # return completion | ||
|
|
||
| # async def _acreate_with_base_url(*args, **kwargs): | ||
| # kwargs.pop("num_retries", None) | ||
| # async with AsyncOpenAI(base_url=self.econfig.user_llm_base_url) as client: | ||
| # completion = await client.chat.completions.create(*args, **kwargs) | ||
| # return completion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| # Dump info to file | ||
| if "task_id" in data: | ||
| real_task_id = data["task_id"][:120] + "-" + task_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Description
This example demonstrates how to train customer service agents using the $\tau^2$-Bench with AReaL's PPO/GRPO training pipeline. The$\tau^2$ -Bench provides realistic customer service simulation environments across multiple domains (retail, airline, telecom) where agents must help with user's request by both using agent tools and guiding users using their tools.
Curve for training reward on telecom-small subset.
Related Issue
Fixes #(issue)
Type of Change
work as expected)
Checklist
jb build docs/gemini review)Breaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!