diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py new file mode 100644 index 00000000..c818e60a --- /dev/null +++ b/src/agentlab/agents/vl_agent/config.py @@ -0,0 +1,48 @@ +from browsergym.experiments.benchmark import HighLevelActionSetArgs +from .vl_agent.ui_agent import UIAgentArgs +from .vl_model.llama_model import LlamaModelArgs +from .vl_model.openrouter_api_model import OpenRouterAPIModelArgs +from .vl_prompt.ui_prompt import UIPromptArgs + + +VL_MODEL_ARGS_DICT = { + "gpt_4o": OpenRouterAPIModelArgs( + base_url="https://openrouter.ai/api/v1", + model_id="openai/gpt-4o-2024-11-20", + max_tokens=8192, + reproducibility_config={"temperature": 0.1}, + ), + "llama_32_11b": LlamaModelArgs( + model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", + torch_dtype="bfloat16", + accelerator_config={"mixed_precision": "bf16", "cpu": False}, + reproducibility_config={"temperature": 0.1}, + max_length=32768, + max_new_tokens=8192, + checkpoint_file=None, + device=None, + ), +} + +VL_PROMPT_ARGS_DICT = { + "ui_prompt": UIPromptArgs( + use_screenshot=True, + use_screenshot_som=False, + use_tabs=True, + use_history=True, + use_error=True, + use_abstract_example=True, + use_concrete_example=False, + extra_instruction=None, + ) +} + +VL_AGENT_ARGS_DICT = { + "ui_agent": UIAgentArgs( + main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"], + auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"], + action_set_args=HighLevelActionSetArgs(subsets=["coord"]), + max_retry=4, + ) +} diff --git a/src/agentlab/agents/vl_agent/main.py b/src/agentlab/agents/vl_agent/main.py new file mode 100644 index 00000000..eeb1184c --- /dev/null +++ b/src/agentlab/agents/vl_agent/main.py @@ -0,0 +1,33 @@ +from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT +from agentlab.experiments.study import Study +import logging +import os + + +logging.getLogger().setLevel(logging.INFO) + +vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]] +benchmark = "miniwob" +os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/" +reproducibility_mode = False +relaunch = False +n_jobs = 1 + + +if __name__ == "__main__": + if reproducibility_mode: + for vl_agent_args in vl_agent_args_list: + vl_agent_args.set_reproducibility_mode() + if relaunch: + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + else: + study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING) + study.run( + n_jobs=n_jobs, + parallel_backend="sequential", + strict_reproducibility=reproducibility_mode, + n_relaunch=3, + ) + if reproducibility_mode: + study.append_to_journal(strict_reproducibility=True) diff --git a/src/agentlab/agents/vl_agent/utils.py b/src/agentlab/agents/vl_agent/utils.py new file mode 100644 index 00000000..36db2dd9 --- /dev/null +++ b/src/agentlab/agents/vl_agent/utils.py @@ -0,0 +1,39 @@ +from accelerate import dispatch_model, infer_auto_device_map +from accelerate.utils.modeling import get_balanced_memory +from PIL import Image +from torch.nn import Module +from typing import Union +import base64 +import io +import numpy as np + + +def image_to_image_url(image: Union[Image.Image, np.ndarray]): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") + buffer = io.BytesIO() + image.save(buffer, format="JPEG") + image_base64 = base64.b64encode(buffer.getvalue()).decode() + image_url = f"data:image/jpeg;base64,{image_base64}" + return image_url + + +def image_url_to_image(image_url: str) -> Image.Image: + image_base64 = image_url.replace("data:image/jpeg;base64,", "") + image_data = base64.b64decode(image_base64.encode()) + buffer = io.BytesIO(image_data) + image = Image.open(buffer) + return image + + +def auto_dispatch_model(model: Module, no_split_module_classes: list[str]) -> Module: + max_memory = get_balanced_memory(model, no_split_module_classes=no_split_module_classes) + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + ) + model = dispatch_model(model, device_map=device_map) + return model diff --git a/src/agentlab/agents/vl_agent/vl_agent/base.py b/src/agentlab/agents/vl_agent/vl_agent/base.py new file mode 100644 index 00000000..8dde21a7 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_agent/base.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod +from browsergym.core.action.highlevel import HighLevelActionSet +from browsergym.experiments.benchmark import Benchmark +from dataclasses import dataclass + + +class VLAgent(ABC): + @property + @abstractmethod + def action_set(self) -> HighLevelActionSet: + raise NotImplementedError + + @abstractmethod + def get_action(self, obs: dict) -> tuple[str, dict]: + raise NotImplementedError + + @abstractmethod + def obs_preprocessor(self, obs: dict) -> dict: + raise NotImplementedError + + +@dataclass +class VLAgentArgs(ABC): + @property + @abstractmethod + def agent_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def make_agent(self) -> VLAgent: + raise NotImplementedError + + @abstractmethod + def prepare(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def set_reproducibility_mode(self): + raise NotImplementedError + + @abstractmethod + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py new file mode 100644 index 00000000..c850894e --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -0,0 +1,145 @@ +from agentlab.llm.llm_utils import ParseError, retry +from agentlab.llm.tracking import cost_tracker_decorator +from browsergym.core.action.highlevel import HighLevelActionSet +from browsergym.experiments.agent import AgentInfo +from browsergym.experiments.benchmark import Benchmark +from browsergym.experiments.benchmark.base import HighLevelActionSetArgs +from browsergym.utils.obs import overlay_som +from copy import copy, deepcopy +from dataclasses import asdict, dataclass +from functools import cache +from typing import Optional +from .base import VLAgent, VLAgentArgs +from ..vl_model.base import VLModelArgs +from ..vl_prompt.ui_prompt import UIPromptArgs + + +class UIAgent(VLAgent): + def __init__( + self, + main_vl_model_args: VLModelArgs, + auxiliary_vl_model_args: Optional[VLModelArgs], + ui_prompt_args: UIPromptArgs, + action_set_args: HighLevelActionSetArgs, + max_retry: int, + ): + self.main_vl_model = main_vl_model_args.make_model() + if auxiliary_vl_model_args is None: + self.auxiliary_vl_model = None + else: + self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() + self.ui_prompt_args = ui_prompt_args + self.action_set_args = action_set_args + self.max_retry = max_retry + self.thoughts = [] + self.actions = [] + + @property + @cache + def action_set(self) -> HighLevelActionSet: + return self.action_set_args.make_action_set() + + @cost_tracker_decorator + def get_action(self, obs: dict) -> tuple[str, dict]: + ui_prompt = self.ui_prompt_args.make_prompt( + obs=obs, thoughts=self.thoughts, actions=self.actions, action_set=self.action_set + ) + try: + messages = ui_prompt.get_messages() + answer = retry( + chat=self.main_vl_model, + messages=messages, + n_retry=self.max_retry, + parser=ui_prompt.parse_answer, + ) + stats = {"num_main_retries": (len(messages) - 3) // 2} + except ParseError: + answer = {"thought": None, "action": None} + stats = {"num_main_retries": self.max_retry} + stats.update(self.main_vl_model.get_stats()) + if self.auxiliary_vl_model is not None: + preliminary_answer = answer + ui_prompt = self.ui_prompt_args.make_prompt( + obs=obs, + thoughts=self.thoughts, + actions=self.actions, + action_set=self.action_set, + preliminary_answer=preliminary_answer, + ) + try: + messages = ui_prompt.get_messages() + answer = retry( + chat=self.auxiliary_vl_model, + messages=messages, + n_retry=self.max_retry, + parser=ui_prompt.parse_answer, + ) + stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 + except ParseError: + answer = {"thought": None, "action": None} + stats["num_auxiliary_retries"] = self.max_retry + stats.update(self.auxiliary_vl_model.get_stats()) + else: + preliminary_answer = None + self.thoughts.append(str(answer["thought"])) + self.actions.append(str(answer["action"])) + agent_info = AgentInfo( + think=str(answer["thought"]), stats=stats, extra_info=preliminary_answer + ) + return answer["action"], asdict(agent_info) + + def obs_preprocessor(self, obs: dict) -> dict: + obs = copy(obs) + if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som: + obs["screenshot"] = overlay_som( + obs["screenshot"], extra_properties=obs["extra_element_properties"] + ) + return obs + + +@dataclass +class UIAgentArgs(VLAgentArgs): + main_vl_model_args: VLModelArgs + auxiliary_vl_model_args: Optional[VLModelArgs] + ui_prompt_args: UIPromptArgs + action_set_args: HighLevelActionSetArgs + max_retry: int + + @property + @cache + def agent_name(self) -> str: + if self.auxiliary_vl_model_args is None: + return f"UIAgent-{self.main_vl_model_args.model_name}" + else: + return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" + + def make_agent(self) -> UIAgent: + self.ui_agent = UIAgent( + main_vl_model_args=self.main_vl_model_args, + auxiliary_vl_model_args=self.auxiliary_vl_model_args, + ui_prompt_args=self.ui_prompt_args, + action_set_args=self.action_set_args, + max_retry=self.max_retry, + ) + return self.ui_agent + + def prepare(self): + self.main_vl_model_args.prepare() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.prepare() + + def close(self): + self.main_vl_model_args.close() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.close() + + def set_reproducibility_mode(self): + self.main_vl_model_args.set_reproducibility_mode() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.set_reproducibility_mode() + + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + self.ui_prompt_args.use_tabs = benchmark.is_multi_tab + self.action_set_args = deepcopy(benchmark.high_level_action_set_args) + if demo_mode: + self.action_set_args.demo_mode = "all_blue" diff --git a/src/agentlab/agents/vl_agent/vl_model/base.py b/src/agentlab/agents/vl_agent/vl_model/base.py new file mode 100644 index 00000000..ce188183 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from agentlab.llm.llm_utils import AIMessage, Discussion + + +class VLModel(ABC): + @abstractmethod + def __call__(self, messages: Discussion) -> AIMessage: + raise NotImplementedError + + @abstractmethod + def get_stats(self) -> dict: + raise NotImplementedError + + +class VLModelArgs(ABC): + @property + @abstractmethod + def model_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def make_model(self) -> VLModel: + raise NotImplementedError + + @abstractmethod + def prepare(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def set_reproducibility_mode(self): + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py new file mode 100644 index 00000000..968877df --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -0,0 +1,128 @@ +from accelerate import Accelerator +from accelerate.utils.modeling import load_checkpoint_in_model +from agentlab.llm.llm_utils import AIMessage, Discussion +from dataclasses import dataclass +from functools import cache +from transformers import AutoProcessor, MllamaForConditionalGeneration +from typing import Optional +from .base import VLModel, VLModelArgs +from ..utils import auto_dispatch_model, image_url_to_image + + +class LlamaModel(VLModel): + def __init__( + self, + model_path: str, + torch_dtype: str, + accelerator_config: dict, + reproducibility_config: dict, + max_length: int, + max_new_tokens: int, + ): + self.model = MllamaForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch_dtype + ) + self.processor = AutoProcessor.from_pretrained(model_path) + self.accelerator = Accelerator(**accelerator_config) + self.reproducibility_config = reproducibility_config + self.max_length = max_length + self.max_new_tokens = max_new_tokens + + def __call__(self, messages: Discussion) -> AIMessage: + input_messages = [] + input_images = [] + for message in messages: + input_message = {"role": message["role"], "content": []} + if isinstance(message["content"], str): + input_message["content"].append({"type": "text", "text": message["content"]}) + else: + for item in message["content"]: + if item["type"] == "text": + input_message["content"].append(item) + elif item["type"] == "image_url": + input_message["content"].append({"type": "image"}) + input_images.append(image_url_to_image(item["image_url"]["url"])) + input_messages.append(input_message) + input_text = self.processor.apply_chat_template( + input_messages, add_generation_prompt=True, tokenize=False + ) + input = self.processor( + images=input_images, + text=input_text, + add_special_tokens=False, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ).to(self.model.device) + with self.accelerator.autocast(): + output = self.model.generate( + **input, + eos_token_id=self.processor.tokenizer.eos_token_id, + max_new_tokens=self.max_new_tokens, + use_cache=True, + **self.reproducibility_config, + ) + output_text = self.processor.tokenizer.batch_decode( + output[:, input["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + return AIMessage([{"type": "text", "text": output_text}]) + + def get_stats(self) -> dict: + return {} + + +@dataclass +class LlamaModelArgs(VLModelArgs): + model_path: str + torch_dtype: str + accelerator_config: dict + reproducibility_config: dict + max_length: int + max_new_tokens: int + checkpoint_file: Optional[str] + device: Optional[str] + + @property + @cache + def model_name(self) -> str: + return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") + + def make_model(self) -> LlamaModel: + llama_model = LlamaModel( + model_path=self.model_path, + torch_dtype=self.torch_dtype, + accelerator_config=self.accelerator_config, + reproducibility_config=self.reproducibility_config, + max_length=self.max_length, + max_new_tokens=self.max_new_tokens, + ) + if self.checkpoint_file is not None: + load_checkpoint_in_model(llama_model.model, checkpoint=self.checkpoint_file) + if self.device is None: + layer_classes = set() + for layer in llama_model.model.language_model.model.layers: + layer_classes.add(layer.__class__) + for layer in llama_model.model.vision_model.transformer.layers: + layer_classes.add(layer.__class__) + for layer in llama_model.model.vision_model.global_transformer.layers: + layer_classes.add(layer.__class__) + llama_model.model = auto_dispatch_model( + llama_model.model, + no_split_module_classes=[layer_class.__name__ for layer_class in layer_classes], + ) + else: + llama_model.model = llama_model.model.to(self.device) + llama_model.model.eval() + self.llama_model = llama_model + return self.llama_model + + def prepare(self): + pass + + def close(self): + del self.llama_model.model + + def set_reproducibility_mode(self): + self.reproducibility_config = {"do_sample": False} diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py new file mode 100644 index 00000000..d8c5b46f --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -0,0 +1,70 @@ +from agentlab.llm.llm_utils import AIMessage, Discussion +from dataclasses import dataclass +from functools import cache +from openai import OpenAI, RateLimitError +from .base import VLModel, VLModelArgs +import backoff +import os + + +class OpenRouterAPIModel(VLModel): + def __init__( + self, + base_url: str, + model_id: str, + max_tokens: int, + reproducibility_config: dict, + ): + self.client = OpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) + self.model_id = model_id + self.max_tokens = max_tokens + self.reproducibility_config = reproducibility_config + + def __call__(self, messages: Discussion) -> AIMessage: + @backoff.on_exception(backoff.expo, RateLimitError) + def get_response(messages, max_tokens, **kwargs): + completion = self.client.chat.completions.create( + model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs + ) + try: + response = completion.choices[0].message.content + except: + response = "" + return response + + response = get_response(messages, self.max_tokens, **self.reproducibility_config) + return AIMessage([{"type": "text", "text": response}]) + + def get_stats(self) -> dict: + return {} + + +@dataclass +class OpenRouterAPIModelArgs(VLModelArgs): + base_url: str + model_id: str + max_tokens: int + reproducibility_config: dict + + @property + @cache + def model_name(self) -> str: + return self.model_id.split("/")[-1].replace("-", "_").replace(".", "") + + def make_model(self) -> OpenRouterAPIModel: + self.openrouter_api_model = OpenRouterAPIModel( + base_url=self.base_url, + model_id=self.model_id, + max_tokens=self.max_tokens, + reproducibility_config=self.reproducibility_config, + ) + return self.openrouter_api_model + + def prepare(self): + pass + + def close(self): + pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"temperature": 0.0} diff --git a/src/agentlab/agents/vl_agent/vl_prompt/base.py b/src/agentlab/agents/vl_agent/vl_prompt/base.py new file mode 100644 index 00000000..3658ef4d --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_prompt/base.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from agentlab.llm.llm_utils import Discussion +from browsergym.core.action.highlevel import HighLevelActionSet +from typing import Optional + + +class VLPromptPart(ABC): + @abstractmethod + def get_message_content(self) -> list[dict]: + raise NotImplementedError + + +class VLPrompt(ABC): + @abstractmethod + def get_messages(self) -> Discussion: + raise NotImplementedError + + @abstractmethod + def parse_answer(self, answer_text: str) -> dict: + raise NotImplementedError + + +class VLPromptArgs(ABC): + @abstractmethod + def make_prompt( + self, + obs: dict, + thoughts: list[str], + actions: list[str], + action_set: HighLevelActionSet, + extra_instruction: Optional[str] = None, + preliminary_answer: Optional[dict] = None, + ) -> VLPrompt: + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py new file mode 100644 index 00000000..b50e3eea --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -0,0 +1,313 @@ +from agentlab.llm.llm_utils import ( + Discussion, + extract_code_blocks, + HumanMessage, + ParseError, + parse_html_tags_raise, +) +from browsergym.core.action.highlevel import HighLevelActionSet +from dataclasses import dataclass +from PIL import Image +from typing import Optional, Union +from .base import VLPrompt, VLPromptArgs, VLPromptPart +from ..utils import image_to_image_url +import numpy as np + + +class IntroductionPromptPart(VLPromptPart): + def __init__(self): + self.text = """\ +You are an agent working to address a web-based task through step-by-step interactions with the browser. \ +To achieve the goal of the task, at each step, you need to submit an action according to the current state of the browser. \ +This action will be executed to update the state of the browser, and you will proceed to the next step. +""" + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class GoalPromptPart(VLPromptPart): + def __init__(self, goal_object: list[dict]): + text = """\ +# The goal of the task +""" + for item in goal_object: + if item["type"] == "text": + text += f"""\ +{item['text']} +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class ScreenshotPromptPart(VLPromptPart): + def __init__(self, screenshot: Union[Image.Image, np.ndarray]): + self.text = """\ +# The screenshot of the current web page +""" + self.image_url = image_to_image_url(screenshot) + + def get_message_content(self) -> list[dict]: + return [ + {"type": "text", "text": self.text}, + {"type": "image_url", "image_url": {"url": self.image_url}}, + ] + + +class TabsPromptPart(VLPromptPart): + def __init__( + self, open_pages_titles: list[str], open_pages_urls: list[str], active_page_index: int + ): + text = """\ +# The open tabs of the browser +""" + for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): + text += f"""\ +## Tab {index}{' (active tab)' if index == active_page_index else ''} +### Title +{title} +### URL +{url} +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class HistoryPromptPart(VLPromptPart): + def __init__(self, thoughts: list[str], actions: list[str]): + text = """\ +# The thoughts and actions of the previous steps +""" + for index, (thought, action) in enumerate(zip(thoughts, actions)): + text += f"""\ +## Step {index} +### Thought +{thought} +### Action +{action} +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class ErrorPromptPart(VLPromptPart): + def __init__( + self, + last_action_error: str, + logs_separator: str = "Call log:", + logs_limit: int = 5, + ): + text = """\ +# The error caused by the last action +""" + if logs_separator in last_action_error: + error, logs = last_action_error.split(logs_separator) + logs = logs.split("\n")[:logs_limit] + text += f"""\ +{error} +{logs_separator} +""" + for log in logs: + text += f"""\ +{log} +""" + else: + text += f"""\ +{last_action_error} +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + +class PreliminaryAnswerPromptPart(VLPromptPart): + def __init__( + self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool + ): + text = f"""\ +# The action space +Here are all the actions you can take to interact with the browser. \ +They are Python functions based on the Playwright library. +{action_set_description} +# The format of the answer +Think about the action to take, and describe the location to take the action. \ +Your answer should include one thought and one location. +""" + if use_abstract_example: + text += """\ +# An abstract example of the answer + +The thought about the action. + + +The description of the location. + +""" + if use_concrete_example: + text += """\ +# A concrete example of the answer + +The goal is to click on the numbers in ascending order. \ +The smallest number visible on the screen is '1'. \ +I will use the 'mouse_click' action to directly click on the number '1'. + + +The number '1' in the top-left quadrant of the white area. + +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + +class AnswerPromptPart(VLPromptPart): + def __init__( + self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool + ): + text = f"""\ +# The action space +Here are all the actions you can take to interact with the browser. \ +They are Python functions based on the Playwright library. +{action_set_description} +# The format of the answer +Think about the action to take, and choose the action from the action space. \ +Your answer should include one thought and one action. +""" + if use_abstract_example: + text += """\ +# An abstract example of the answer + +The thought about the action. + + +The action to take. + +""" + if use_concrete_example: + text += """\ +# A concrete example of the answer + +The goal is to click on the numbers in ascending order. \ +The smallest number visible on the screen is '1'. \ +I will use the 'mouse_click' action to directly click on the number '1'. + + +mouse_click(50, 50) + +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +@dataclass +class UIPrompt(VLPrompt): + introduction_prompt_part: IntroductionPromptPart + goal_prompt_part: GoalPromptPart + screenshot_prompt_part: Optional[ScreenshotPromptPart] + tabs_prompt_part: Optional[TabsPromptPart] + history_prompt_part: Optional[HistoryPromptPart] + error_prompt_part: Optional[ErrorPromptPart] + answer_prompt_part: AnswerPromptPart + action_validator: callable + + def get_messages(self) -> Discussion: + message_content = self.introduction_prompt_part.get_message_content() + message_content.extend(self.goal_prompt_part.get_message_content()) + if self.screenshot_prompt_part is not None: + message_content.extend(self.screenshot_prompt_part.get_message_content()) + if self.tabs_prompt_part is not None: + message_content.extend(self.tabs_prompt_part.get_message_content()) + if self.history_prompt_part is not None: + message_content.extend(self.history_prompt_part.get_message_content()) + if self.error_prompt_part is not None: + message_content.extend(self.error_prompt_part.get_message_content()) + message_content.extend(self.answer_prompt_part.get_message_content()) + messages = Discussion([HumanMessage(message_content)]) + messages.merge() + return messages + + def parse_answer(self, answer_content: list[dict]) -> dict: + answer_text = answer_content[0]["text"] + answer_dict = {} + try: + answer_dict.update(parse_html_tags_raise(answer_text, keys=["thought", "action"])) + except ParseError as error: + answer_dict["parse_error"] = str(error) + answer_dict["thought"] = answer_text + code_blocks = extract_code_blocks(answer_text) + if len(code_blocks) == 0: + raise error + else: + answer_dict["action"] = "\n".join([block for _, block in code_blocks]) + if answer_dict["action"] == "None": + answer_dict["action"] = None + else: + try: + self.action_validator(answer_dict["action"]) + except Exception as error: + raise ParseError(str(error)) + return answer_dict + + +@dataclass +class UIPromptArgs(VLPromptArgs): + use_screenshot: bool + use_screenshot_som: bool + use_tabs: bool + use_history: bool + use_error: bool + use_abstract_example: bool + use_concrete_example: bool + extra_instruction: Optional[str] + + def make_prompt( + self, obs: dict, thoughts: list[str], actions: list[str], action_set: HighLevelActionSet + ) -> UIPrompt: + introduction_prompt_part = IntroductionPromptPart() + goal_prompt_part = GoalPromptPart(obs["goal_object"]) + if self.use_screenshot: + screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"]) + else: + screenshot_prompt_part = None + if self.use_tabs: + tabs_prompt_part = TabsPromptPart( + open_pages_titles=obs["open_pages_titles"], + open_pages_urls=obs["open_pages_urls"], + active_page_index=obs["active_page_index"], + ) + else: + tabs_prompt_part = None + if self.use_history and len(thoughts) == len(actions) > 0: + history_prompt_part = HistoryPromptPart(thoughts=thoughts, actions=actions) + else: + history_prompt_part = None + if self.use_error and obs["last_action_error"]: + error_prompt_part = ErrorPromptPart(obs["last_action_error"]) + else: + error_prompt_part = None + answer_prompt_part = AnswerPromptPart( + action_set_description=action_set.describe( + with_long_description=True, with_examples=False + ), + use_abstract_example=self.use_abstract_example, + use_concrete_example=self.use_concrete_example, + ) + self.ui_prompt = UIPrompt( + introduction_prompt_part=introduction_prompt_part, + goal_prompt_part=goal_prompt_part, + screenshot_prompt_part=screenshot_prompt_part, + tabs_prompt_part=tabs_prompt_part, + history_prompt_part=history_prompt_part, + error_prompt_part=error_prompt_part, + answer_prompt_part=answer_prompt_part, + action_validator=action_set.to_python_code, + ) + return self.ui_prompt