diff --git a/areal/dataset/__init__.py b/areal/dataset/__init__.py index 3258c6087..dfa90c665 100644 --- a/areal/dataset/__init__.py +++ b/areal/dataset/__init__.py @@ -10,7 +10,14 @@ from transformers.processing_utils import ProcessorMixin from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -VALID_DATASETS = ["gsm8k", "clevr_count_70k", "geometry3k", "hh-rlhf", "torl_data"] +VALID_DATASETS = [ + "gsm8k", + "clevr_count_70k", + "geometry3k", + "hh-rlhf", + "torl_data", + "terminal_bench", +] logger = logging.getLogger("Dataset") @@ -24,7 +31,6 @@ def _get_custom_dataset( processor: Optional["ProcessorMixin"] = None, **kwargs, ) -> "Dataset": - if "gsm8k" in path and type == "sft": from .gsm8k import get_gsm8k_sft_dataset @@ -105,6 +111,16 @@ def _get_custom_dataset( max_length=max_length, **kwargs, ) + elif "terminal_bench" in path and type == "rl": + from .terminal_bench import get_terminal_bench_rl_dataset + + return get_terminal_bench_rl_dataset( + path=path, + split=split, + tokenizer=tokenizer, + max_length=max_length, + **kwargs, + ) else: raise ValueError( f"Dataset {path} with split {split} and training type {type} is not supported. " diff --git a/areal/dataset/terminal_bench.py b/areal/dataset/terminal_bench.py new file mode 100644 index 000000000..bc351936b --- /dev/null +++ b/areal/dataset/terminal_bench.py @@ -0,0 +1,60 @@ +from typing import TYPE_CHECKING + +from datasets import load_dataset + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerFast + + +def get_terminal_bench_rl_dataset( + path: str, + split: str, + tokenizer: "PreTrainedTokenizerFast", + max_length: int | None = None, +): + """Load terminal-bench dataset for RL training. + + The dataset should be in parquet format with the following columns: + - prompt: The formatted prompt for the task + - task_name: Name of the task + - instruction: Raw instruction text + - extra_info: JSON string containing task metadata + """ + # Load from parquet file + dataset = load_dataset("parquet", data_files={split: path}, split=split) + + # The dataset already has the right format from the converter: + # - prompt: contains the formatted conversation + # - task_name, instruction, extra_info: metadata fields + + # For RL training, we need to extract messages from the prompt or extra_info + def process(sample): + # The prompt is already formatted, but we need to extract the instruction + # to create a messages structure for the workflow + instruction = sample.get("instruction", "") + task_name = sample.get("task_name", "") + dockerfile_contents = sample.get("dockerfile_contents", "") + + # Return data in the format expected by the workflow + return { + "instruction": instruction, + "task_name": task_name, + "dockerfile_contents": dockerfile_contents, + "extra_info": sample.get("extra_info", ""), + "data_source": sample.get("data_source", "terminal_bench"), + } + + dataset = dataset.map(process) + + # Filter out sequences longer than max_length if specified + if max_length is not None: + + def filter_length(samples): + # Tokenize instructions in batches for efficiency + instructions = samples["instruction"] + tokens_list = tokenizer(instructions, add_special_tokens=False)["input_ids"] + return [len(tokens) <= max_length for tokens in tokens_list] + + dataset = dataset.filter(filter_length, batched=True) + + return dataset diff --git a/areal/experimental/openai/client.py b/areal/experimental/openai/client.py index 18a2579d8..6dd3c321a 100644 --- a/areal/experimental/openai/client.py +++ b/areal/experimental/openai/client.py @@ -15,7 +15,6 @@ from openai.types.chat import ( ChatCompletion, ChatCompletionMessage, - ChatCompletionToolMessageParam, ChatCompletionToolParam, ) from openai.types.chat.chat_completion import Choice @@ -277,22 +276,11 @@ async def create( if is_omitted(input): raise ValueError("input is required for Responses.create") - def _convert_tool_output_format( - item: dict, - ) -> ChatCompletionToolMessageParam | dict: + def _convert_tool_output_format(item: dict) -> dict: """Convert custom tool output format to standard chat template format. - Converts openai.types.responses.response_input_item_param.FunctionCallOutput - to openai.types.chat.ChatCompletionToolMessageParam. - - Args: - item: Input dict, could be FunctionCallOutput from openai-agents SDK - with format: {'call_id': str, 'output': str, 'type': 'function_call_output'} - - Returns: - ChatCompletionToolMessageParam (TypedDict) with format: - {'role': 'tool', 'content': str, 'tool_call_id': str} - or the original dict if conversion is not needed. + Converts from: {'call_id': ..., 'output': ..., 'type': 'function_call_output'} + To: {'role': 'tool', 'content': ..., 'tool_call_id': ...} """ if ( isinstance(item, dict) diff --git a/assets/qwen3_8b_terminal.png b/assets/qwen3_8b_terminal.png new file mode 100644 index 000000000..236d1d96e Binary files /dev/null and b/assets/qwen3_8b_terminal.png differ diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/openai-agents/agent_terminal_workflow.py b/examples/openai-agents/agent_terminal_workflow.py new file mode 100644 index 000000000..21a037257 --- /dev/null +++ b/examples/openai-agents/agent_terminal_workflow.py @@ -0,0 +1,202 @@ +import asyncio +import logging +import os + +from agents import Agent as OpenAIAgent +from agents import ModelSettings, OpenAIProvider, RunConfig, SQLiteSession +from agents import Runner as OpenAIRunner +from terminal.env import TerminalEnv +from terminal.judge_agent import JudgeAgent, judge_from_env +from terminal.prompt import SYSTEM_PROMPT +from transformers import PreTrainedTokenizerFast + +from areal.api.cli_args import GenerationHyperparameters +from areal.api.workflow_api import RolloutWorkflow +from areal.experimental.openai import ArealOpenAI +from areal.utils import stats_tracker + +logger = logging.getLogger(__name__) + + +class TerminalAgent: + def __init__( + self, + tokenizer: PreTrainedTokenizerFast, + max_tokens_per_turn: int = 1024, + max_turns: int = 8, + max_total_tokens: int = 32768, + dump_dir: str | None = None, + rollout_stat_scope: str = "rollout", + ): + self.tokenizer = tokenizer + self.max_tokens_per_turn = max_tokens_per_turn + self.max_turns = max_turns + self.max_total_tokens = max_total_tokens + self.dump_dir = dump_dir + self.rollout_stat_scope = rollout_stat_scope + + async def run_agent(self, data, client: ArealOpenAI, judge_agent: JudgeAgent): + """Run the agent workflow for terminal task execution.""" + run_config = RunConfig( + model_provider=OpenAIProvider( + openai_client=client, + use_responses=True, + ), + tracing_disabled=True, + model_settings=ModelSettings( + temperature=1.0, + extra_args={"max_completion_tokens": self.max_tokens_per_turn}, + tool_choice="auto", + store=True, + ), + ) + + async with TerminalEnv( + task_name=data["task_name"], + dump_dir=self.dump_dir, + rollout_stat_scope=self.rollout_stat_scope, + ) as env: + # Create agent workflow with terminal tools + agent = OpenAIAgent( + name="Terminal Task Agent", + instructions=SYSTEM_PROMPT, + tools=env.get_tools(), + ) + session = SQLiteSession("terminal") + content = data["instruction"] + + max_attempts = self.max_turns + reward = 0 + judge_reward = 0 + tracker = stats_tracker.get(self.rollout_stat_scope) + + with tracker.record_timing("run_agent_total"): + error_count = 0.0 + attempts_used = 0.0 + for attempt in range(max_attempts): + attempts_used = float(attempt + 1) + try: + with tracker.record_timing("openai_runner_run"): + result = await OpenAIRunner.run( + agent, + input=content, + session=session, + run_config=run_config, + max_turns=30, + ) + except Exception as e: + logger.error(f"Error running agent: {e}") + error_count += 1.0 + break + + with tracker.record_timing("env_validate_reward"): + reward = env.reward() + if judge_agent: + with tracker.record_timing("judge_agent_reward"): + judge_reward = await judge_agent.get_reward_from_judge( + session=session, + dockerfile_contents=data["dockerfile_contents"], + ) + if judge_reward >= 0 and reward < 0.99: + reward = reward * 0.65 + judge_reward * 0.35 + + tracker.scalar( + reward=reward, + judge_reward=judge_reward, + attempt_index=float(attempt), + input_chars=float(len(content) if content else 0.0), + output_chars=float( + len(getattr(result, "final_output", "") or "") + ), + ) + + if isinstance(reward, float) and reward >= 0.99: + tracker.scalar(success=1.0) + break + + if attempt < max_attempts - 1: + content = f"""The previous attempt didn't complete the task successfully. + Please try a different approach. + Original task: {data["instruction"]} + + Previous attempt result: {result.final_output} + + Please analyze what went wrong and try again with a corrected approach.""" + else: + content = f"""This is your final attempt. Please be extremely careful. + Original task: {data["instruction"]} + + Previous attempts: {result.final_output} + + Please provide a final, carefully executed solution.""" + tracker.scalar(success=0.0) + + tracker.scalar( + final_reward=reward, attempts_used=attempts_used, errors=error_count + ) + + client.set_final_reward(reward) + + return reward + + +class TerminalAgentWorkflow(RolloutWorkflow): + def __init__( + self, + gconfig: GenerationHyperparameters, + tokenizer: PreTrainedTokenizerFast, + dump_dir: str | None = None, + rollout_stat_scope: str = "rollout", + n_trajs: int = 1, + max_tokens: int = 32768, + max_turns: int = 8, + ): + self.gconfig = gconfig + self.gconfig.n_samples = 1 + self.tokenizer = tokenizer + self.dump_dir = dump_dir + self.max_tokens = max_tokens + self.rollout_stat_scope = rollout_stat_scope + if self.dump_dir is not None and not os.path.exists(self.dump_dir): + os.makedirs(self.dump_dir, exist_ok=True) + + # Search hyper-parameters + self.n_trajs = n_trajs + self.agent = TerminalAgent( + tokenizer=self.tokenizer, + max_tokens_per_turn=self.gconfig.max_new_tokens, + max_turns=max_turns, + max_total_tokens=max_tokens, + dump_dir=self.dump_dir, + rollout_stat_scope=self.rollout_stat_scope, + ) + self.judge_agent = judge_from_env() + + async def arun_episode(self, engine, data): + clients = [ + ArealOpenAI( + engine=engine, tokenizer=self.tokenizer, tool_call_parser="qwen25" + ) + for _ in range(self.n_trajs) + ] + + # Collect trajectories + rewards = await asyncio.gather( + *[ + self.agent.run_agent( + data=data, + client=clients[i], + judge_agent=self.judge_agent, + ) + for i in range(self.n_trajs) + ] + ) + for reward in rewards: + stats_tracker.get(self.rollout_stat_scope).scalar(reward=reward) + + interactions_with_reward = {} + for client in clients: + client.apply_reward_discount(turn_discount=0.9) + interactions = client.export_interactions(style="individual") + interactions_with_reward.update(interactions) + return interactions_with_reward diff --git a/examples/openai-agents/config.yaml b/examples/openai-agents/config.yaml index 7be7a094f..09b2eddd0 100644 --- a/examples/openai-agents/config.yaml +++ b/examples/openai-agents/config.yaml @@ -19,7 +19,7 @@ cluster: type: nfs nfs_record_root: /tmp/areal/name_resolve -allocation_mode: sglang.d4p1t1+d4p1t1 +allocation_mode: sglang.d4p1t1+d1p1t1c4 rollout: experiment_name: ${experiment_name} diff --git a/examples/openai-agents/terminal/README.md b/examples/openai-agents/terminal/README.md new file mode 100644 index 000000000..2b71f96d9 --- /dev/null +++ b/examples/openai-agents/terminal/README.md @@ -0,0 +1,124 @@ +# Guidance for starting training a terminal agent + +## Overview + +The terminal agent training system consists of two separate instances: + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Training Instance (Has GPU) │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ AReaL Training Pipeline │ │ +│ │ - Actor Model (Inference) │ │ +│ │ - Rollout Workers │ │ +│ │ - Dataset Loaders │ │ +│ │ - PPO/GRPO Training │ │ +│ └────────────────┬─────────────────────────────────────────┘ │ +│ │ │ +│ │ HTTP Requests │ +│ │ (Tool Calls) │ +│ ▼ │ +└───────────────────┼─────────────────────────────────────────────┘ + │ + │ +┌───────────────────┼─────────────────────────────────────────────┐ +│ │ MCP Server Instance (No GPU) │ +│ ┌────────────────▼─────────────────────────────────────────┐ │ +│ │ MCP Server (Flask) │ │ +│ │ - HTTP/SSE Endpoints │ │ +│ │ - Task Container Management │ │ +│ └──────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────┐ │ +│ │ Docker-in-Docker (DinD) │ │ +│ │ ┌────────────┐ ┌────────────┐ ┌────────────┐ │ │ +│ │ │ Container │ │ Container │ │ Container │ ... │ │ +│ │ │ (Task 1) │ │ (Task 2) │ │ (Task 3) │ │ │ +│ │ │ + tmux │ │ + tmux │ │ + tmux │ │ │ +│ │ └────────────┘ └────────────┘ └────────────┘ │ │ +│ └──────────────────────────────────────────────────────────┘ │ +│ │ +└──────────────────────────────────────────────────────────────────┘ +``` + +**Key Points:** + +- **Training Instance**: Runs the RL training loop, manages rollouts, has GPU +- **MCP Server Instance**: Hosts terminal environments in Docker containers +- **Communication**: Training instance sends tool calls via HTTP to MCP server +- **Isolation**: Each task runs in its own Docker container with tmux session + +## Step 1 Start mcp server: + +1. Pull image: `docker pull docker.io/cruizba/ubuntu-dind:latest` as mcp server + +1. Enter mcp server and configure server + +```bash +# Prepare task data +# Install https://github.com/laude-institute/terminal-bench +# Download datasets to specified directory, like /data/tb: +tb datasets download -d terminal-bench-core --output-dir /data/tb + +# Reset docker config +apt-get update +apt-get install -y python3-pip +update-alternatives --install /usr/bin/python python /usr/bin/python3 10 +pip install --break-system-packages requests terminal-bench flask +apt-get install -y fuse-overlayfs +mkdir -p /etc/docker +tee /etc/docker/daemon.json > /dev/null <<'EOF' +{ + "default-address-pools": [ + { + "base": "172.16.0.0/12", + "size": 24 + } + ] +} +EOF + +# Restart docker +supervisorctl shutdown +start-docker.sh + +# Start mcp server +python -m examples.openai-agents.terminal.server --tasks-dir /storage/openpsi/codes/puzhen.pz/terminal-bench-core/easy/ --tasks-log-dir /data/tb/logs +``` + +2. Test mcp server + +```bash +python examples/openai-agents/terminal/test_client.py +``` + +## Step 2 Train agent + +1. Prepare datasets + +```bash +# Install https://github.com/laude-institute/terminal-bench +# Download datasets to specified directory, like /data/tb: +tb datasets download -d terminal-bench-core --output-dir /data/tb + +# Convert datasets to parquet format +python examples/openai-agents/terminal/tasks_to_parquet_converter.py --tasks-dir /data/tb/tasks --output-dir /tmp/terminal_bench/easy-data/ +# Make sure the output-dir contains `terminal_bench` in its name +# Example: /tmp/terminal_bench/easy-data/train.parquet +``` + +2. Start training task locally + +```bash +# Download Qwen3-4B-Thinking-2507 from huggingface + +huggingface-cli download Qwen/Qwen3-4B-Thinking-2507 --local-dir /storage/models/Qwen3-4B-Thinking-2507 +## set mcp server address +export MCP_SERVER_URL=http://$MCP_SERVER_ADDRESS +python3 -m areal.launcher.local examples/openai-agents/train_agents.py --config examples/openai-agents/config.yaml actor.path=/storage/models/Qwen3-4B-Thinking-2507 train_dataset.path=/tmp/terminal_bench/easy-data/train.parquet train_dataset.batch_size=4 gconfig.n_samples=4 trial_name=oss-qwen25-7b agent_type=multi_agent_terminal experiment_name=openai-agents-terminal n_trajs=1 max_turns=3 valid_dataset.path=/tmp/terminal_bench/easy-data/val.parquet stats_logger.wandb.mode=online +``` + +## Experiment Record + +![experiment record](../../../assets/qwen3_8b_terminal.png) diff --git a/examples/openai-agents/terminal/__init__.py b/examples/openai-agents/terminal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/openai-agents/terminal/env.py b/examples/openai-agents/terminal/env.py new file mode 100644 index 000000000..8960fde52 --- /dev/null +++ b/examples/openai-agents/terminal/env.py @@ -0,0 +1,477 @@ +import asyncio +import json +import logging +import os +import re +import uuid +from pathlib import Path +from typing import Any + +import aiofiles +import aiofiles.os +import requests +from agents import FunctionTool, RunContextWrapper +from mcp import ClientSession +from mcp.client.sse import sse_client +from pydantic import BaseModel, Field + +from areal.utils import stats_tracker + +logger = logging.getLogger(__name__) + + +class TerminalEnv: + """Simple context manager for managing terminal container lifecycle.""" + + def __init__( + self, + task_name: str = None, + dump_dir: str | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + rollout_stat_scope: str = "rollout", + dockerfile_contents: str | None = None, + ): + self.base_url = os.environ.get("MCP_SERVER_URL", "http://localhost:8000") + self.task_name = task_name + self.container_name = None + self.uuid = str(uuid.uuid4()) + self.dump_dir = dump_dir + self._mcp_session = None + self._sse_exit_stack = None + self.max_retries = max_retries + self.retry_delay = retry_delay + self._connection_lock = asyncio.Lock() + self.rollout_stat_scope = rollout_stat_scope + self.dockerfile_contents = dockerfile_contents + + async def _log_tool_call( + self, tool_name: str, arguments: dict, result: str, container_name: str + ): + """Log tool call to file asynchronously. + + Args: + tool_name: Name of the tool being called + arguments: Arguments passed to the tool + result: Result returned by the tool + container_name: Name of the container for the log filename + """ + if self.dump_dir is not None: + dump_path = Path(self.dump_dir) / "terminal" + await aiofiles.os.makedirs(dump_path, exist_ok=True) + log_file = dump_path / f"{container_name}.jsonl" + async with aiofiles.open(log_file, "a", encoding="utf-8") as f: + log_entry = { + "tool_name": tool_name, + "arguments": arguments, + "result": result, + } + await f.write(json.dumps(log_entry, ensure_ascii=False) + "\n") + + def __enter__(self) -> "TerminalEnv": + """Start the task container.""" + payload = {"uuid": self.uuid, "task_name": self.task_name} + try: + response = requests.post( + f"{self.base_url}/tasks/start", json=payload, timeout=360 + ) + response.raise_for_status() + data = response.json() + self.container_name = data["container_name"] + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to start task container: {e}") + return self + + async def _connect_mcp(self) -> None: + """Establish MCP connection with retry logic.""" + from contextlib import AsyncExitStack + + for attempt in range(self.max_retries): + try: + if self._sse_exit_stack is None: + self._sse_exit_stack = AsyncExitStack() + + logger.info( + f"Connecting to MCP server (attempt {attempt + 1}/{self.max_retries})..." + ) + read, write = await self._sse_exit_stack.enter_async_context( + sse_client( + url=f"{self.base_url}/sse", + timeout=60 * 3, + ) + ) + self._mcp_session = await self._sse_exit_stack.enter_async_context( + ClientSession(read, write) + ) + await self._mcp_session.initialize() + logger.info("MCP connection established successfully") + return + except Exception as e: + logger.warning(f"MCP connection attempt {attempt + 1} failed: {e}") + if attempt < self.max_retries - 1: + await asyncio.sleep( + self.retry_delay * (attempt + 1) + ) # Exponential backoff + # Clean up failed connection attempt + if self._sse_exit_stack: + try: + await self._sse_exit_stack.aclose() + except Exception: + pass + self._sse_exit_stack = None + self._mcp_session = None + else: + raise RuntimeError( + f"Failed to connect to MCP server after {self.max_retries} attempts: {e}" + ) + + async def _reconnect_mcp(self) -> None: + """Reconnect to MCP server after connection loss.""" + async with self._connection_lock: + logger.info("Reconnecting to MCP server...") + # Clean up old connection + if self._sse_exit_stack: + try: + await self._sse_exit_stack.aclose() + except Exception as e: + logger.warning(f"Error closing old connection: {e}") + self._sse_exit_stack = None + self._mcp_session = None + + # Establish new connection + await self._connect_mcp() + + async def __aenter__(self) -> "TerminalEnv": + """Async context manager entry - start container and MCP session.""" + self.__enter__() + # Initialize persistent MCP connection with retry + await self._connect_mcp() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop the task container on exit.""" + if self.container_name: + try: + payload = {"uuid": self.uuid, "task_name": self.task_name} + response = requests.post( + f"{self.base_url}/tasks/stop", json=payload, timeout=30 + ) + response.raise_for_status() + except Exception as e: + print(f"Warning: Failed to stop task {self.task_name}: {e}") + return False # Don't suppress exceptions + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit - cleanup MCP session and container.""" + if self._sse_exit_stack: + await self._sse_exit_stack.aclose() + self.__exit__(exc_type, exc_val, exc_tb) + return False + + async def _call_mcp_tool_with_retry(self, tool_name: str, arguments: dict) -> str: + """Call MCP tool with retry logic and error handling. + + Args: + tool_name: Name of the MCP tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool result as string + + Raises: + RuntimeError: If tool call fails after all retries + """ + for attempt in range(self.max_retries): + try: + result = await self._mcp_session.call_tool( + name=tool_name, + arguments=arguments, + ) + output = ( + result.content[0].text + if result.content and len(result.content) > 0 + else "" + ) + return output + except Exception as e: + logger.warning(f"Tool call attempt {attempt + 1} failed: {e}") + if attempt < self.max_retries - 1: + try: + await self._reconnect_mcp() + except Exception as reconnect_error: + logger.error(f"Reconnection failed: {reconnect_error}") + if attempt == self.max_retries - 2: + raise + else: + raise RuntimeError( + f"Failed to call {tool_name} after {self.max_retries} attempts: {e}" + ) + + def get_tools(self) -> list[FunctionTool]: + """Create async function tools for terminal interaction. + + Returns: + List of async function tools for keystrokes and capture-pane operations. + """ + if not self.container_name: + raise RuntimeError( + "Container not started. Use TerminalEnv as context manager first." + ) + + def _format_terminal_output(output: str, keystrokes: str = "") -> str: + """Clean terminal output by removing shell prompts and the exact keystrokes command. + + Args: + output: Raw terminal output + keystrokes: The exact command that was sent to remove from output + """ + if not output: + return output + + lines = output.split("\n") + cleaned = [] + keystrokes_clean = keystrokes.strip() + + for line in lines: + # Remove shell prompt + line = re.sub(r"^[a-zA-Z0-9@]+:[^#]*#\s*", "", line) + + # Remove the exact keystrokes command if present at the start + if keystrokes_clean and line.startswith(keystrokes_clean): + line = line[len(keystrokes_clean) :].lstrip() + + # Keep non-empty lines + if line.strip(): + cleaned.append(line.rstrip()) + + return "\n".join(cleaned).strip() + + class ExecuteCommandArgs(BaseModel): + """Arguments for executing a terminal command.""" + + command: str = Field( + description="The terminal command to execute (e.g., 'ls -la', 'python script.py', 'mkdir new_folder')" + ) + wait_time_sec: float = Field( + default=2.0, + description="Time to wait after command execution in seconds. Use longer values for commands that take time to complete.", + ) + + class CurrentStatusArgs(BaseModel): + """Arguments for getting current directory status and file listing.""" + + include_hidden: bool = Field( + default=False, + description="Whether to include hidden files (starting with .) in the file listing. Set to True to see all files including dotfiles.", + ) + + class FileContentsArgs(BaseModel): + """Arguments for viewing file contents with different viewing modes.""" + + absolute_path: str = Field( + description="Absolute path to the file to view (e.g., '/home/user/file.txt', '/var/log/syslog')" + ) + tail_lines: int = Field( + default=0, + description="Number of lines to show from the end of the file. Use tail to see the last N lines (e.g., for logs). Set to 0 to disable tail view.", + ) + head_lines: int = Field( + default=0, + description="Number of lines to show from the beginning of the file. Use head to see the first N lines (e.g., for file headers). Set to 0 to disable head view.", + ) + + async def execute_command(ctx: RunContextWrapper[Any], args: str) -> str: + """Execute a command in the terminal and return the result. + + Args: + args: JSON string containing command and wait_time_sec parameters. + + Returns: + Command output after execution. + """ + try: + parsed_args = ExecuteCommandArgs.model_validate_json(args) + + # Call the simplified MCP wrapper + raw_output = await self._call_mcp_tool_with_retry( + "keystrokes", + { + "container_name": self.container_name, + "keystrokes": parsed_args.command, + "append_enter": True, + "wait_time_sec": parsed_args.wait_time_sec, + }, + ) + + # Clean the output + output = _format_terminal_output(raw_output, parsed_args.command) + + # Log and track metrics + await self._log_tool_call( + tool_name="execute_command", + arguments=parsed_args.model_dump(), + result=output, + container_name=self.container_name, + ) + + tracker = stats_tracker.get(self.rollout_stat_scope) + tracker.scalar( + tool_execute_success=1.0, + tool_execute_input_chars=float(len(parsed_args.command)), + tool_execute_output_chars=float(len(output)), + ) + + return output + except Exception as e: + stats_tracker.get(self.rollout_stat_scope).scalar( + tool_execute_error=1.0 + ) + return f"Error executing command: {e}" + + async def current_working_directory( + ctx: RunContextWrapper[Any], args: str + ) -> str: + """Get current working directory and list files. + + Args: + args: JSON string containing include_hidden parameter. + + Returns: + Current directory information and file listing. + """ + try: + parsed_args = CurrentStatusArgs.model_validate_json(args) + + # Get current working directory + pwd_result = await self._call_mcp_tool_with_retry( + "keystrokes", + { + "container_name": self.container_name, + "keystrokes": "pwd", + "append_enter": True, + "wait_time_sec": 1.0, + }, + ) + + # List files + ls_option = "-alh" if parsed_args.include_hidden else "-lh" + ls_result = await self._call_mcp_tool_with_retry( + "keystrokes", + { + "container_name": self.container_name, + "keystrokes": f"ls {ls_option}", + "append_enter": True, + "wait_time_sec": 1.0, + }, + ) + + # Clean and combine results + clean_pwd = _format_terminal_output(pwd_result, "pwd") + clean_ls = _format_terminal_output(ls_result, f"ls {ls_option}") + output = f"Current directory: {clean_pwd}\n\n Files: {clean_ls}" + + # Log and track metrics + await self._log_tool_call( + tool_name="current_working_directory", + arguments=parsed_args.model_dump(), + result=output, + container_name=self.container_name, + ) + + tracker = stats_tracker.get(self.rollout_stat_scope) + tracker.scalar( + tool_status_success=1.0, + tool_status_output_chars=float(len(output)), + ) + + return output + except Exception as e: + stats_tracker.get(self.rollout_stat_scope).scalar(tool_status_error=1.0) + return f"Error getting current status: {e}" + + async def file_contents(ctx: RunContextWrapper[Any], args: str) -> str: + try: + parsed_args = FileContentsArgs.model_validate_json(args) + + # Determine the command to use based on arguments + if parsed_args.tail_lines > 0: + command = ( + f"tail -n {parsed_args.tail_lines} {parsed_args.absolute_path}" + ) + elif parsed_args.head_lines > 0: + command = ( + f"head -n {parsed_args.head_lines} {parsed_args.absolute_path}" + ) + else: + command = f"cat {parsed_args.absolute_path}" + + # Get file contents + raw_result = await self._call_mcp_tool_with_retry( + "keystrokes", + { + "container_name": self.container_name, + "keystrokes": command, + "append_enter": True, + "wait_time_sec": 1.0, + }, + ) + + # Clean the output + result = _format_terminal_output(raw_result, command) + + # Remove trailing shell prompt if present + result = re.sub(r"\n[a-zA-Z0-9@]+:[^#]*#$", "", result) + + # Log and track metrics + await self._log_tool_call( + tool_name="file_contents", + arguments=parsed_args.model_dump(), + result=result, + container_name=self.container_name, + ) + + tracker = stats_tracker.get(self.rollout_stat_scope) + tracker.scalar( + tool_status_success=1.0, + tool_status_output_chars=float(len(result)), + ) + + return result + except Exception as e: + stats_tracker.get(self.rollout_stat_scope).scalar(tool_status_error=1.0) + return f"Error getting file contents: {e}" + + return [ + FunctionTool( + name="execute_command", + description="Execute a command in the terminal and return the result.", + params_json_schema=ExecuteCommandArgs.model_json_schema(), + on_invoke_tool=execute_command, + ), + FunctionTool( + name="current_working_directory", + description="Get current working directory and list files.", + params_json_schema=CurrentStatusArgs.model_json_schema(), + on_invoke_tool=current_working_directory, + ), + FunctionTool( + name="file_contents", + description="View file contents with different viewing modes (full content, head, or tail).", + params_json_schema=FileContentsArgs.model_json_schema(), + on_invoke_tool=file_contents, + ), + ] + + def reward(self) -> float: + """Reward function for the terminal environment.""" + try: + payload = {"container_name": self.container_name} + response = requests.post( + f"{self.base_url}/tasks/validate", json=payload, timeout=360 + ) + response.raise_for_status() + result = response.json() + return round(float(result["score"]), 2) + except Exception as e: + print(f"Error getting reward from API: {e}") + return 0.0 diff --git a/examples/openai-agents/terminal/judge_agent.py b/examples/openai-agents/terminal/judge_agent.py new file mode 100644 index 000000000..4ec57fbeb --- /dev/null +++ b/examples/openai-agents/terminal/judge_agent.py @@ -0,0 +1,71 @@ +import os + +from agents import Agent, ModelSettings, RunConfig, SQLiteSession +from agents import Runner as OpenAIRunner +from agents.extensions.models.litellm_model import LitellmModel +from pydantic import BaseModel + +from .prompt import JUDGE_PROMPT + + +class JudgeOutput(BaseModel): + score: float + + +class JudgeAgent: + def __init__(self, base_url: str, api_key: str, model_name: str): + self.base_url = base_url + self.api_key = api_key + # Parse model_name as comma-separated list for round-robin + self.model_names = [ + name.strip() for name in model_name.split(",") if name.strip() + ] + self.current_model_index = 0 + + async def get_reward_from_judge( + self, + session: SQLiteSession, + dockerfile_contents: str, + ) -> float: + items = await session.get_items() + + # Round-robin model selection + selected_model = self.model_names[self.current_model_index] + self.current_model_index = (self.current_model_index + 1) % len( + self.model_names + ) + + agent = Agent( + name="JudgeAgent", + instructions=JUDGE_PROMPT, + model=LitellmModel( + model=selected_model, + api_key=self.api_key, + base_url=self.base_url, + ), + output_type=JudgeOutput, + ) + try: + result = await OpenAIRunner.run( + agent, + input=items, + run_config=RunConfig( + tracing_disabled=True, + model_settings=ModelSettings( + temperature=0.0, + ), + ), + ) + judge_output = result.final_output_as(JudgeOutput) + return judge_output.score + except Exception: + return -1 + + +def judge_from_env() -> JudgeAgent | None: + base_url = os.environ.get("LITELLM_API_BASE", None) + api_key = os.environ.get("LITELLM_API_KEY", None) + model_name = os.environ.get("LITELLM_MODEL_NAME", None) + if not base_url or not api_key or not model_name: + return None + return JudgeAgent(base_url, api_key, model_name) diff --git a/examples/openai-agents/terminal/logging_config.py b/examples/openai-agents/terminal/logging_config.py new file mode 100644 index 000000000..f07bde53f --- /dev/null +++ b/examples/openai-agents/terminal/logging_config.py @@ -0,0 +1,30 @@ +import logging +import sys + + +def setup_logging(): + """Setup logging configuration for the MCP server.""" + # Configure root logger + logging.basicConfig( + level=logging.INFO, # Changed from DEBUG to INFO to reduce noise + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + # Suppress noisy third-party loggers + logging.getLogger("mcp.server.lowlevel.server").setLevel(logging.WARNING) + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("litellm").setLevel(logging.WARNING) + + # Ensure terminal_bench loggers also output to stdout + terminal_logger = logging.getLogger("terminal_bench_server") + if not terminal_logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + terminal_logger.addHandler(handler) + terminal_logger.setLevel(logging.DEBUG) + terminal_logger.propagate = False # Prevent duplicate logging to root logger diff --git a/examples/openai-agents/terminal/models.py b/examples/openai-agents/terminal/models.py new file mode 100644 index 000000000..5ff2326c2 --- /dev/null +++ b/examples/openai-agents/terminal/models.py @@ -0,0 +1,793 @@ +"""All types and classes for the terminal-bench server.""" + +import asyncio +import contextvars +import logging +import os +import threading +import warnings +from pathlib import Path +from time import time +from typing import Any + +import docker +from pydantic import BaseModel +from starlette.responses import JSONResponse +from terminal_bench.handlers.trial_handler import Task +from terminal_bench.parsers.base_parser import UnitTestStatus +from terminal_bench.parsers.parser_factory import ParserFactory +from terminal_bench.terminal.docker_compose_manager import DockerComposeManager +from terminal_bench.terminal.tmux_session import TmuxSession +from tqdm import tqdm + +from .logging_config import setup_logging + +# Suppress SQLAlchemy 2.0 deprecation warnings from terminal_bench +warnings.filterwarnings("ignore", category=DeprecationWarning, module="terminal_bench") + +setup_logging() + +logger = logging.getLogger(__name__) + +# Context variable to track request start time +request_start_time = contextvars.ContextVar("request_start_time", default=None) + +# --- Pydantic Models for API and Tools --- + + +class TaskRequest(BaseModel): + uuid: str | None = None + task_name: str | None = None + container_name: str | None = None # Optional, auto-generated from uuid + task_name + + +class TaskContainer(DockerComposeManager): + def __init__( + self, + uuid: str, + task_name: str, + logs_dir: Path, + docker_compose_path: Path, + client_image_name: str, + no_rebuild: bool = False, + cleanup: bool = False, + ): + container_name = f"{uuid}__{task_name}" + + # Initialize parent class with required parameters + super().__init__( + client_container_name=container_name, + client_image_name=client_image_name, + docker_compose_path=docker_compose_path, + no_rebuild=no_rebuild, + cleanup=cleanup, + ) + + # Store additional attributes + self.uuid = uuid + self.task_name = task_name + self.container_name = container_name + self.logs_dir = logs_dir.joinpath(uuid) + ## override env + self.env["T_BENCH_TASK_LOGS_PATH"] = str(self.logs_dir.joinpath("client")) + self.env["T_BENCH_TASK_AGENT_LOGS_PATH"] = str(self.logs_dir.joinpath("agent")) + + +class TerminalBenchServer: + """ + Terminal Bench Server class that manages server state and operations. + Provides thread-safe access to shared resources using asyncio locks. + """ + + _is_running = False + preheat_image = False + + def __init__( + self, + tasks_dir: Path | None = None, + tasks_log_dir: Path | None = None, + preheat_image: bool = False, + ): + # In-memory registry for active tasks and their managers + # Format: {container_name: {"uuid": str, "task_name": str, "last_seen": float, "compose_manager": TaskContainer, "container": Container}} + self.active_tasks: dict[str, dict[str, Any]] = {} + # Cache for TmuxSession objects + self.tmux_sessions: dict[str, TmuxSession] = {} + # Docker client instance + self.docker_client = None + # Path to the tasks directory, needs to be configured + self.tasks_dir = tasks_dir or Path( + os.environ.get("T_BENCH_TASKS_DIR", "/app/tasks") + ) + # Path to the tasks logs directory + self.tasks_log_dir = tasks_log_dir or Path( + os.environ.get("T_BENCH_TASKS_LOG_DIR", "/var/logs/terminal-bench/") + ) + self.preheat_image = preheat_image + # Thread locks for thread-safe access to shared resources + # Using threading.Lock instead of asyncio.Lock for cross-thread safety + self.active_tasks_lock = threading.Lock() + self.tmux_sessions_lock = threading.Lock() + self.garbage_collector_task = None + # Background event loop and thread for GC + self._gc_loop = None + self._gc_thread = None + + def init_images_sync(self): + """Pre-build all Docker images from tasks directory.""" + + print("Initializing Docker images...") + if not self.tasks_dir.exists(): + print(f"Warning: Tasks directory {self.tasks_dir} does not exist") + return + + # Get all task directories that contain docker-compose.yaml + task_dirs = [d for d in self.tasks_dir.iterdir() if d.is_dir()] + total_tasks = len(task_dirs) + built_count = 0 + skipped_count = 0 + failed_count = 0 + + print(f"Found {total_tasks} task directories") + + # Use tqdm for progress tracking + with tqdm(total=total_tasks, desc="Building images", unit="task") as pbar: + for task_dir in task_dirs: + # Quit if server is shutting down + if not self._is_running: + pbar.set_description("Interrupted") + print("\nImage initialization interrupted by shutdown") + break + compose_path = task_dir / "docker-compose.yaml" + if not compose_path.exists(): + skipped_count += 1 + pbar.update(1) + pbar.set_postfix( + { + "built": built_count, + "skipped": skipped_count, + "failed": failed_count, + } + ) + continue + + task_name = task_dir.name + image_name = f"tb__{task_name.replace('.', '-')}__client" + + try: + # Check if image already exists + if self._image_exists(image_name): + pbar.set_description(f"Skipping {task_name[:30]}") + skipped_count += 1 + pbar.update(1) + pbar.set_postfix( + { + "built": built_count, + "skipped": skipped_count, + "failed": failed_count, + } + ) + continue + + pbar.set_description(f"Building {task_name[:30]}") + + # Create a temporary TaskContainer to build the image + temp_manager = TaskContainer( + uuid="temp_build", + task_name=task_name, + logs_dir=self.tasks_log_dir, + docker_compose_path=compose_path, + client_image_name=image_name, + no_rebuild=False, + ) + + # Build the image without starting the container + temp_manager.build() + built_count += 1 + pbar.update(1) + pbar.set_postfix( + { + "built": built_count, + "skipped": skipped_count, + "failed": failed_count, + } + ) + + except Exception as e: + failed_count += 1 + pbar.set_description(f"Failed {task_name[:30]}") + pbar.update(1) + pbar.set_postfix( + { + "built": built_count, + "skipped": skipped_count, + "failed": failed_count, + } + ) + print(f"\nError building {task_name}: {e}") + + print("\nImage initialization complete:") + print(f" Built: {built_count}") + print(f" Skipped: {skipped_count}") + print(f" Failed: {failed_count}") + print(f" Total: {total_tasks}") + + def _run_gc_loop(self): + """Run garbage collector in a separate thread with its own event loop.""" + self._gc_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._gc_loop) + try: + self._gc_loop.run_until_complete(self.garbage_collector()) + except Exception as e: + print(f"Error in GC loop: {e}") + finally: + self._gc_loop.close() + + def startup(self): + """Server startup logic (synchronous).""" + print("Server starting up...") + try: + self.docker_client = docker.from_env() + self._is_running = True + + if self.preheat_image: + self.init_images_sync() + + # Recover active tasks synchronously + self._recover_active_tasks_sync() + + # Start garbage collector in background thread + self._gc_thread = threading.Thread(target=self._run_gc_loop, daemon=True) + self._gc_thread.start() + + print("Server startup complete with following configuration:") + print(f"TASK DIR: {self.tasks_dir}") + print(f"RECOVER TASK NUM: {len(self.active_tasks)}") + print(f"TASK LOGS DIR: {self.tasks_log_dir}") + + except Exception as e: + self._is_running = False + print(f"Error during startup: {e}") + + def shutdown(self): + """Server shutdown logic (synchronous).""" + print("Server shutting down...") + self._is_running = False + + # Stop garbage collector loop + if self._gc_loop and self._gc_loop.is_running(): + self._gc_loop.call_soon_threadsafe(self._gc_loop.stop) + + # Wait for GC thread to finish + if self._gc_thread and self._gc_thread.is_alive(): + self._gc_thread.join(timeout=5.0) + + print("Server shutdown complete.") + + def _recover_active_tasks_sync(self): + """Scan for existing containers and repopulate the active_tasks registry (synchronous).""" + print("Recovering active tasks...") + if not self.docker_client: + print("Docker client not initialized, skipping recovery") + return + + containers = self.docker_client.containers.list() + + # No need for async lock in sync context + for container in containers: + parts = container.name.split("__") + if len(parts) == 2: + uuid, task_name = parts + print(f"Found existing container: {container.name}") + compose_path = self.tasks_dir / task_name / "docker-compose.yaml" + if compose_path.exists(): + image_name = f"tb__{task_name.replace('.', '-')}__client" + compose_manager = TaskContainer( + uuid=uuid, + task_name=task_name, + client_image_name=image_name, + docker_compose_path=compose_path, + logs_dir=self.tasks_log_dir, + no_rebuild=True, + ) + self.active_tasks[container.name] = { + "uuid": uuid, + "task_name": task_name, + "last_seen": time(), + "compose_manager": compose_manager, + "container": container, + } + print(f"Recovered task: {container.name}") + else: + print( + f"Warning: Could not find compose file for task {task_name} of container {container.name}" + ) + + async def garbage_collector(self): + """Periodically clean up idle containers.""" + while self._is_running: + try: + await asyncio.sleep(60) # Check every minute + print("Running garbage collector...") + current_time = time() + idle_timeout = 60 * 15 # 10 minutes + + with self.active_tasks_lock: + containers_to_cleanup = [] + for container_name, task_info in self.active_tasks.items(): + if current_time - task_info["last_seen"] > idle_timeout: + containers_to_cleanup.append(container_name) + + for container_name in containers_to_cleanup: + task_info = self.active_tasks[container_name] + print(f"Container {container_name} is idle. Shutting down.") + try: + task_info["compose_manager"].stop() + + # Clean up tmux session + with self.tmux_sessions_lock: + if container_name in self.tmux_sessions: + del self.tmux_sessions[container_name] + + del self.active_tasks[container_name] + print(f"Successfully cleaned up {container_name}.") + except Exception as e: + print( + f"Error during garbage collection for {container_name}: {e}" + ) + except asyncio.CancelledError: + print("Garbage collector cancelled.") + break + except Exception as e: + print(f"Error in garbage collector: {e}") + + async def start_task(self, req: TaskRequest) -> JSONResponse: + """Starts a new task container or returns an existing one.""" + op_start = time() + container_name = f"{req.uuid}__{req.task_name}" + logger.info(f"[PERF] start_task called for {container_name}") + + # Check if container already exists + check_start = time() + with self.active_tasks_lock: + if container_name in self.active_tasks: + self.active_tasks[container_name]["last_seen"] = time() + log_operation_time( + f"start_task (reused) for {container_name}", op_start + ) + return JSONResponse( + {"container_name": container_name, "status": "reused"} + ) + log_operation_time("start_task lock check", check_start) + + compose_path = self.tasks_dir / req.task_name / "docker-compose.yaml" + if not compose_path.exists(): + return JSONResponse( + {"error": f"Task '{req.task_name}' not found."}, + status_code=404, + ) + + logger.info(f"Creating new container for task: {req.task_name}") + + try: + image_name = f"tb__{req.task_name.replace('.', '-')}__client" + + # Check if image exists + check_image_start = time() + no_rebuild = self._image_exists(image_name) + log_operation_time( + f"start_task check image exists for {req.task_name}", check_image_start + ) + + # Create TaskContainer + create_manager_start = time() + compose_manager = TaskContainer( + uuid=req.uuid, + task_name=req.task_name, + client_image_name=image_name, + docker_compose_path=compose_path, + no_rebuild=no_rebuild, + cleanup=False, # set to False for future reuse + logs_dir=self.tasks_log_dir, + ) + log_operation_time( + f"start_task create TaskContainer for {req.task_name}", + create_manager_start, + ) + + # Start container (potentially slow) + start_container_time = time() + container = compose_manager.start() + log_operation_time( + f"start_task container.start() for {req.task_name}", + start_container_time, + ) + + # Register container + register_start = time() + with self.active_tasks_lock: + self.active_tasks[container_name] = { + "uuid": req.uuid, + "task_name": req.task_name, + "last_seen": time(), + "compose_manager": compose_manager, + "container": container, + } + log_operation_time( + f"start_task register container for {container_name}", register_start + ) + log_operation_time(f"start_task (created) for {container_name}", op_start) + + return JSONResponse({"container_name": container_name, "status": "created"}) + except Exception as e: + logger.error(f"Error starting container: {e}, task name: {req.task_name}") + log_operation_time(f"start_task (failed) for {container_name}", op_start) + return JSONResponse( + {"error": f"Failed to start container: {e}"}, status_code=500 + ) + + async def list_tasks(self) -> JSONResponse: + """Lists all active tasks.""" + with self.active_tasks_lock: + tasks = [ + { + "container_name": name, + "uuid": info["uuid"], + "task_name": info["task_name"], + "last_seen": info["last_seen"], + } + for name, info in self.active_tasks.items() + ] + return JSONResponse(tasks) + + async def get_tmux_session(self, container_name: str) -> TmuxSession: + """Get or create a TmuxSession for the given container.""" + with self.tmux_sessions_lock: + if container_name not in self.tmux_sessions: + print(f"Creating new tmux session for {container_name}") + + with self.active_tasks_lock: + if container_name not in self.active_tasks: + raise ValueError( + f"Container {container_name} not found in active tasks" + ) + + task_info = self.active_tasks[container_name] + container = task_info["container"] + user = container.attrs["Config"].get("User", "") + + self.tmux_sessions[container_name] = TmuxSession( + session_name="agent", + container=container, + user=user, + ) + self.tmux_sessions[container_name].start() + + return self.tmux_sessions[container_name] + + def create_validation_session(self, container_name: str): + session = TmuxSession( + session_name="test", + container=self.active_tasks[container_name]["container"], + user=self.active_tasks[container_name]["container"] + .attrs["Config"] + .get("User", ""), + ) + session.start() + return session + + def update_task_last_seen(self, container_name: str): + """Update the last_seen timestamp for a task.""" + with self.active_tasks_lock: + if container_name in self.active_tasks: + self.active_tasks[container_name]["last_seen"] = time() + + def validate_container(self, container_name: str) -> bool: + """Validate that the container exists.""" + with self.active_tasks_lock: + return container_name in self.active_tasks + + async def validate_task(self, container_name: str) -> JSONResponse: + """Run tests in the container and return test results.""" + op_start = time() + + if not container_name: + return JSONResponse( + {"error": "Container name is required"}, + status_code=400, + ) + + logger.info(f"[PERF] validate_container called for {container_name}") + + # Validate container exists + validate_start = time() + if not self.validate_container(container_name): + return JSONResponse( + {"error": f"Container '{container_name}' not found"}, status_code=404 + ) + log_operation_time( + f"validate_container check exists for {container_name}", validate_start + ) + + try: + # Update last seen timestamp + self.update_task_last_seen(container_name) + + # Create a temporary tmux session + session_start = time() + session = self.create_validation_session(container_name) + log_operation_time( + f"validate_container create tmux session for {container_name}", + session_start, + ) + + # Get task info + with self.active_tasks_lock: + task_info = self.active_tasks.get(container_name) + if not task_info: + return JSONResponse( + { + "error": f"Task info not found for container '{container_name}'" + }, + status_code=404, + ) + task_name = task_info["task_name"] + + # Get test timeout and parser from task config + task_dir = self.tasks_dir / task_name + task_config_path = task_dir / "task.yaml" + + max_test_timeout_sec = 60.0 # Default timeout + parser_name = "pytest" # Default parser + + if task_config_path.exists(): + try: + task = Task.from_yaml(task_config_path) + max_test_timeout_sec = task.max_test_timeout_sec + parser_name = task.parser_name + except Exception as e: + logger.warning(f"Failed to load task config for {task_name}: {e}") + + # Copy test files to container + run_tests_path = task_dir / "run-tests.sh" + test_dir = task_dir / "tests" + + if not run_tests_path.exists(): + return JSONResponse( + { + "container_name": container_name, + "status": "error", + "error": f"Test script not found: {run_tests_path}", + }, + status_code=404, + ) + + try: + # Use DockerComposeManager's static method to copy files + copy_files_start = time() + with self.active_tasks_lock: + container = self.active_tasks[container_name]["container"] + + paths_to_copy = [run_tests_path] + if test_dir.exists(): + paths_to_copy.append(test_dir) + + DockerComposeManager.copy_to_container( + container=container, paths=paths_to_copy, container_dir="/tests" + ) + log_operation_time( + f"validate_container copy test files for {container_name}", + copy_files_start, + ) + logger.info(f"Copied test files to container {container_name}") + except Exception as e: + logger.error(f"Failed to copy test files: {e}") + return JSONResponse( + { + "container_name": container_name, + "status": "error", + "error": f"Failed to copy test files: {str(e)}", + }, + status_code=500, + ) + + # Run test script + test_script_path = "/tests/run-tests.sh" + + logger.info( + f"Running tests for container {container_name} with timeout {max_test_timeout_sec}s" + ) + + try: + run_tests_start = time() + session.send_keys( + [f"bash {test_script_path}", "Enter"], + block=True, + max_timeout_sec=max_test_timeout_sec, + ) + log_operation_time( + f"validate_container run tests for {container_name}", + run_tests_start, + ) + except TimeoutError: + logger.warning(f"Test timeout for container {container_name}") + log_operation_time( + f"validate_container (timeout) for {container_name}", op_start + ) + return JSONResponse( + { + "container_name": container_name, + "status": "timeout", + "error": f"Test execution timed out after {max_test_timeout_sec} seconds", + } + ) + + # Capture test output + capture_start = time() + test_output = session.capture_pane(capture_entire=True) + log_operation_time( + f"validate_container capture output for {container_name}", capture_start + ) + + # Parse test results + try: + parser = ParserFactory.get_parser(parser_name) + results = parser.parse(test_output) + + # Calculate weighted score + score = _calculate_weighted_test_score(results, None) + + log_operation_time( + f"validate_container (success) for {container_name} with score {score}", + op_start, + ) + + return JSONResponse( + { + "container_name": container_name, + "status": "completed", + "score": score, + "raw_output": test_output, + }, + status_code=200, + ) + + except Exception as e: + logger.error(f"Error parsing test results for {task_name}: {e}") + log_operation_time( + f"validate_container (parse_error) for {container_name}", op_start + ) + return JSONResponse( + { + "container_name": container_name, + "status": "parse_error", + "error": f"Failed to parse test results: {str(e)}", + "raw_output": test_output, + }, + status_code=500, + ) + + except Exception as e: + logger.error( + f"Error validating container {container_name}: {e}", exc_info=True + ) + log_operation_time( + f"validate_container (failed) for {container_name}", op_start + ) + return JSONResponse( + {"error": f"Failed to validate container: {str(e)}"}, status_code=500 + ) + + async def stop_task(self, req: TaskRequest) -> JSONResponse: + """Stops and removes a task container for resource cleanup.""" + container_name = req.container_name + if not container_name: + container_name = f"{req.uuid}__{req.task_name}" + + with self.active_tasks_lock: + if container_name not in self.active_tasks: + return JSONResponse( + {"error": f"Container '{container_name}' not found."}, + status_code=404, + ) + + task_info = self.active_tasks[container_name] + + print(f"Stopping container: {container_name}") + + try: + # Stop the compose services + task_info["compose_manager"].stop() + + # Clean up tmux session + with self.tmux_sessions_lock: + if container_name in self.tmux_sessions: + del self.tmux_sessions[container_name] + + # Remove from active tasks + with self.active_tasks_lock: + del self.active_tasks[container_name] + + print(f"Successfully stopped and removed container: {container_name}") + return JSONResponse({"container_name": container_name, "status": "stopped"}) + except Exception as e: + print(f"Error stopping container {container_name}: {e}") + return JSONResponse( + {"error": f"Failed to stop container: {e}"}, status_code=500 + ) + + def _image_exists(self, image_name: str) -> bool: + """Check if a Docker image exists locally. + + Args: + image_name: Name of the Docker image to check + + Returns: + True if image exists, False otherwise + """ + try: + self.docker_client.images.get(image_name) + return True + except docker.errors.ImageNotFound: + return False + + +def log_operation_time(operation_name: str, start_time: float): + """Log the time taken for an operation.""" + duration = time() - start_time + logger.info( + f"[PERF] {operation_name} took {duration:.3f}s ({duration * 1000:.1f}ms)" + ) + + +def _calculate_weighted_test_score( + results: dict[str, UnitTestStatus], + test_weights: dict[str, float] | None, +) -> float: + """ + Calculate weighted score from test results. + + Args: + results: Test name to status mapping + test_weights: Test name to weight mapping + + Returns: + Weighted score between 0.0 and 1.0 + """ + if not results: + return 0.0 + + # If no test weights provided or only placeholder, use equal weights + # Filter out placeholder key used when test_weights.json doesn't exist + filtered_weights = { + k: v for k, v in (test_weights or {}).items() if not k.startswith("_") + } + + if not filtered_weights: + equal_weight = 1.0 / len(results) + total_score = sum( + equal_weight if status == UnitTestStatus.PASSED else 0.0 + for status in results.values() + ) + return total_score + + # Calculate weighted score + total_score = 0.0 + total_weight = 0.0 + + for test_name, status in results.items(): + weight = filtered_weights.get(test_name, 0.0) + if weight > 0: + score = 1.0 if status == UnitTestStatus.PASSED else 0.0 + total_score += score * weight + total_weight += weight + + # Normalize if weights don't sum to 1.0 + if total_weight > 0: + return total_score / total_weight + + equal_weight = 1.0 / len(results) + return sum( + equal_weight if status == UnitTestStatus.PASSED else 0.0 + for status in results.values() + ) diff --git a/examples/openai-agents/terminal/prompt.py b/examples/openai-agents/terminal/prompt.py new file mode 100644 index 000000000..6825b0302 --- /dev/null +++ b/examples/openai-agents/terminal/prompt.py @@ -0,0 +1,264 @@ +SYSTEM_PROMPT = """ +You are a Terminal Agent operating inside a Linux shell environment. +You can interact with the system using provided tools with multiple retries. Never simulate, predict, or describe command results — always perform real actions through tool calls. +--- + +## WORKFLOW (STRICT) + +### Phase 1: Initial Exploration (First Actions) +1. **Start with `current_working_directory()`** to understand your current location +2. **Use `file_contents()`** to examine important files (README, configs, requirements) +3. **Create a mental map** of the directory structure and key files + +### Phase 2: Task Execution +1. **Plan your approach** based on what you discovered +2. **Download missing tools** using `execute_command()` if needed +3. **Use `execute_command()`** for implementation actions +4. **Use `file_contents()`** to verify changes and check results +5. **Use `current_working_directory()`** to monitor directory changes + +### Phase 3: Verification +1. **Check your work** with appropriate tools +2. **Test functionality** if applicable +3. **Provide concise summary** of what was accomplished + +--- + +## TOOL USAGE BEST PRACTICES + +### File Operations +- **Prefer `file_contents(head_lines=N)`** over `execute_command("head -n N file")` for better error handling +- **Prefer `file_contents(tail_lines=N)`** over `execute_command("tail -n N file")` for logs +- **Use `current_working_directory()`** before file operations to verify paths + +### Command Execution +- **Use appropriate wait_time_sec**: 1.0 for quick commands, 5.0+ for long operations +- **Check command results** before proceeding to next steps +- **Use absolute paths** when possible to avoid path issues + +### Error Handling +- **Read error messages carefully** from tool outputs +- **Try one corrective action** per failure (fix path, add permissions, install missing deps) +- **Use `current_working_directory()`** to verify file existence before operations + +--- + +## Error Categories and Solutions + +### Command Not Found +- **Symptom**: `bash: command: not found` +- **Solution**: Install missing package (use apt-get install) + +### Permission Denied +- **Symptom**: `bash: ./script: Permission denied` +- **Solution**: Use `chmod +x` (see section 2) + +### File Not Found +- **Symptom**: `No such file or directory` +- **Solution**: Check current directory, verify paths, use absolute paths + +### Syntax Errors +- **Symptom**: Command syntax issues +- **Solution**: Check command syntax, quote strings properly, escape special chars + +### Network Errors +- **Symptom**: Connection timeouts, DNS failures +- **Solution**: Check network connectivity, try different mirrors, use offline alternatives + +--- + +## EFFICIENCY TIPS +- **Use `file_contents()` with head/tail** for large files to avoid long outputs +- **Start with `current_working_directory()`** to understand the environment +- **Combine related operations** in logical sequences +- **Avoid redundant commands** - check results before repeating actions + +--- + +Be methodical, explore first, then execute. Use the right tool for each task and always verify your results. +""" + + +JUDGE_PROMPT = """ +You are a judge for a terminal task agent. You will be given the agent't session lists which contains the agent's actions and the environment's responses directly, so you can evaluate the agent's performance based on the actions and responses. +## Quick Reference: Scoring Overview + +**Score Range**: 0.00 to 1.00 (two decimal places) + +### Immediate Failure Conditions (Hard Caps) +- **No valid tool calls**: Max score 0.09 +- **Only parse errors**: Max score 0.30 +- **No initial todo creation**: Max score 0.40 +- **Skipped exploration phase**: Max score 0.50 + +### Primary Scoring Components +1. **Action Output Success** (35%) +2. **Todo Usage & Planning** (25%) +3. **Phase Adherence** (25%) +4. **Tool Usage Effectiveness** (15%) + +--- + +## Required Execution Phases + +Agents MUST follow these phases in order: + +1. **Planning** → Create initial todos (first action) including exploration tasks +2. **Exploration** → Read-only discovery of file structure, key files, and environment +3. **Plan Refinement** → Update todos based on findings +4. **Execution** → Implement the solution, adjust / maintain / extend plan where necessary +5. **Verification** → Test and validate + +**Phase violations incur significant penalties (-0.20 to -0.30)** + +--- + +## Detailed Scoring Criteria + +### 1. Action Output Success (35% weight) + +**Evaluate**: +- Percentage of turns with valid actions +- Successful parsing and execution rate +- Recovery from failures + +### 2. Todo Usage & Planning (25% weight) + +**Requirements**: +- First action MUST create todos +- Initial todos should typically include exploration tasks (file structure, key files) unless user provides complete details +- Todo list is kept up to date throughout based on discoveries + +**Penalties**: +- No initial todos: Cap at 0.40 +- Never completing todos: -0.10 to -0.20 +- Poor maintenance: -0.05 to -0.15 + +### 3. Phase Adherence (25% weight) + +**Check for**: +- All 5 phases in correct order +- Evidence in both todos AND actions +- Extensive and relevant systematic exploration before implementation +- Proper refinement of plan based on discoveries + +**Violations**: +- Skipping phases: -0.20 to -0.30 +- Out of order execution: -0.15 to -0.25 + +### 4. Tool Usage Effectiveness (15% weight) + +**Good Tool Usage**: +- Purposeful actions progressing toward goal +- Appropriate tool selection +- Using simpler tools when available + +**Scratchpad Usage**: +- ✅ Reward (+0.05 to +0.10): Complex reasoning, hypothesis tracking +- ❌ Penalize (-0.05 to -0.10): Duplicating todos, chat-like usage + +**Tool Misuse Penalties** (-0.05 to -0.15): +- Meaningless action sequences +- Actions contradicting logical workflow +- Fundamental misunderstanding of tool purpose + +--- + +## Quality Modifiers + +### Error Recovery & Learning (+/- 0.10) +**Bonus Conditions**: +- Fixes parse errors and continues +- Adapts after command failures +- Shows clear improvement trajectory +- Error messages lead to corrected actions + +### Discovery Quality (+/- 0.20) +**Look for**: +- Systematic exploration +- Information synthesis across phases +- Building comprehensive understanding +- Effective use of scratchpad for insights + +### Efficiency & Focus (+/- 0.05) +**Assess**: +- Avoiding redundant actions +- Maintaining phase focus +- Clean action sequences +- Working within token constraints + +### Assumption Avoidance (+/- 0.15) +**Penalize** (-0.05 to -0.15): +- Acting on assumed file locations +- Implementing based on guesses +- Making changes without verification that they worked + +**Reward** (+0.05): +- Explicit verification before action +- Checking file existence +- Testing assumptions through exploration + +--- + +## Critical Penalty Areas + +### Overthinking Detection (-0.15 to -0.40) + +**CRITICAL: Thinking without action is heavily penalized. Take concrete actions immediately.** + +**Analysis Paralysis** (-0.15 to -0.30): +- Excessive thinking (10+ lines or multiple paragraphs in tags) with no corresponding actions +- Repeatedly questioning tool availability instead of trying them +- Over-analyzing instead of executing concrete actions +- Explaining basic syntax instead of using and testing it + +**Approach Switching Loops** (-0.10 to -0.25): +- Cycling through same options +- Revisiting rejected approaches + +**Redundant Action Attempts** (-0.15 to -0.30): +- Retrying completed tasks +- Ignoring "already completed" messages +- Creating duplicate todos + +**Writing Full Actions in Thinking** (-0.10 to -0.25): +- Drafting complete tool calls +- Writing out full code snippets instead of executing them +- Pre-planning entire scripts rather than building incrementally +- Long thinking blocks with no actions between them +- Note: Brief planning is good; extended thinking without action is not + +**Severity Scale**: +- Minor (1-2 patterns): -0.15 +- Moderate (3-4 patterns): -0.25 +- Severe (5+ patterns): -0.35 +- Extreme (prevents actions): -0.40 + +### Gaming Detection (-0.10 to -0.30) + +**Watch for**: +- Minimal actions to "check off" phases +- Artificial complexity for simple tasks +- Suspicious early mistakes with dramatic recovery +- Unnecessarily prolonged trajectories + +--- + +## Key Reminders + +✅ **Always Reward**: +- Planning-exploration first approach +- Clear phase progression +- Learning from errors +- Efficient execution +- Strategic scratchpad use + +❌ **Always Penalize**: +- No tool use +- Missing initial exploration +- Phase skipping +- Overthinking/paralysis +- Gaming behaviors + +⚠️ **Your Role**: Evaluate HOW the agent worked, not WHETHER the task was completed. Task completion is verified separately via software run unit tests. +""" diff --git a/examples/openai-agents/terminal/requirements.txt b/examples/openai-agents/terminal/requirements.txt new file mode 100644 index 000000000..6e7068bc1 --- /dev/null +++ b/examples/openai-agents/terminal/requirements.txt @@ -0,0 +1,4 @@ +FastMCP +pydantic +starlette +terminal_bench \ No newline at end of file diff --git a/examples/openai-agents/terminal/server.py b/examples/openai-agents/terminal/server.py new file mode 100644 index 000000000..9e42f3184 --- /dev/null +++ b/examples/openai-agents/terminal/server.py @@ -0,0 +1,222 @@ +"""A MCP server for running terminal-bench tasks independently.""" + +import argparse +import asyncio +import contextvars +import logging +import warnings +from pathlib import Path + +from fastmcp import FastMCP +from starlette.requests import Request + +from .logging_config import setup_logging +from .models import TaskRequest, TerminalBenchServer + +# Suppress SQLAlchemy 2.0 deprecation warnings from terminal_bench +warnings.filterwarnings("ignore", category=DeprecationWarning, module="terminal_bench") + +setup_logging() + +logger = logging.getLogger(__name__) + +# Context variable to track request start time +request_start_time = contextvars.ContextVar("request_start_time", default=None) + +server_instance = None +# --- MCP Server Setup with FastMCP --- + +mcp = FastMCP("t-bench-multi-task") + + +@mcp.tool() +async def keystrokes( + container_name: str, + keystrokes: str, + append_enter: bool = False, + wait_time_sec: float = 0.0, +) -> str: + """Send keystrokes to a tmux session and return the result. + + Args: + container_name: The name of the container to send keystrokes to + keystrokes: Keystrokes to execute in the terminal. Use tmux-style escape sequences for special characters (e.g. C-c for ctrl-c) + append_enter: Whether to append a newline character to the end of the keystrokes (necessary to execute bash commands) + wait_time_sec: The number of expected seconds to wait for the command to complete + + Returns: + Terminal output after executing the keystrokes + """ + + # Validate container exists + if not server_instance.validate_container(container_name): + raise ValueError(f"Invalid or unknown container_name: {container_name}") + + # Update last seen timestamp + server_instance.update_task_last_seen(container_name) + + # Get or create tmux session + session = await server_instance.get_tmux_session(container_name) + + # Clear the terminal to avoid historical results in next calls + session.send_keys( + keys=["clear", "Enter"], + min_timeout_sec=0.1, + max_timeout_sec=0.1, + ) + + keys = [keystrokes, "Enter"] if append_enter else keystrokes + session.send_keys( + keys=keys, + min_timeout_sec=wait_time_sec, + max_timeout_sec=wait_time_sec, + ) + + # Capture the output before clearing + output = session.capture_pane() + + # Clear the terminal to avoid historical results in next calls + session.send_keys( + keys=["clear", "Enter"], + min_timeout_sec=0.1, + max_timeout_sec=0.1, + ) + + return output + + +@mcp.tool() +async def capture_pane( + container_name: str, + wait_before_capture_sec: float = 0.0, +) -> str: + """Capture the pane of a tmux session. + + Args: + container_name: The name of the container to capture the pane from + wait_before_capture_sec: The number of seconds to wait before capturing the pane. This is useful if you just executed a command and want to wait a bit to capture the output + + Returns: + Current terminal pane content + """ + + # Validate container exists + if not server_instance.validate_container(container_name): + raise ValueError(f"Invalid or unknown container_name: {container_name}") + + # Update last seen timestamp + server_instance.update_task_last_seen(container_name) + + # Get or create tmux session + session = await server_instance.get_tmux_session(container_name) + + if wait_before_capture_sec > 0: + await asyncio.sleep(wait_before_capture_sec) + + return session.capture_pane() + + +# --- Custom HTTP Routes for Task Management --- + + +@mcp.custom_route("/tasks/start", methods=["POST"]) +async def start_task_route(request: Request): + """Start a new task container.""" + data = await request.json() + req = TaskRequest.model_validate(data) + return await server_instance.start_task(req) + + +@mcp.custom_route("/tasks/stop", methods=["POST"]) +async def stop_task_route(request: Request): + """Stop a task container.""" + data = await request.json() + req = TaskRequest.model_validate(data) + return await server_instance.stop_task(req) + + +@mcp.custom_route("/tasks", methods=["GET"]) +async def list_tasks_route(request: Request): + """List all active tasks.""" + return await server_instance.list_tasks() + + +@mcp.custom_route("/tasks/validate", methods=["POST"]) +async def validate_container_route(request: Request): + """Validate a container and get test results.""" + data = await request.json() + req = TaskRequest.model_validate(data) + container_name = req.container_name + if not container_name or container_name == "": + container_name = f"{req.uuid}__{req.task_name}" + return await server_instance.validate_task(container_name) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Terminal Bench MCP Server - Run terminal-bench tasks independently", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--tasks-dir", + type=Path, + default=None, + help="Path to the tasks directory (default: env T_BENCH_TASKS_DIR or /app/tasks)", + ) + + parser.add_argument( + "--tasks-log-dir", + type=Path, + default=None, + help="Path to the tasks logs directory (default: env T_BENCH_TASKS_LOG_DIR or /var/logs/terminal-bench/)", + ) + + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host to bind the server to" + ) + + parser.add_argument( + "--port", type=int, default=8000, help="Port to bind the server to" + ) + + parser.add_argument( + "--preheat-image", + action="store_true", + default=False, + help="Preheat the docker image before starting the server", + ) + + return parser.parse_args() + + +def main(): + """Main entry point for the server.""" + args = parse_args() + + # Initialize server instance (don't start yet) + global server_instance + server_instance = TerminalBenchServer( + tasks_dir=args.tasks_dir, + tasks_log_dir=args.tasks_log_dir, + preheat_image=args.preheat_image, + ) + + server_instance.startup() + + # Run FastMCP server - startup will happen on first tool call + try: + mcp.run( + transport="sse", + host=args.host, + port=args.port, + ) + except KeyboardInterrupt: + print("\nShutting down server...") + if server_instance: + server_instance.shutdown() + + +if __name__ == "__main__": + main() diff --git a/examples/openai-agents/terminal/tasks_to_parquet_converter.py b/examples/openai-agents/terminal/tasks_to_parquet_converter.py new file mode 100644 index 000000000..831332812 --- /dev/null +++ b/examples/openai-agents/terminal/tasks_to_parquet_converter.py @@ -0,0 +1,232 @@ +"""Convert Terminal Bench tasks to RLLM/VERL format.""" + +import argparse +import json +from pathlib import Path + +import pandas as pd +from terminal_bench_task import TBenchTrainingTask, load_terminal_bench_tasks +from tqdm import tqdm + + +def create_prompt_from_task( + task: TBenchTrainingTask, system_prompt: str | None = None +) -> str: + """Create a prompt from a terminal bench task.""" + if system_prompt is None: + system_prompt = ( + "You are an AI assistant helping to complete terminal-based tasks. " + "Follow the instructions carefully and use appropriate commands to accomplish the goal." + ) + + # Convert to string format (you might want to use a specific chat template) + prompt = ( + f"<|system|>\n{system_prompt}\n<|user|>\n{task.instruction}\n<|assistant|>\n" + ) + + return prompt + + +def convert_tasks_to_parquet( + tasks_dir: Path, + output_dir: Path, + train_split: float | None = None, + system_prompt: str | None = None, + task_names: list[str] | None = None, + test_tasks_dir: Path | None = None, +) -> None: + """Convert terminal bench tasks to parquet format for VERL training. + + Args: + tasks_dir: Directory containing terminal bench tasks (or train tasks if test_tasks_dir is provided) + output_dir: Output directory for parquet files + train_split: Fraction of data for training (ignored if test_tasks_dir is provided) + system_prompt: System prompt to use + task_names: Specific task names to convert + test_tasks_dir: Directory containing test tasks for validation set + """ + + # Load tasks + if test_tasks_dir is not None: + # Load train and test tasks separately + print(f"Loading training tasks from {tasks_dir}") + train_tasks = load_terminal_bench_tasks(tasks_dir, task_names) + print(f"Loaded {len(train_tasks)} training tasks") + + print(f"Loading validation tasks from {test_tasks_dir}") + val_tasks = load_terminal_bench_tasks(test_tasks_dir, task_names) + print(f"Loaded {len(val_tasks)} validation tasks") + + tasks = train_tasks + val_tasks + else: + # Load all tasks from single directory + print(f"Loading tasks from {tasks_dir}") + tasks = load_terminal_bench_tasks(tasks_dir, task_names) + print(f"Loaded {len(tasks)} tasks") + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare data for parquet + data_records = [] + + for task in tqdm(tasks, desc="Converting tasks"): + record = { + "prompt": create_prompt_from_task(task, system_prompt), + "task_name": task.task_name, + "task_path": str(task.task_path), + "instruction": task.instruction, + "data_source": "terminal_bench", # For reward_fn_key + "metadata": { + "test_weights": task.test_weights, + "max_test_timeout_sec": task.max_test_timeout_sec, + }, + } + + # Always include extra_info with task configuration + extra_info_dict = { + "task_name": task.task_name, + "task_path": str(task.task_path), + "instruction": task.instruction, + "test_weights": task.test_weights, + "dockerfile_contents": task.dockerfile_contents, + "py_test_file_contents": task.py_test_file_contents, + "max_test_timeout_sec": task.max_test_timeout_sec, + } + + # Include additional files if present + if task.additional_files: + extra_info_dict["additional_files"] = task.additional_files + + record["extra_info"] = json.dumps(extra_info_dict) + + data_records.append(record) + + # Create DataFrame + df = pd.DataFrame(data_records) + + # Split into train and validation + if test_tasks_dir is not None: + # Use pre-defined split based on directories + n_train = len(train_tasks) + train_df = df[:n_train] + val_df = df[n_train:] + else: + # Use train_split parameter + if train_split is None: + train_split = 0.9 + n_train = int(len(df) * train_split) + train_df = df[:n_train] + val_df = df[n_train:] + + # Save to parquet + train_path = output_dir / "train.parquet" + val_path = output_dir / "val.parquet" + + train_df.to_parquet(train_path, index=False) + val_df.to_parquet(val_path, index=False) + + print(f"Saved {len(train_df)} training examples to {train_path}") + print(f"Saved {len(val_df)} validation examples to {val_path}") + + # Also save task information for the reward function + tasks_info = {} + for task in tasks: + task_info = { + "task_path": str(task.task_path), + "test_weights": task.test_weights, + "dockerfile_contents": task.dockerfile_contents, + "py_test_file_contents": task.py_test_file_contents, + "max_test_timeout_sec": task.max_test_timeout_sec, + } + # Include additional files if present + if task.additional_files: + task_info["additional_files"] = task.additional_files + + tasks_info[task.task_name] = task_info + + tasks_info_path = output_dir / "tasks_info.json" + with open(tasks_info_path, "w") as f: + json.dump(tasks_info, f, indent=2) + + print(f"Saved task information to {tasks_info_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Terminal Bench tasks to RLLM/VERL format." + ) + parser.add_argument( + "--tasks-dir", + type=str, + default="tasks", + help="Directory containing training tasks (default: tasks)", + ) + parser.add_argument( + "--test-tasks-dir", + type=str, + default="", + help="Directory containing test/validation tasks (default: empty)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="data", + help="Output directory for parquet files (default: data)", + ) + parser.add_argument( + "--train-split", + type=float, + default=None, + help="Fraction of data for training (only used if --test-tasks-dir is not provided)", + ) + parser.add_argument( + "--system-prompt", type=str, default=None, help="Custom system prompt to use" + ) + parser.add_argument( + "--task-names", + type=str, + nargs="+", + default=None, + help="Specific task names to convert (optional)", + ) + + args = parser.parse_args() + + # Convert to Path objects + tasks_dir = Path(args.tasks_dir) + test_tasks_dir = ( + Path(args.test_tasks_dir) + if args.test_tasks_dir and args.test_tasks_dir.strip() + else None + ) + output_dir = Path(args.output_dir) + + # Check if directories exist + if not tasks_dir.exists(): + print(f"Error: {tasks_dir} directory not found") + return + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + if test_tasks_dir: + print( + f"Converting {tasks_dir}/ -> train.parquet and {test_tasks_dir}/ -> val.parquet" + ) + else: + print(f"Converting {tasks_dir}/ with train_split={args.train_split}") + + # Convert to parquet only (with extra_info) + convert_tasks_to_parquet( + tasks_dir=tasks_dir, + output_dir=output_dir, + train_split=args.train_split, + system_prompt=args.system_prompt, + task_names=args.task_names, + test_tasks_dir=test_tasks_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/openai-agents/terminal/terminal_bench_task.py b/examples/openai-agents/terminal/terminal_bench_task.py new file mode 100644 index 000000000..966a8c415 --- /dev/null +++ b/examples/openai-agents/terminal/terminal_bench_task.py @@ -0,0 +1,131 @@ +import json +from enum import Enum +from pathlib import Path + +import yaml +from pydantic import BaseModel +from tqdm import tqdm + + +class TBenchTaskDifficulty(Enum): + EASY = "easy" + MEDIUM = "medium" + HARD = "hard" + UNRATED = "unrated" + + +class TBenchTrainingTask(BaseModel): + """Data model for a task which follows the same format as terminal-bench.""" + + task_name: str + task_path: Path + instruction: str + difficulty: TBenchTaskDifficulty + test_weights: dict + dockerfile_contents: str + py_test_file_contents: str + max_test_timeout_sec: int = 300 # Default timeout + additional_files: dict | None = None # Maps file paths to contents + + +def load_terminal_bench_tasks( + tasks_dir: Path, + task_names: list[str] | None = None, +) -> list[TBenchTrainingTask]: + if task_names is None: + task_names = [p.name for p in tasks_dir.iterdir() if p.is_dir()] + + tasks = [] + for task_name in tqdm(task_names, desc="Loading tasks"): + task_path = tasks_dir / task_name + task_yaml = task_path / "task.yaml" + + if not task_yaml.exists(): + tqdm.write(f"Task YAML file not found: {task_yaml}") + continue + + with open(task_yaml, encoding="utf-8") as f: + task_data = yaml.safe_load(f) + + instruction = task_data.get("instruction") + if not instruction: + tqdm.write(f"Instruction not found in task YAML: {task_yaml}") + continue + + # Get max test timeout if specified + max_test_timeout_sec = task_data.get("max_test_timeout_sec", 300) + difficulty = task_data.get("difficulty", TBenchTaskDifficulty.UNRATED) + + # Load test weights + test_weights_path = task_path / "test_weights.json" + if test_weights_path.exists(): + with open(test_weights_path, encoding="utf-8") as f: + test_weights = json.load(f) + else: + # Use placeholder to avoid PyArrow empty struct error + test_weights = {"_no_weights": 1.0} + + # Load Dockerfile + dockerfile_path = task_path / "Dockerfile" + if not dockerfile_path.exists(): + tqdm.write(f"Dockerfile not found for task: {task_name}") + continue + with open(dockerfile_path, encoding="utf-8") as f: + dockerfile_contents = f.read() + if not dockerfile_contents: + tqdm.write(f"Dockerfile is empty for task: {task_name}") + continue + + # Load Python test file if it exists + py_test_file_path = task_path / "tests" / "test_outputs.py" + if not py_test_file_path.exists(): + tqdm.write(f"Python test file not found for task: {task_name}") + continue + with open(py_test_file_path, encoding="utf-8") as f: + py_test_file_contents = f.read() + if not py_test_file_contents: + tqdm.write(f"Python test file is empty for task: {task_name}") + continue + + # Load additional files if they exist + additional_files = {} + # List all files in the task directory (excluding standard files) + standard_files = {"Dockerfile", "task.yaml", "test_weights.json"} + standard_dirs = {"tests", "__pycache__"} + + for item in task_path.iterdir(): + if item.is_file() and item.name not in standard_files: + # Read the file and store with relative path + rel_path = item.relative_to(task_path) + try: + with open(item, encoding="utf-8") as f: + additional_files[str(rel_path)] = f.read() + except UnicodeDecodeError: + tqdm.write(f"Binary file found for task: {task_name}") + elif item.is_dir() and item.name not in standard_dirs: + # Recursively read files from subdirectories + for subfile in item.rglob("*"): + if subfile.is_file(): + rel_path = subfile.relative_to(task_path) + try: + with open(subfile, encoding="utf-8") as f: + additional_files[str(rel_path)] = f.read() + except UnicodeDecodeError: + # Skip binary files for now + tqdm.write(f"Binary file found for task: {task_name}") + + tasks.append( + TBenchTrainingTask( + task_name=task_name, + task_path=task_path, + instruction=instruction, + difficulty=difficulty, + test_weights=test_weights, + dockerfile_contents=dockerfile_contents, + py_test_file_contents=py_test_file_contents, + max_test_timeout_sec=max_test_timeout_sec, + additional_files=additional_files if additional_files else None, + ) + ) + + return tasks diff --git a/examples/openai-agents/terminal/test_client.py b/examples/openai-agents/terminal/test_client.py new file mode 100644 index 000000000..1de2bda0f --- /dev/null +++ b/examples/openai-agents/terminal/test_client.py @@ -0,0 +1,129 @@ +"""Simple MCP client to test terminal server functionality.""" + +import asyncio + +from mcp import ClientSession +from mcp.client.sse import sse_client + + +async def test_mcp_server(): + """Test the MCP server by calling tools.""" + server_url = "http://localhost:8000" + + print("Connecting to MCP server...") + async with sse_client(f"{server_url}/sse") as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + print("✓ Connected to MCP server\n") + + # List available tools + print("=" * 60) + print("Available Tools:") + print("=" * 60) + tools_result = await session.list_tools() + for tool in tools_result.tools: + print(f" • {tool.name}: {tool.description}") + print() + + # Test 1: Start a task + print("=" * 60) + print("Test 1: Starting a task container") + print("=" * 60) + import requests + + response = requests.post( + f"{server_url}/tasks/start", + json={"uuid": "test-123", "task_name": "hello-world"}, + timeout=60, + ) + if response.status_code == 200: + data = response.json() + container_name = data["container_name"] + print(f"✓ Task started: {container_name}\n") + + # Test 2: Send keystrokes + print("=" * 60) + print("Test 2: Sending keystrokes (echo 'Hello MCP')") + print("=" * 60) + result = await session.call_tool( + name="keystrokes", + arguments={ + "container_name": container_name, + "keystrokes": "echo 'Hello MCP'", + "append_enter": True, + "wait_time_sec": 1.0, + }, + ) + output = result.content[0].text if result.content else "" + print(f"Output:\n{output}\n") + + print("=" * 60) + print("Test 3: Capturing terminal pane") + print("=" * 60) + result = await session.call_tool( + name="capture_pane", + arguments={ + "container_name": container_name, + "wait_before_capture_sec": 0.5, + }, + ) + output = result.content[0].text if result.content else "" + print(f"Captured output:\n{output}\n") + + print("=" * 60) + print("Test 4: Listing active tasks") + print("=" * 60) + response = requests.get(f"{server_url}/tasks") + if response.status_code == 200: + tasks = response.json() + print(f"✓ Active tasks: {len(tasks)}") + for task in tasks: + print(f" • {task['container_name']} (UUID: {task['uuid']})") + print() + + # validate task + print("=" * 60) + print("Test 4: Validating task container") + print("=" * 60) + response = requests.post( + f"{server_url}/tasks/validate", + json={"uuid": "test-123", "task_name": "hello-world"}, + timeout=180, + ) + if response.status_code == 200: + print( + f"✓ Task validated successfully, score: {response.json().get('score')}\n" + ) + else: + print(f"✗ Failed to validate task: {response.text}\n") + + print("=" * 60) + print("Test 5: Stopping task container") + print("=" * 60) + response = requests.post( + f"{server_url}/tasks/stop", + json={"uuid": "test-123", "task_name": "hello-world"}, + timeout=30, + ) + if response.status_code == 200: + print("✓ Task stopped successfully\n") + else: + print(f"✗ Failed to stop task: {response.text}\n") + else: + print(f"✗ Failed to start task: {response.text}\n") + + print("=" * 60) + print("All tests completed!") + print("=" * 60) + + +if __name__ == "__main__": + try: + asyncio.run(test_mcp_server()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + + traceback.print_exc() diff --git a/examples/openai-agents/train_agents.py b/examples/openai-agents/train_agents.py index 3fda65287..202ac5ba3 100644 --- a/examples/openai-agents/train_agents.py +++ b/examples/openai-agents/train_agents.py @@ -28,7 +28,7 @@ class AgentRLConfig(GRPOConfig): default="math", metadata={ "help": "Type of agent workflow to use.", - "choices": ["math", "multi_agent_math"], + "choices": ["math", "multi_agent_math", "multi_agent_terminal"], }, ) n_trajs: int = field( @@ -123,6 +123,19 @@ def main(args): StatsLogger.get_log_path(config.stats_logger), "generated" ), ) + elif config.agent_type == "multi_agent_terminal": + from agent_terminal_workflow import TerminalAgentWorkflow + + workflow = TerminalAgentWorkflow( + gconfig=config.gconfig, + tokenizer=tokenizer, + n_trajs=config.n_trajs, + max_tokens=config.max_tokens_per_trajectory, + max_turns=config.max_turns, + dump_dir=os.path.join( + StatsLogger.get_log_path(config.stats_logger), "generated" + ), + ) else: raise ValueError(f"Unknown agent_type: {config.agent_type}.")