diff --git a/.gitignore b/.gitignore index fc61d50e..6234b3cd 100644 --- a/.gitignore +++ b/.gitignore @@ -172,6 +172,9 @@ outputs/ miniwob-plusplus/ .miniwob-server.pid debugging_results/ +docker_vm_data/ +OSWorld/ + # working files experiments/* \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index d52ef62d..eb18e557 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,13 +3,20 @@ "editor.formatOnSave": true, "editor.defaultFormatter": "ms-python.black-formatter", "editor.codeActionsOnSave": { - "source.organizeImports": "explicit", - "source.fixAll": "never" - } + "source.organizeImports": "always", + "source.fixAll": "always", + }, }, + "python.analysis.languageServerMode": "full", + "python.analysis.typeCheckingMode": "standard", "python.testing.pytestArgs": [ "tests" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, + "files.watcherExclude": { + "**/.git/objects/**": true, + "**/.git/subtree-cache/**": true, + "**/node_modules/*/**": true + }, } \ No newline at end of file diff --git a/Makefile b/Makefile index b37fc1c4..23799f32 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test setup miniwob lint stop-miniwob +.PHONY: test setup miniwob lint stop-miniwob osworld setup: @pip install -e . @@ -30,3 +30,23 @@ test: setup miniwob check-miniwob run-tests stop-miniwob lint: setup @black src/ --check --diff @darglint -v 2 -z short src/ + +osworld: + @echo "Setting up OSWorld..." + @git clone https://github.com/xlang-ai/OSWorld || true + @echo "Modifying OSWorld requirements.txt to remove pinned versions..." + @cd OSWorld && \ + sed -i.bak 's/numpy~=.*/numpy/' requirements.txt && \ + sed -i.bak 's/torch~=.*/torch/' requirements.txt && \ + sed -i.bak 's/torch$$/torch/' requirements.txt && \ + sed -i.bak 's/tqdm~=.*/tqdm/' requirements.txt && \ + sed -i.bak 's/pandas~=.*/pandas/' requirements.txt + @echo "Installing OSWorld requirements..." + @cd OSWorld && pip install -r requirements.txt + @echo "Installing OSWorld in development mode..." + @cd OSWorld && pip install -e . + @echo "OSWorld setup completed!" + @echo "Next steps:" + @echo "1. Configure your VM (VMware/VirtualBox) according to OSWorld documentation" + @echo "2. Download or set up the Ubuntu VM image" + @echo "3. Run AgentLab with OSWorld tasks" \ No newline at end of file diff --git a/README.md b/README.md index 05b04c2b..b5807314 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ AgentLab Features: | [GAIA](https://huggingface.co/spaces/gaia-benchmark/leaderboard) (soon) | - | - | None | - | - | live web | soon | | [Mind2Web-live](https://huggingface.co/datasets/iMeanAI/Mind2Web-Live) (soon) | - | - | None | - | - | live web | soon | | [MiniWoB](https://miniwob.farama.org/index.html) | [setup](https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/README.md) | 125 | Medium | 10 | no | self hosted (static files) | soon | +| [OSWorld](https://os-world.github.io/) | [setup](https://github.com/ServiceNow/AgentLab/blob/main/src/agentlab/benchmarks/setup.md) | 369 | None | - | - | self hosted | soon | ## 🛠️ Setup AgentLab diff --git a/experiments/osworld_debug_task_ids.json b/experiments/osworld_debug_task_ids.json new file mode 100644 index 00000000..64d4e2c6 --- /dev/null +++ b/experiments/osworld_debug_task_ids.json @@ -0,0 +1,37 @@ +[ + { + "id": "550ce7e7-747b-495f-b122-acdc4d0b8e54", + "task": "I am checking our soccer club's to-do list for the last semester and adding strike-through sign on the line we have already accomplished. Could you help me add a strike-through on the first and second line?", + "complexity": 1 + }, + { + "id": "59f21cfb-0120-4326-b255-a5b827b38967", + "task": "Could you play the music video that's saved on my desktop for me via vlc?", + "complexity": 1 + }, + { + "id": "35253b65-1c19-4304-8aa4-6884b8218fc0", + "task": "Hey, I need a quick way back to this site. Could you whip up a shortcut on my desktop for me?", + "complexity": 1 + }, + { + "id": "0ed39f63-6049-43d4-ba4d-5fa2fe04a951", + "task": "Please help me change all the places in this document that say \"text\" to \"test\".", + "complexity": 1 + }, + { + "id": "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", + "task": "I am currently using an Ubuntu system, and I have wrongly deleted a poster of party night. Could you help me recover it from the Trash?", + "complexity": 1 + }, + { + "id": "510f64c8-9bcc-4be1-8d30-638705850618", + "task": "Could you start VS Code in folder ~/Desktop/project from the terminal?", + "complexity": 1 + }, + { + "id": "53ad5833-3455-407b-bbc6-45b4c79ab8fb", + "task": "Please help me use VS Code to open the \"project\" in the \"user\" folder under \"home\".", + "complexity": 1 + } +] \ No newline at end of file diff --git a/experiments/osworld_docker_test.py b/experiments/osworld_docker_test.py new file mode 100644 index 00000000..3b68db5d --- /dev/null +++ b/experiments/osworld_docker_test.py @@ -0,0 +1,37 @@ +import logging + +from desktop_env.desktop_env import DesktopEnv + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], +) + +example = { + "id": "94d95f96-9699-4208-98ba-3c3119edf9c2", + "instruction": "I want to install Spotify on my current system. Could you please help me?", + "config": [ + { + "type": "execute", + "parameters": { + "command": [ + "python", + "-c", + "import pyautogui; import time; pyautogui.click(960, 540); time.sleep(0.5);", + ] + }, + } + ], + "evaluator": { + "func": "check_include_exclude", + "result": {"type": "vm_command_line", "command": "which spotify"}, + "expected": {"type": "rule", "rules": {"include": ["spotify"], "exclude": ["not found"]}}, + }, +} + +env = DesktopEnv(action_space="pyautogui", provider_name="docker", os_type="Ubuntu") + +obs = env.reset(task_config=example) +obs, reward, done, info = env.step("pyautogui.rightClick()") +print(obs) diff --git a/experiments/run_osworld.py b/experiments/run_osworld.py new file mode 100644 index 00000000..36db0878 --- /dev/null +++ b/experiments/run_osworld.py @@ -0,0 +1,66 @@ +import json +import logging +import os + +from agentlab.agents.tool_use_agent.tool_use_agent import OSWORLD_CLAUDE +from agentlab.benchmarks.osworld import OsworldBenchmark +from agentlab.experiments.study import Study, make_study + +fmt = "%(asctime)s - %(levelname)s - %(name)s:%(lineno)d - %(funcName)s() - %(message)s" +logging.basicConfig(level=logging.INFO, force=True, format=fmt, handlers=[logging.StreamHandler()]) + + +def get_most_recent_incomplete_study() -> Study: + """ + Relaunch an existing study, this will continue incomplete experiments and relaunch errored experiments. + """ + study = Study.load_most_recent() + study.find_incomplete(include_errors=True) + return study + + +def get_task_ids() -> set[str]: + with open("experiments/osworld_debug_task_ids.json", "r") as f: + task_ids = json.load(f) + return set([task["id"] for task in task_ids]) + + +def main(): + n_jobs = 4 + use_vmware = True + relaunch = True + agent_args = [ + OSWORLD_CLAUDE, + # OSWORLD_OAI # performs poorly. + ] # type: ignore + parallel_backend = "ray" + os.environ["AGENTLAB_DEBUG"] = os.environ.get("AGENTLAB_DEBUG", "1") + + study = make_study( + benchmark=OsworldBenchmark( + test_set_name="test_small.json" + ), # or test_all.json (Exper) # type: ignore + agent_args=agent_args, # type: ignore + comment="osworld debug 2", + logging_level=logging.INFO, + logging_level_stdout=logging.INFO, + ) + + if use_vmware: + for exp_args in study.exp_args_list: + exp_args.env_args.provider_name = "vmware" # type: ignore + exp_args.env_args.path_to_vm = "OSWorld/vmware_vm_data/Ubuntu0/Ubuntu0.vmx" # type: ignore + parallel_backend = "sequential" + + if os.environ.get("AGENTLAB_DEBUG"): + task_ids = get_task_ids() + study.exp_args_list = [exp_args for exp_args in study.exp_args_list if exp_args.env_args.task["id"] in task_ids] # type: ignore + print(f"Debug on {len(study.exp_args_list)} experiments") + n_jobs = 1 # Make sure to use 1 job when debugging in VS + + study = get_most_recent_incomplete_study() if relaunch else study + study.run(n_jobs=n_jobs, n_relaunch=1, parallel_backend=parallel_backend) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 4b0f3d17..6322ffd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,4 +26,4 @@ matplotlib ray[default] python-slugify pillow -gymnasium>=0.27 \ No newline at end of file +gymnasium>=0.27 diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index bec693ae..d39bdcc0 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -19,7 +19,11 @@ from PIL import Image from agentlab.agents import agent_utils +from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark +from bgym import Benchmark as BgymBenchmark from agentlab.agents.agent_args import AgentArgs +from agentlab.benchmarks.osworld import OSWorldActionSet +from agentlab.llm.base_api import BaseModelArgs from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, @@ -36,7 +40,6 @@ @dataclass class Block(ABC): - def _init(self): """Initialize the block.""" pass @@ -169,6 +172,7 @@ class Obs(Block): use_tabs: bool = False # add_mouse_pointer: bool = False use_zoomed_webpage: bool = False + skip_preprocessing: bool = False def apply( self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput @@ -181,7 +185,6 @@ def apply( obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}") if self.use_screenshot: - if self.use_som: screenshot = obs["screenshot_som"] else: @@ -231,7 +234,6 @@ def _format_tabs(obs): @dataclass class GeneralHints(Block): - use_hints: bool = True def apply(self, llm, discussion: StructuredDiscussion) -> dict: @@ -342,9 +344,10 @@ class PromptConfig: @dataclass class ToolUseAgentArgs(AgentArgs): - model_args: OpenAIResponseModelArgs = None + model_args: BaseModelArgs = None config: PromptConfig = None use_raw_page_output: bool = False # This attribute is used in loop.py to setup the env. + action_set: bgym.AbstractActionSet | None = None def __post_init__(self): try: @@ -356,8 +359,9 @@ def make_agent(self) -> bgym.Agent: if self.config is None: self.config = DEFAULT_PROMPT_CONFIG return ToolUseAgent( - model_args=self.model_args, + model_args=self.model_args, # type: ignore config=self.config, + action_set=self.action_set, ) def prepare(self): @@ -366,17 +370,24 @@ def prepare(self): def close(self): return self.model_args.close_server() + def set_benchmark(self, benchmark: AgentLabBenchmark | BgymBenchmark, demo_mode: bool): + """Set benchmark specific flags.""" + benchmark_name = benchmark.name + if benchmark_name == "osworld": + self.config.obs.skip_preprocessing = True + class ToolUseAgent(bgym.Agent): def __init__( self, model_args: OpenAIResponseModelArgs, config: PromptConfig = None, + action_set: bgym.AbstractActionSet | None = None, ): self.model_args = model_args self.config = config - self.action_set = bgym.HighLevelActionSet( - self.config.action_subsets, multiaction=self.config.multiaction + self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet( + self.config.action_subsets, multiaction=self.config.multiaction # type: ignore ) self.tools = self.action_set.to_tool_description(api=model_args.api) @@ -395,7 +406,8 @@ def __init__( def obs_preprocessor(self, obs): obs = copy(obs) - + if self.config.obs.skip_preprocessing: + return obs page = obs.pop("page", None) if page is not None: obs["screenshot"] = extract_screenshot(page) @@ -592,3 +604,49 @@ def get_action(self, obs: Any) -> float: model_args=GPT4_1_OPENROUTER_MODEL, config=DEFAULT_PROMPT_CONFIG, ) + +OSWORLD_CLAUDE = ToolUseAgentArgs( + model_args=CLAUDE_MODEL_CONFIG, + config=PromptConfig( + tag_screenshot=True, + goal=Goal(goal_as_system_msg=True), + obs=Obs( + use_last_error=True, + use_screenshot=True, + use_axtree=True, + use_dom=False, + use_som=False, + use_tabs=False, + ), + summarizer=Summarizer(do_summary=True), + general_hints=GeneralHints(use_hints=False), + task_hint=TaskHint(use_task_hint=False), + keep_last_n_obs=None, + multiaction=False, # whether to use multi-action or not + action_subsets=("coord",), # or "bid" + ), + action_set=OSWorldActionSet("computer_13"), # or "pyautogui" +) + +OSWORLD_OAI = ToolUseAgentArgs( + model_args=OPENAI_MODEL_CONFIG, + config=PromptConfig( + tag_screenshot=True, + goal=Goal(goal_as_system_msg=True), + obs=Obs( + use_last_error=True, + use_screenshot=True, + use_axtree=False, + use_dom=False, + use_som=False, + use_tabs=False, + ), + summarizer=Summarizer(do_summary=True), + general_hints=GeneralHints(use_hints=False), + task_hint=TaskHint(use_task_hint=False), + keep_last_n_obs=1, # keep only the last observation in the discussion + multiaction=False, # whether to use multi-action or not + action_subsets=("coord",), + ), + action_set=OSWorldActionSet("computer_13"), +) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 8b846f5f..61b1ab68 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -712,7 +712,7 @@ def dict_msg_to_markdown(d: dict): case "text": parts.append(f"\n```\n{item['text']}\n```\n") case "tool_use": - tool_use = _format_tool_call(item["name"], item["input"], item["call_id"]) + tool_use = _format_tool_call(item["name"], item["input"], item["id"]) parts.append(f"\n```\n{tool_use}\n```\n") case _: parts.append(f"\n```\n{str(item)}\n```\n") @@ -1337,7 +1337,7 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr horizontalalignment="right", rotation=0, clip_on=True, - antialiased=True, + # antialiased=True, fontweight=1000, backgroundcolor=color, ) diff --git a/src/agentlab/benchmarks/abstract_env.py b/src/agentlab/benchmarks/abstract_env.py index 33e09e22..4e460b8d 100644 --- a/src/agentlab/benchmarks/abstract_env.py +++ b/src/agentlab/benchmarks/abstract_env.py @@ -1,4 +1,6 @@ +import time from abc import ABC, abstractmethod +from functools import wraps import gymnasium as gym from dataclasses_json import DataClassJsonMixin @@ -71,3 +73,38 @@ def step(self, action: str): @abstractmethod def close(self): """Close any resources used by the environment""" + + +def add_step_timing_to_env_info_decorator(step_func): + """Decorator/wrapper that adds timing information to any step function. + + This wrapper can be applied to any step method to automatically + measure and include action execution timing in the env_info. + + Args: + step_func: The step function to wrap + + Returns: + Wrapped function that includes timing information + """ + + @wraps(step_func) + def wrapped_step(self, action: str): + action_exec_start = time.time() + obs, reward, terminated, truncated, env_info = step_func(self, action) + action_exec_stop = time.time() + + # Ensure env_info is a dictionary + if env_info is None: + env_info = {} + + if "action_exec_start" not in env_info: + env_info["action_exec_start"] = action_exec_start + if "action_exec_stop" not in env_info: + env_info["action_exec_stop"] = action_exec_stop + if "action_exec_timeout" not in env_info: + env_info["action_exec_timeout"] = 0.0 # Default to 0, override if needed + + return obs, reward, terminated, truncated, env_info + + return wrapped_step diff --git a/src/agentlab/benchmarks/osworld.py b/src/agentlab/benchmarks/osworld.py new file mode 100644 index 00000000..e08cfb5a --- /dev/null +++ b/src/agentlab/benchmarks/osworld.py @@ -0,0 +1,948 @@ +import ast +import importlib.util +import json +import logging +import os +import time +from copy import deepcopy +from dataclasses import dataclass +from io import BytesIO +from pathlib import Path +from typing import Any, Literal + +import numpy as np +from bgym import AbstractActionSet +from dataclasses_json import DataClassJsonMixin +from PIL import Image + +from agentlab.benchmarks.abstract_env import ( + AbstractBenchmark, + AbstractEnv, + AbstractEnvArgs, + add_step_timing_to_env_info_decorator, +) +from agentlab.benchmarks.osworld_axtree_preprocessing import ( + linearize_accessibility_tree, + tag_screenshot, +) + +spec = importlib.util.find_spec("desktop_env") +if spec is not None: # desktop_env is available + from desktop_env.actions import KEYBOARD_KEYS, X_MAX, Y_MAX + from desktop_env.desktop_env import DesktopEnv +else: + # If desktop_env is not available, set to None or default values + DesktopEnv = None + KEYBOARD_KEYS = [ + "\t", + "\n", + "\r", + " ", + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "{", + "|", + "}", + "~", + "accept", + "add", + "alt", + "altleft", + "altright", + "apps", + "backspace", + "browserback", + "browserfavorites", + "browserforward", + "browserhome", + "browserrefresh", + "browsersearch", + "browserstop", + "capslock", + "clear", + "convert", + "ctrl", + "ctrlleft", + "ctrlright", + "decimal", + "del", + "delete", + "divide", + "down", + "end", + "enter", + "esc", + "escape", + "execute", + "f1", + "f10", + "f11", + "f12", + "f13", + "f14", + "f15", + "f16", + "f17", + "f18", + "f19", + "f2", + "f20", + "f21", + "f22", + "f23", + "f24", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "final", + "fn", + "hanguel", + "hangul", + "hanja", + "help", + "home", + "insert", + "junja", + "kana", + "kanji", + "launchapp1", + "launchapp2", + "launchmail", + "launchmediaselect", + "left", + "modechange", + "multiply", + "nexttrack", + "nonconvert", + "num0", + "num1", + "num2", + "num3", + "num4", + "num5", + "num6", + "num7", + "num8", + "num9", + "numlock", + "pagedown", + "pageup", + "pause", + "pgdn", + "pgup", + "playpause", + "prevtrack", + "print", + "printscreen", + "prntscrn", + "prtsc", + "prtscr", + "return", + "right", + "scrolllock", + "select", + "separator", + "shift", + "shiftleft", + "shiftright", + "sleep", + "stop", + "subtract", + "tab", + "up", + "volumedown", + "volumemute", + "volumeup", + "win", + "winleft", + "winright", + "yen", + "command", + "option", + "optionleft", + "optionright", + ] + X_MAX = 1920 + Y_MAX = 1080 + +logger = logging.getLogger(__name__) +COMPUTER_13_ACTIONS_OAI_CHATCOMPLETION_TOOLS = [ + { + "type": "function", + "function": { + "name": "move_to", + "description": "Move the cursor to the specified position", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "X coordinate", + "minimum": 0, + "maximum": X_MAX, + }, + "y": { + "type": "number", + "description": "Y coordinate", + "minimum": 0, + "maximum": Y_MAX, + }, + }, + "required": ["x", "y"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "click", + "description": "Click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", + "parameters": { + "type": "object", + "properties": { + "button": { + "type": "string", + "enum": ["left", "right", "middle"], + "description": "Mouse button to click", + }, + "x": { + "type": "number", + "description": "X coordinate", + "minimum": 0, + "maximum": X_MAX, + }, + "y": { + "type": "number", + "description": "Y coordinate", + "minimum": 0, + "maximum": Y_MAX, + }, + "num_clicks": { + "type": "integer", + "enum": [1, 2, 3], + "description": "Number of clicks", + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "mouse_down", + "description": "Press the left button if the button not specified, otherwise press the specified button", + "parameters": { + "type": "object", + "properties": { + "button": { + "type": "string", + "enum": ["left", "right", "middle"], + "description": "Mouse button to press", + } + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "mouse_up", + "description": "Release the left button if the button not specified, otherwise release the specified button", + "parameters": { + "type": "object", + "properties": { + "button": { + "type": "string", + "enum": ["left", "right", "middle"], + "description": "Mouse button to release", + } + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "right_click", + "description": "Right click at the current position if x and y are not specified, otherwise right click at the specified position", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "X coordinate", + "minimum": 0, + "maximum": X_MAX, + }, + "y": { + "type": "number", + "description": "Y coordinate", + "minimum": 0, + "maximum": Y_MAX, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "double_click", + "description": "Double click at the current position if x and y are not specified, otherwise double click at the specified position", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "X coordinate", + "minimum": 0, + "maximum": X_MAX, + }, + "y": { + "type": "number", + "description": "Y coordinate", + "minimum": 0, + "maximum": Y_MAX, + }, + }, + "required": [], + }, + }, + }, + { + "type": "function", + "function": { + "name": "drag_to", + "description": "Drag the cursor to the specified position with the left button pressed", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "X coordinate", + "minimum": 0, + "maximum": X_MAX, + }, + "y": { + "type": "number", + "description": "Y coordinate", + "minimum": 0, + "maximum": Y_MAX, + }, + }, + "required": ["x", "y"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "scroll", + "description": "Scroll the mouse wheel up or down", + "parameters": { + "type": "object", + "properties": { + "dx": {"type": "integer", "description": "Horizontal scroll amount"}, + "dy": {"type": "integer", "description": "Vertical scroll amount"}, + }, + "required": ["dx", "dy"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "typing", + "description": "Type the specified text", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string", "description": "Text to type"}}, + "required": ["text"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "press", + "description": "Press the specified key and release it", + "parameters": { + "type": "object", + "properties": { + "key": {"type": "string", "enum": KEYBOARD_KEYS, "description": "Key to press"} + }, + "required": ["key"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "key_down", + "description": "Press the specified key", + "parameters": { + "type": "object", + "properties": { + "key": { + "type": "string", + "enum": KEYBOARD_KEYS, + "description": "Key to press down", + } + }, + "required": ["key"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "key_up", + "description": "Release the specified key", + "parameters": { + "type": "object", + "properties": { + "key": { + "type": "string", + "enum": KEYBOARD_KEYS, + "description": "Key to release", + } + }, + "required": ["key"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "hotkey", + "description": "Press the specified key combination", + "parameters": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": {"type": "string", "enum": KEYBOARD_KEYS}, + "description": "Array of keys to press simultaneously", + } + }, + "required": ["keys"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "wait", + "description": "Wait until the next action", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + { + "type": "function", + "function": { + "name": "fail", + "description": "Decide the task cannot be performed", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, + { + "type": "function", + "function": { + "name": "done", + "description": "Decide the task is done", + "parameters": {"type": "object", "properties": {}, "required": []}, + }, + }, +] + + +class OsworldGym(AbstractEnv): + + def __init__( + self, + task: dict, + provider_name: str, + region: str | None, + path_to_vm: str | None, + snapshot_name: str, + action_space: str, + cache_dir: str, + screen_size: tuple[int, int], + headless: bool, + require_a11y_tree: bool, + require_terminal: bool, + os_type: str, + enable_proxy: bool, + max_steps: int, + exp_dir: Path, + record_video: bool = True, + ): + self.task = task + self.env_info = { + "provider_name": provider_name, + "region": region, + "path_to_vm": path_to_vm, + "snapshot_name": snapshot_name, + "action_space": action_space, + "cache_dir": cache_dir, + "screen_size": screen_size, + "headless": headless, + "require_a11y_tree": require_a11y_tree, + "require_terminal": require_terminal, + "os_type": os_type, + "enable_proxy": enable_proxy, + } + if DesktopEnv is None: + raise ImportError( + "desktop_env is not installed. Please install it (use `make osworld`) to use OSWorld Gym." + ) + self.env = DesktopEnv( + action_space=action_space, + provider_name=provider_name, + region=region, # type: ignore + path_to_vm=path_to_vm, # type: ignore + snapshot_name=snapshot_name, + cache_dir=cache_dir, + screen_size=screen_size, # type: ignore + headless=headless, + require_a11y_tree=require_a11y_tree, + require_terminal=require_terminal, + os_type=os_type, + ) + self._step_count = 0 + self.max_steps = max_steps + self.exp_dir = exp_dir + self.record_video = record_video + + def reset(self, seed: int | None = None) -> tuple[dict[str, Any], dict[str, Any]]: + self.env.reset(task_config=self.task, seed=seed) + logging.info(f"Start solving task: {self.task['instruction']}") + time.sleep( + 60 + ) # Wait for the environment to be ready, as in https://github.com/xlang-ai/OSWorld/blob/main/lib_run_single.py#L15 + raw_obs = self.env._get_obs() # Get the initial observation + if self.record_video: + self.env.controller.start_recording() + logging.info("Started recording the environment video") + obs = self.to_agentlab_observation(raw_obs) + self._step_count = 0 + return obs, self.env_info + + @add_step_timing_to_env_info_decorator + def step(self, action: str): + """Execute the action in the OS-world environment.""" + env_action = self.agentlab_to_env_action(action) + logger.info(f"AgentLab Action returned: {action}, converted to: {env_action}") + raw_obs, reward, done, info = self.env.step(env_action) + logger.info(f"STEP {self.task['id']} {self._step_count + 1}/{self.max_steps}") + self._step_count += 1 + truncated = info.get("fail", False) or self._step_count >= self.max_steps + if done or truncated: + if done: + logger.info(f"Task {self.task['id']} completed successfully.") + else: + logger.warning(f"Task {self.task['id']} truncated after {self._step_count} steps.") + try: + reward = self.env.evaluate() + logger.info(f"Evaluated reward: {reward}") + except Exception as e: + logger.error(f"Failed to evaluate {self.task} task: {e}") + obs = self.to_agentlab_observation(raw_obs) + return obs, reward, done, truncated, info + + def agentlab_to_env_action(self, action: str) -> Any: + """Convert AgentLab agents action format to OSWorld action format.""" + if self.env.action_space == "computer_13": + return self.convert_agentlab_action_to_computer_13(action) + elif self.env.action_space == "pyautogui": + raise NotImplementedError( + "PyAutoGUI action space is not supported yet. Please use 'computer_13' action space." + ) + + def to_agentlab_observation(self, obs: dict[str, Any]) -> dict[str, Any]: + """Convert OSWorld observation to AgentLab format.""" + converted_obs = {} + + self._add_screenshot(converted_obs, obs) + # self._add_som_screenshot(converted_obs, obs) #TODO: test this + converted_obs["axtree_txt"] = linearize_accessibility_tree( + accessibility_tree=obs["accessibility_tree"], platform="ubuntu" + ) + converted_obs["last_action_error"] = "" # OSWorld doesn't provide this directly + converted_obs["focused_element_bid"] = "" # Extract from accessibility tree if available + converted_obs = self._add_browser_context(converted_obs) + converted_obs = self._add_task_context(converted_obs, obs) + + return converted_obs + + def convert_screenshot_to_numpy(self, screenshot) -> np.ndarray: + """Convert screenshot to numpy array format expected by AgentLab.""" + image = Image.open(BytesIO(screenshot)) + image = image.convert("RGB") if image.mode != "RGB" else image + return np.array(image) + + def _add_screenshot(self, converted_obs: dict[str, Any], obs: dict[str, Any]) -> None: + """Convert screenshot to numpy array format expected by AgentLab""" + converted_obs["screenshot"] = self.convert_screenshot_to_numpy(obs["screenshot"]) + + def _add_som_screenshot(self, converted_obs: dict[str, Any], obs: dict[str, Any]) -> None: + """Convert SOM screenshot to numpy array format expected by AgentLab""" + masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot( + obs["screenshot"], obs["accessibility_tree"], platform="ubuntu" + ) + converted_obs["som_screenshot"] = self.convert_screenshot_to_numpy(tagged_screenshot) + + def _add_browser_context(self, converted_obs: dict[str, Any]): + """Add browser-like context fields adapted for desktop environment.""" + converted_obs["url"] = "" + converted_obs["open_pages_urls"] = [] + converted_obs["open_pages_titles"] = [] + converted_obs["active_page_index"] = 0 + return converted_obs + + def _add_task_context(self, converted_obs: dict[str, Any], obs: dict[str, Any]): + """Add task and instruction context fields.""" + instruction = obs.get("instruction", "") + converted_obs["goal_object"] = [{"type": "text", "text": instruction}] + if obs.get("terminal"): + converted_obs["terminal_output"] = obs["terminal"] + return converted_obs + + def convert_agentlab_action_to_computer_13(self, action: str) -> dict[str, Any] | str: + """Convert action string to dictionary format. + + Args: + action (str): Action string in AgentLab format, e.g., "move_to(x=100, y=200)". + + Returns: + dict[str, Any] | str: Action in OSWorld Computer 13 format as a dictionary, + or a string for simple actions like "wait", "done", or "fail". + + Examples: + >>> env = OsworldGym(task={}, provider_name="vmware", region=None, path_to_vm=None, + ... snapshot_name="init_state", action_space="computer_13", + ... cache_dir="cache", screen_size=(1920, 1080), headless=True, + ... require_a11y_tree=True, require_terminal=False, os_type="Ubuntu", + ... enable_proxy=False, max_steps=50, exp_dir=Path(".")) + >>> env.convert_agentlab_action_to_computer_13("move_to(x=100, y=200)") + {'action_type': 'MOVE_TO', 'parameters': {'x': 100, 'y': 200}} + >>> env.convert_agentlab_action_to_computer_13("wait()") + 'WAIT' + """ + + action_type, action_args, action_kwargs = self.parse_agentlab_action_str_to_func_args( + action + ) + + if action_type in ["wait", "done", "fail"]: + return str(action_type).upper() + if action_args: + logger.warning( + f"""Action '{action_type}' has unexpected positional arguments: {action_args}. + OSWorld Computer 13 actions are processed as dictionaries.""" + ) + action_kwargs = action_kwargs if action_kwargs is not None else {} + + return {"action_type": str(action_type).upper(), "parameters": action_kwargs} + + @staticmethod + def parse_agentlab_action_str_to_func_args(action: str): + """Parse the agentlab action string to extract function name, args, and kwargs. + + Args: + action (str): Action string in AgentLab format, e.g., "move_to(x=100, y=200)". + + Returns: + tuple: A tuple containing the function name, a list of positional arguments, + and a dictionary of keyword arguments. + + Examples: + >>> parse_agentlab_action_str_to_func_args("move_to(x=100, y=200)") + ('move_to', [], {'x': 100, 'y': 200}) + >>> parse_agentlab_action_str_to_func_args("hotkey(keys=['ctrl', 'alt', 't'])") + ('hotkey', [], {'keys': ['ctrl', 'alt', 't']}) + """ + try: + action = action.strip() + parsed = ast.parse(action, mode="eval") + if isinstance(parsed.body, ast.Call): + func_name = ast.unparse(parsed.body.func) + args = [ast.literal_eval(arg) for arg in parsed.body.args] + kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in parsed.body.keywords} + return func_name, args, kwargs + except Exception as e: + logger.warning( + f"Failed to parse agentlab agent's str function call: {action}, error: {e}" + ) + return None, None, None + + def close(self): + if self.record_video: + video_name = str(self.exp_dir / "recording.mp4") + self.env.controller.end_recording(video_name) + logger.info(f"Recorded video saved to {video_name}") + return self.env.close() + + +@dataclass +class OSWorldActionSet(AbstractActionSet, DataClassJsonMixin): + # TODO: Define and use agentlab AbstractActionSet + # AbstractActionSet should define some standard format to represent actions.(list of dict with keys that are MCP compatible) + # Should we have 'abstract function' here for action conversion for backend LLM with fixed action set like UI-Tars or Semi-fixed action set LLMs like OpenAI CUA? + # TODO: We need to support both 'action space as tools' and 'action space as prompt' for agentlab agents + # and have conversion functions to convert them to format acceptable by environment. + action_space: Literal["computer_13", "pyautogui"] = "computer_13" + multiaction: bool = False + + def describe(self, with_long_description: bool = True, with_examples: bool = True) -> str: + """Describe the OSWorld action set for desktop interactions.""" + pass + + def example_action(self, abstract: bool) -> str: + """Provide example actions for the action set.""" + pass + + def to_python_code(self, action) -> str: + """We use the OS-world/desktop_env environment controller""" + pass + + def to_tool_description(self, api="openai"): + """Convert the action set to a tool description for Tool-Use LLMs. + + The default for openai is openai Response API tools format. + + Args: + api (str): The API format to use. Defaults to "openai". + + Returns: + list[dict]: List of tool descriptions in the specified API format. + + Raises: + ValueError: If an unsupported action space is specified. + """ + # TODO: Rename bgym AbstractActionSet 'to_tool_descriptor' method as 'to_tool_description' for consistency. + if self.action_space == "computer_13": + tools = COMPUTER_13_ACTIONS_OAI_CHATCOMPLETION_TOOLS + + else: + raise ValueError( + "Only 'computer_13' action space is currently supported for tool description." + ) + api_formatters = { + "openai": lambda: format_chat_completion_tools_to_response_api(tools), + "chatcompletion": lambda: tools, + "anthropic": lambda: format_chat_completion_tools_to_anthropic(tools), + } + + if api not in api_formatters: + raise ValueError(f"Unsupported API type: {api}") + + return api_formatters[api]() + + +def format_chat_completion_tools_to_anthropic(tools: list[dict]) -> list[dict]: + """Convert OpenAI Response API tool format to Anthropic tool format.""" + formatted_tools = [] + for tool in tools: + function_def = tool["function"] + formatted_tool = { + "name": function_def["name"], + "description": function_def["description"], + "input_schema": function_def["parameters"], + } + formatted_tools.append(formatted_tool) + + return formatted_tools + + +def format_chat_completion_tools_to_response_api(tools: list[dict]) -> list[dict]: + """Convert tools from OpenAI Chat Completion format to Responses API format. + + Args: + tools: List of tools in Chat Completion format with nested function object + + Returns: + List of tools in Responses API format with flattened structure + """ + formatted_tools = [] + for tool in tools: + function_def = tool["function"] + formatted_tool = { + "type": "function", + "name": function_def["name"], + "description": function_def["description"], + "parameters": function_def["parameters"], + } + + # Handle the strict field if present + if "strict" in function_def: + formatted_tool["strict"] = function_def["strict"] + + formatted_tools.append(formatted_tool) + + return formatted_tools + + +@dataclass +class OsworldEnvArgs(AbstractEnvArgs): + task: dict[str, Any] + task_seed: int = 0 + task_name: str | None = None + path_to_vm: str | None = None # path to .vmx file + provider_name: str = "docker" # path to .vmx file + region: str = "us-east-1" # AWS specific, does not apply to all providers + snapshot_name: str = "init_state" # snapshot name to revert to + action_space: Literal["computer_13", "pyautogui"] = "computer_13" + cache_dir: str = "cache" + screen_size: tuple[int, int] = (1920, 1080) + headless: bool = False + require_a11y_tree: bool = True + require_terminal: bool = False + os_type: str = "Ubuntu" + enable_proxy: bool = False + max_steps: int = 50 + + def make_env( + self, exp_dir: Path, action_mapping=None, use_raw_page_output: bool = False + ) -> OsworldGym: + logger.info(f"Creating OSWorld Gym with task: {self.task}") + gym = OsworldGym( + task=self.task, + provider_name=self.provider_name, + region=self.region, + path_to_vm=self.path_to_vm, + snapshot_name=self.snapshot_name, + action_space=self.action_space, + cache_dir=self.cache_dir, + screen_size=self.screen_size, + headless=self.headless, + require_a11y_tree=self.require_a11y_tree, + require_terminal=self.require_terminal, + os_type=self.os_type, + enable_proxy=self.enable_proxy, + max_steps=self.max_steps, + exp_dir=exp_dir, + ) + return gym + + +class OsworldBenchmark(AbstractBenchmark): + name: str = "osworld" + is_multi_tab: bool = False + high_level_action_set_args: OSWorldActionSet = None # type: ignore + test_set_path: str = "OSWorld/evaluation_examples" + test_set_name: str = "test_all.json" + domain: str = "all" + env_args: OsworldEnvArgs = None # type: ignore # basic env configuration for all tasks + env_args_list: list[OsworldEnvArgs] = None # type: ignore + + def model_post_init(self, __context: Any) -> None: + self.env_args_list = [] + if not self.env_args: + self.env_args = OsworldEnvArgs(task={}) + self.high_level_action_set_args = OSWorldActionSet(action_space=self.env_args.action_space) + with open(os.path.join(self.test_set_path, self.test_set_name)) as f: + tasks = json.load(f) + if self.domain != "all": + tasks = {self.domain: tasks[self.domain]} + + for domain in tasks: + for task_id in tasks[domain]: + task_file = os.path.join(self.test_set_path, f"examples/{domain}/{task_id}.json") + with open(task_file) as f: + task = json.load(f) + task = self.fix_settings_file_path_in_config(task) + name = f"{self.name}.{task['id']}" + task_env_args = deepcopy(self.env_args) + task_env_args.task = task + task_env_args.task_name = name + self.env_args_list.append(task_env_args) + logger.info(f"Loaded {len(self.env_args_list)} tasks from domain '{self.domain}'") + + def fix_settings_file_path_in_config(self, task: dict) -> dict: + """Fix the settings file path in the task configuration. + + Args: + task: Task configuration dictionary. + + Returns: + Updated task configuration with fixed settings file paths. + """ + osworld_repo = os.getenv("OSWORLD_REPO", "OSWorld") + updated_task = deepcopy(task) # Avoid modifying the original task + for config in updated_task["config"]: + if config.get("parameters", False) and config["parameters"].get("settings_file", False): + config["parameters"]["settings_file"] = os.path.join( + osworld_repo, config["parameters"]["settings_file"] + ) + return updated_task diff --git a/src/agentlab/benchmarks/osworld_axtree_preprocessing.py b/src/agentlab/benchmarks/osworld_axtree_preprocessing.py new file mode 100644 index 00000000..f8cbcbc6 --- /dev/null +++ b/src/agentlab/benchmarks/osworld_axtree_preprocessing.py @@ -0,0 +1,340 @@ +import io +import xml.etree.ElementTree as ET +from typing import Tuple, List + +from PIL import Image, ImageDraw, ImageFont + + +def find_leaf_nodes(xlm_file_str): + if not xlm_file_str: + return [] + + root = ET.fromstring(xlm_file_str) + + # Recursive function to traverse the XML tree and collect leaf nodes + def collect_leaf_nodes(node, leaf_nodes): + # If the node has no children, it is a leaf node, add it to the list + if not list(node): + leaf_nodes.append(node) + # If the node has children, recurse on each child + for child in node: + collect_leaf_nodes(child, leaf_nodes) + + # List to hold all leaf nodes + leaf_nodes = [] + collect_leaf_nodes(root, leaf_nodes) + return leaf_nodes + + +attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes" +attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes" +state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state" +state_ns_windows = "https://accessibility.windows.example.org/ns/state" +component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component" +component_ns_windows = "https://accessibility.windows.example.org/ns/component" +value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value" +value_ns_windows = "https://accessibility.windows.example.org/ns/value" +class_ns_windows = "https://accessibility.windows.example.org/ns/class" + + +def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: + if platform == "ubuntu": + _state_ns = state_ns_ubuntu + _component_ns = component_ns_ubuntu + elif platform == "windows": + _state_ns = state_ns_windows + _component_ns = component_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + keeps: bool = ( + node.tag.startswith("document") + or node.tag.endswith("item") + or node.tag.endswith("button") + or node.tag.endswith("heading") + or node.tag.endswith("label") + or node.tag.endswith("scrollbar") + or node.tag.endswith("searchbox") + or node.tag.endswith("textbox") + or node.tag.endswith("link") + or node.tag.endswith("tabelement") + or node.tag.endswith("textfield") + or node.tag.endswith("textarea") + or node.tag.endswith("menu") + or node.tag + in { + "alert", + "canvas", + "check-box", + "combo-box", + "entry", + "icon", + "image", + "paragraph", + "scroll-bar", + "section", + "slider", + "static", + "table-cell", + "terminal", + "text", + "netuiribbontab", + "start", + "trayclockwclass", + "traydummysearchcontrol", + "uiimage", + "uiproperty", + "uiribboncommandbar", + } + ) + keeps = ( + keeps + and ( + platform == "ubuntu" + and node.get("{{{:}}}showing".format(_state_ns), "false") == "true" + and node.get("{{{:}}}visible".format(_state_ns), "false") == "true" + or platform == "windows" + and node.get("{{{:}}}visible".format(_state_ns), "false") == "true" + ) + and ( + node.get("{{{:}}}enabled".format(_state_ns), "false") == "true" + or node.get("{{{:}}}editable".format(_state_ns), "false") == "true" + or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true" + or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true" + ) + and ( + node.get("name", "") != "" + or node.text is not None + and len(node.text) > 0 + or check_image + and node.get("image", "false") == "true" + ) + ) + + coordinates: Tuple[int, int] = eval( + node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)") + ) + sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)")) + keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0 + return keeps + + +def filter_nodes(root: ET, platform="ubuntu", check_image=False): + filtered_nodes = [] + + for node in root.iter(): + if judge_node(node, platform, check_image): + filtered_nodes.append(node) + # print(ET.tostring(node, encoding="unicode")) + + return filtered_nodes + + +def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="ubuntu"): + + if platform == "ubuntu": + _state_ns = state_ns_ubuntu + _component_ns = component_ns_ubuntu + _value_ns = value_ns_ubuntu + elif platform == "windows": + _state_ns = state_ns_windows + _component_ns = component_ns_windows + _value_ns = value_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + # Load the screenshot image + image_stream = io.BytesIO(image_file_content) + image = Image.open(image_stream) + if float(down_sampling_ratio) != 1.0: + image = image.resize( + (int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)) + ) + draw = ImageDraw.Draw(image) + marks = [] + drew_nodes = [] + text_informations: List[str] = ["index\ttag\tname\ttext"] + + try: + # Adjust the path to the font file you have or use a default one + font = ImageFont.truetype("arial.ttf", 15) + except IOError: + # Fallback to a basic font if the specified font can't be loaded + font = ImageFont.load_default() + + index = 1 + + # Loop over all the visible nodes and draw their bounding boxes + for _node in nodes: + coords_str = _node.attrib.get("{{{:}}}screencoord".format(_component_ns)) + size_str = _node.attrib.get("{{{:}}}size".format(_component_ns)) + + if coords_str and size_str: + try: + # Parse the coordinates and size from the strings + coords = tuple(map(int, coords_str.strip("()").split(", "))) + size = tuple(map(int, size_str.strip("()").split(", "))) + + import copy + + original_coords = copy.deepcopy(coords) + original_size = copy.deepcopy(size) + + if float(down_sampling_ratio) != 1.0: + # Downsample the coordinates and size + coords = tuple(int(coord * down_sampling_ratio) for coord in coords) + size = tuple(int(s * down_sampling_ratio) for s in size) + + # Check for negative sizes + if size[0] <= 0 or size[1] <= 0: + raise ValueError(f"Size must be positive, got: {size}") + + # Calculate the bottom-right corner of the bounding box + bottom_right = (coords[0] + size[0], coords[1] + size[1]) + + # Check that bottom_right > coords (x1 >= x0, y1 >= y0) + if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]: + raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}") + + # Check if the area only contains one color + cropped_image = image.crop((*coords, *bottom_right)) + if len(set(list(cropped_image.getdata()))) == 1: + continue + + # Draw rectangle on image + draw.rectangle([coords, bottom_right], outline="red", width=1) + + # Draw index number at the bottom left of the bounding box with black background + text_position = ( + coords[0], + bottom_right[1], + ) # Adjust Y to be above the bottom right + text_bbox: Tuple[int, int, int, int] = draw.textbbox( + text_position, str(index), font=font, anchor="lb" + ) + # offset: int = bottom_right[1]-text_bbox[3] + # text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset) + + # draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black') + draw.rectangle(text_bbox, fill="black") + draw.text(text_position, str(index), font=font, anchor="lb", fill="white") + + # each mark is an x, y, w, h tuple + marks.append( + [original_coords[0], original_coords[1], original_size[0], original_size[1]] + ) + drew_nodes.append(_node) + + if _node.text: + node_text = ( + _node.text + if '"' not in _node.text + else '"{:}"'.format(_node.text.replace('"', '""')) + ) + elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith( + "EditWrapper" + ) and _node.get("{{{:}}}value".format(_value_ns)): + node_text = _node.get("{{{:}}}value".format(_value_ns), "") + node_text = ( + node_text + if '"' not in node_text + else '"{:}"'.format(node_text.replace('"', '""')) + ) + else: + node_text = '""' + text_information: str = "{:d}\t{:}\t{:}\t{:}".format( + index, _node.tag, _node.get("name", ""), node_text + ) + text_informations.append(text_information) + + index += 1 + + except ValueError: + pass + + output_image_stream = io.BytesIO() + image.save(output_image_stream, format="PNG") + image_content = output_image_stream.getvalue() + + return marks, drew_nodes, "\n".join(text_informations), image_content + + +def print_nodes_with_indent(nodes, indent=0): + for node in nodes: + print(" " * indent, node.tag, node.attrib) + print_nodes_with_indent(node, indent + 2) + + +def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): + + if platform == "ubuntu": + _attributes_ns = attributes_ns_ubuntu + _state_ns = state_ns_ubuntu + _component_ns = component_ns_ubuntu + _value_ns = value_ns_ubuntu + elif platform == "windows": + _attributes_ns = attributes_ns_windows + _state_ns = state_ns_windows + _component_ns = component_ns_windows + _value_ns = value_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) + linearized_accessibility_tree = [ + "tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)" + ] + + # Linearize the accessibility tree nodes into a table format + for node in filtered_nodes: + if node.text: + text = ( + node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""')) + ) + + elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith( + "EditWrapper" + ) and node.get("{{{:}}}value".format(_value_ns)): + node_text = node.get("{{{:}}}value".format(_value_ns), "") + text = ( + node_text if '"' not in node_text else '"{:}"'.format(node_text.replace('"', '""')) + ) + else: + text = '""' + + linearized_accessibility_tree.append( + "{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format( + node.tag, + node.get("name", ""), + text, + ( + node.get("{{{:}}}class".format(_attributes_ns), "") + if platform == "ubuntu" + else node.get("{{{:}}}class".format(class_ns_windows), "") + ), + node.get("{{{:}}}description".format(_attributes_ns), ""), + node.get("{{{:}}}screencoord".format(_component_ns), ""), + node.get("{{{:}}}size".format(_component_ns), ""), + ) + ) + + return "\n".join(linearized_accessibility_tree) + + +def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"): + nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True) + # Make tag screenshot + marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot) + + return marks, drew_nodes, tagged_screenshot, element_list + + +def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): + import tiktoken + + enc = tiktoken.encoding_for_model("gpt-4") + tokens = enc.encode(linearized_accessibility_tree) + if len(tokens) > max_tokens: + linearized_accessibility_tree = enc.decode(tokens[:max_tokens]) + linearized_accessibility_tree += "[...]\n" + return linearized_accessibility_tree diff --git a/src/agentlab/benchmarks/setup.md b/src/agentlab/benchmarks/setup.md new file mode 100644 index 00000000..071670d1 --- /dev/null +++ b/src/agentlab/benchmarks/setup.md @@ -0,0 +1,55 @@ +# Setup OSWorld in AgentLab + +This guide walks you through setting up the OSWorld benchmark in AgentLab for GUI automation testing. + +## Installation + +1. **Clone and install OSWorld repository:** + ```bash + make osworld + ``` + +2. **Complete OSWorld setup:** + - Navigate to the `OSWorld/` directory + - Follow the detailed setup instructions in the OSWorld README + - Download required VM images and configure virtual machines + + +## Usage + +### Entry Point Configuration + +The main entry point `experiments/run_osworld.py` is currently configured with hardcoded parameters. To modify the execution: + +1. **Edit the script directly** to change: + - `n_jobs`: Number of parallel jobs (default: 4, set to 1 for debugging) + - `use_vmware`: Set to `True` for VMware, `False` for other platforms + - `relaunch`: Whether to continue incomplete studies + - `agent_args`: List of agents to test (OSWORLD_CLAUDE, OSWORLD_OAI) + - `test_set_name`: Choose between "test_small.json" or "test_all.json" + +2. **Environment Variables:** + - `AGENTLAB_DEBUG=1`: Automatically runs the debug subset (7 tasks from `osworld_debug_task_ids.json`) + +### Running OSWorld Tasks + +We provide different subsets of tasks: + +- **Debug subset:** 7 tasks defined in `experiments/osworld_debug_task_ids.json` +- **Small subset:** Tasks from `test_small.json` +- **Full subset:** All tasks from `test_all.json` + +### Example Commands + +```bash +# Run with default debug subset (7 tasks) +python experiments/run_osworld.py +``` + + +### Configuration Notes + +- **VMware path:** Currently hardcoded to `"OSWorld/vmware_vm_data/Ubuntu0/Ubuntu0.vmx"` +- **Parallel execution:** Automatically switches to sequential when using VMware +- **Relaunch capability:** Can continue incomplete studies by loading the most recent study + diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index 810b8bc2..66efde19 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -17,6 +17,7 @@ from agentlab.agents.agent_args import AgentArgs from agentlab.analyze import inspect_results +from agentlab.benchmarks.abstract_env import AbstractEnvArgs from agentlab.experiments import reproducibility_util as repro from agentlab.experiments.exp_utils import RESULTS_DIR, add_dependencies from agentlab.experiments.launch_exp import ( @@ -744,7 +745,7 @@ def _convert_env_args(env_args_list): new_list = [] for ea in env_args_list: # already new → keep as‑is - if isinstance(ea, EnvArgs): + if isinstance(ea, (EnvArgs, AbstractEnvArgs)): new_list.append(ea) # old → convert elif isinstance(ea, BGymEnvArgs): diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index e8c74849..1bbeeebc 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -112,7 +112,6 @@ class LLMOutput: class MessageBuilder: def __init__(self, role: str): - self.role = role self.content: List[ContentItem] = [] self.responded_tool_calls: ToolCalls = None @@ -259,7 +258,6 @@ def mark_all_previous_msg_for_caching(self): class AnthropicAPIMessageBuilder(MessageBuilder): - def prepare_message(self) -> List[Message]: content = [self.transform_content(item) for item in self.content] output = {"role": self.role, "content": content} @@ -327,7 +325,6 @@ def mark_all_previous_msg_for_caching(self) -> List[Message]: class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder): - def prepare_message(self) -> List[Message]: """Prepare the message for the OpenAI API.""" content = [] diff --git a/tests/benchmarks/test_osworld.py b/tests/benchmarks/test_osworld.py new file mode 100644 index 00000000..2424f8f4 --- /dev/null +++ b/tests/benchmarks/test_osworld.py @@ -0,0 +1,297 @@ +import importlib.util +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +spec = importlib.util.find_spec("desktop_env") +if spec is None: + DESKTOP_ENV_AVAILABLE = False + OSWorldActionSet = None + OsworldEnvArgs = None + OsworldGym = None +else: + # If desktop_env is available, import the necessary classes + from agentlab.benchmarks.osworld import ( + OSWorldActionSet, + OsworldEnvArgs, + OsworldGym, + ) + + DESKTOP_ENV_AVAILABLE = True + + +# Skip the entire module if desktop_env is not available +pytestmark = pytest.mark.skipif(not DESKTOP_ENV_AVAILABLE, reason="desktop_env not installed") + + +def mock_task_config() -> dict: + """Mock task configuration for testing.""" + return { + "id": "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", + "snapshot": "chrome", + "instruction": "Can you make Bing the main search thingy when I look stuff up on the internet?", + "source": "https://support.google.com/chrome/answer/95426", + "config": [ + { + "type": "launch", + "parameters": {"command": ["google-chrome", "--remote-debugging-port=1337"]}, + } + ], + "trajectory": "trajectories/", + "related_apps": ["chrome"], + "evaluator": { + "func": "match_in_list", + "result": {"type": "default_search_engine"}, + "expected": {"type": "rule", "rules": {"expected": ["Microsoft Bing", "Bing"]}}, + }, + "proxy": False, + } + + +class TestOSWorldActionSet: + """Test cases for OSWorld action set functionality.""" + + def test_action_set_creation(self): + """Test basic action set creation.""" + action_set = OSWorldActionSet(action_space="computer_13") + assert action_set.action_space == "computer_13" + + def test_to_tool_description_openai(self): + """Test tool description conversion for OpenAI format.""" + action_set = OSWorldActionSet(action_space="computer_13") + tools = action_set.to_tool_description(api="openai") + + assert isinstance(tools, list) + assert len(tools) > 0 + + # Check that tools have the expected structure + tool = tools[0] + assert "type" in tool + assert "name" in tool + assert "description" in tool + assert "parameters" in tool + assert tool["type"] == "function" + + def test_to_tool_description_anthropic(self): + """Test tool description conversion for Anthropic format.""" + action_set = OSWorldActionSet(action_space="computer_13") + tools = action_set.to_tool_description(api="anthropic") + + assert isinstance(tools, list) + assert len(tools) > 0 + + # Check that tools have the Anthropic format + tool = tools[0] + assert "name" in tool + assert "description" in tool + assert "input_schema" in tool + # Anthropic format doesn't have "type" field + + def test_unsupported_action_space(self): + """Test that unsupported action spaces raise ValueError.""" + action_set = OSWorldActionSet(action_space="pyautogui") + with pytest.raises( + ValueError, match="Only 'computer_13' action space is currently supported" + ): + action_set.to_tool_description() + + +class TestOsworldEnvArgs: + """Test cases for OSWorld environment arguments.""" + + def test_env_args_creation(self): + """Test basic environment args creation.""" + task = mock_task_config() + env_args = OsworldEnvArgs(task=task, task_name="test_task", max_steps=10) + + assert env_args.task == task + assert env_args.task_name == "test_task" + assert env_args.max_steps == 10 + assert env_args.action_space == "computer_13" # default + assert env_args.provider_name == "docker" # default + + def test_env_args_custom_config(self): + """Test environment args with custom configuration.""" + task = mock_task_config() + env_args = OsworldEnvArgs( + task=task, + task_name="custom_task", + action_space="computer_13", + provider_name="vmware", + headless=True, + screen_size=(1280, 720), + max_steps=25, + ) + + assert env_args.action_space == "computer_13" + assert env_args.provider_name == "vmware" + assert env_args.headless is True + assert env_args.screen_size == (1280, 720) + assert env_args.max_steps == 25 + + @patch("agentlab.benchmarks.osworld.OsworldGym") + def test_make_env(self, mock_gym_class): + """Test environment creation from args.""" + task = mock_task_config() + env_args = OsworldEnvArgs(task=task, task_name="test_task") + + with tempfile.TemporaryDirectory() as tmp_dir: + exp_dir = Path(tmp_dir) + env_args.make_env(exp_dir) + + # Verify that OsworldGym was called with correct arguments + mock_gym_class.assert_called_once() + call_args = mock_gym_class.call_args[1] + assert call_args["task"] == task + assert call_args["exp_dir"] == exp_dir + + +class TestOsworldGym: + """Test cases for OSWorld gym functionality.""" + + def test_gym_action_parsing(self): + """Test gym action parsing functionality.""" + + from agentlab.benchmarks.osworld import OsworldGym + + # Test various action strings including edge cases + test_cases = [ + # Basic actions + ("wait()", ("wait", [], {})), + ("done()", ("done", [], {})), + ("move_to(x=100, y=200)", ("move_to", [], {"x": 100, "y": 200})), + ('typing(text="hello world")', ("typing", [], {"text": "hello world"})), + ("hotkey(keys=['ctrl', 'c'])", ("hotkey", [], {"keys": ["ctrl", "c"]})), + # Edge cases with strings + ('typing(text="")', ("typing", [], {"text": ""})), # Empty string + ('typing(text="line1\\nline2")', ("typing", [], {"text": "line1\nline2"})), # Newlines + ('typing(text="tab\\there")', ("typing", [], {"text": "tab\there"})), # Tabs + ( + 'typing(text="quote\\"test")', + ("typing", [], {"text": 'quote"test'}), + ), # Escaped quotes + ( + 'typing(text="single\'quote")', + ("typing", [], {"text": "single'quote"}), + ), # Single quotes + ('typing(text="unicode: café")', ("typing", [], {"text": "unicode: café"})), # Unicode + # Edge cases with coordinates + ("move_to(x=0, y=0)", ("move_to", [], {"x": 0, "y": 0})), # Zero coordinates + ( + "move_to(x=-10, y=-20)", + ("move_to", [], {"x": -10, "y": -20}), + ), # Negative coordinates + ( + "move_to(x=9999, y=9999)", + ("move_to", [], {"x": 9999, "y": 9999}), + ), # Large coordinates + # Edge cases with lists + ("hotkey(keys=[])", ("hotkey", [], {"keys": []})), # Empty list + ("hotkey(keys=['ctrl'])", ("hotkey", [], {"keys": ["ctrl"]})), # Single key + ( + "hotkey(keys=['ctrl', 'shift', 'alt', 'a'])", + ("hotkey", [], {"keys": ["ctrl", "shift", "alt", "a"]}), + ), # Multiple keys + # Edge cases with boolean values + ("scroll(direction='up', clicks=3)", ("scroll", [], {"direction": "up", "clicks": 3})), + ( + "click(x=100, y=200, button='left')", + ("click", [], {"x": 100, "y": 200, "button": "left"}), + ), + # Edge cases with mixed parameter types + ( + "complex_action(text='test', x=50, enabled=True, items=['a', 'b'])", + ( + "complex_action", + [], + {"text": "test", "x": 50, "enabled": True, "items": ["a", "b"]}, + ), + ), + # Edge cases with whitespace + (" wait() ", ("wait", [], {})), # Leading/trailing spaces + ( + "move_to( x=100 , y=200 )", + ("move_to", [], {"x": 100, "y": 200}), + ), # Spaces around params + # Edge cases with special characters in strings + ( + 'typing(text="@#$%^&*()+={}[]|\\:;\'<>?,./")', + ("typing", [], {"text": "@#$%^&*()+={}[]|\\:;'<>?,./"}), + ), + ] + + for action_str, expected in test_cases: + result = OsworldGym.parse_agentlab_action_str_to_func_args(action_str) + assert result == expected, f"Failed parsing: {action_str}" + + @patch("agentlab.benchmarks.osworld.DesktopEnv") + def test_gym_creation(self, mock_desktop_env): + """Test OSWorld gym creation.""" + task = mock_task_config() + + with tempfile.TemporaryDirectory() as tmp_dir: + exp_dir = Path(tmp_dir) + gym = OsworldGym( + task=task, + provider_name="docker", + region=None, + path_to_vm=None, + snapshot_name="init_state", + action_space="computer_13", + cache_dir="cache", + screen_size=(1920, 1080), + headless=True, + require_a11y_tree=True, + require_terminal=False, + os_type="Ubuntu", + enable_proxy=False, + max_steps=50, + exp_dir=exp_dir, + ) + + assert gym.task == task + assert gym._step_count == 0 + assert gym.max_steps == 50 + assert gym.exp_dir == exp_dir + + def test_convert_agentlab_action_to_computer_13(self): + """Test action conversion from AgentLab to Computer 13 format.""" + task = mock_task_config() + + with tempfile.TemporaryDirectory() as tmp_dir: + exp_dir = Path(tmp_dir) + + with patch("agentlab.benchmarks.osworld.DesktopEnv"): + gym = OsworldGym( + task=task, + provider_name="docker", + region=None, + path_to_vm=None, + snapshot_name="init_state", + action_space="computer_13", + cache_dir="cache", + screen_size=(1920, 1080), + headless=True, + require_a11y_tree=True, + require_terminal=False, + os_type="Ubuntu", + enable_proxy=False, + max_steps=50, + exp_dir=exp_dir, + ) + + # Test simple action + result = gym.convert_agentlab_action_to_computer_13("wait()") + assert result == "WAIT" + + # Test action with parameters + result = gym.convert_agentlab_action_to_computer_13("move_to(x=100, y=200)") + expected = {"action_type": "MOVE_TO", "parameters": {"x": 100, "y": 200}} + assert result == expected + + # Test typing action + result = gym.convert_agentlab_action_to_computer_13('typing(text="hello")') + expected = {"action_type": "TYPING", "parameters": {"text": "hello"}} + assert result == expected