diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py
new file mode 100644
index 00000000..c818e60a
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/config.py
@@ -0,0 +1,48 @@
+from browsergym.experiments.benchmark import HighLevelActionSetArgs
+from .vl_agent.ui_agent import UIAgentArgs
+from .vl_model.llama_model import LlamaModelArgs
+from .vl_model.openrouter_api_model import OpenRouterAPIModelArgs
+from .vl_prompt.ui_prompt import UIPromptArgs
+
+
+VL_MODEL_ARGS_DICT = {
+ "gpt_4o": OpenRouterAPIModelArgs(
+ base_url="https://openrouter.ai/api/v1",
+ model_id="openai/gpt-4o-2024-11-20",
+ max_tokens=8192,
+ reproducibility_config={"temperature": 0.1},
+ ),
+ "llama_32_11b": LlamaModelArgs(
+ model_path="meta-llama/Llama-3.2-11B-Vision-Instruct",
+ torch_dtype="bfloat16",
+ accelerator_config={"mixed_precision": "bf16", "cpu": False},
+ reproducibility_config={"temperature": 0.1},
+ max_length=32768,
+ max_new_tokens=8192,
+ checkpoint_file=None,
+ device=None,
+ ),
+}
+
+VL_PROMPT_ARGS_DICT = {
+ "ui_prompt": UIPromptArgs(
+ use_screenshot=True,
+ use_screenshot_som=False,
+ use_tabs=True,
+ use_history=True,
+ use_error=True,
+ use_abstract_example=True,
+ use_concrete_example=False,
+ extra_instruction=None,
+ )
+}
+
+VL_AGENT_ARGS_DICT = {
+ "ui_agent": UIAgentArgs(
+ main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"],
+ auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"],
+ ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"],
+ action_set_args=HighLevelActionSetArgs(subsets=["coord"]),
+ max_retry=4,
+ )
+}
diff --git a/src/agentlab/agents/vl_agent/main.py b/src/agentlab/agents/vl_agent/main.py
new file mode 100644
index 00000000..eeb1184c
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/main.py
@@ -0,0 +1,33 @@
+from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT
+from agentlab.experiments.study import Study
+import logging
+import os
+
+
+logging.getLogger().setLevel(logging.INFO)
+
+vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]]
+benchmark = "miniwob"
+os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/"
+reproducibility_mode = False
+relaunch = False
+n_jobs = 1
+
+
+if __name__ == "__main__":
+ if reproducibility_mode:
+ for vl_agent_args in vl_agent_args_list:
+ vl_agent_args.set_reproducibility_mode()
+ if relaunch:
+ study = Study.load_most_recent(contains=None)
+ study.find_incomplete(include_errors=True)
+ else:
+ study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING)
+ study.run(
+ n_jobs=n_jobs,
+ parallel_backend="sequential",
+ strict_reproducibility=reproducibility_mode,
+ n_relaunch=3,
+ )
+ if reproducibility_mode:
+ study.append_to_journal(strict_reproducibility=True)
diff --git a/src/agentlab/agents/vl_agent/utils.py b/src/agentlab/agents/vl_agent/utils.py
new file mode 100644
index 00000000..36db2dd9
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/utils.py
@@ -0,0 +1,39 @@
+from accelerate import dispatch_model, infer_auto_device_map
+from accelerate.utils.modeling import get_balanced_memory
+from PIL import Image
+from torch.nn import Module
+from typing import Union
+import base64
+import io
+import numpy as np
+
+
+def image_to_image_url(image: Union[Image.Image, np.ndarray]):
+ if isinstance(image, np.ndarray):
+ image = Image.fromarray(image)
+ if image.mode in ("RGBA", "LA"):
+ image = image.convert("RGB")
+ buffer = io.BytesIO()
+ image.save(buffer, format="JPEG")
+ image_base64 = base64.b64encode(buffer.getvalue()).decode()
+ image_url = f"data:image/jpeg;base64,{image_base64}"
+ return image_url
+
+
+def image_url_to_image(image_url: str) -> Image.Image:
+ image_base64 = image_url.replace("data:image/jpeg;base64,", "")
+ image_data = base64.b64decode(image_base64.encode())
+ buffer = io.BytesIO(image_data)
+ image = Image.open(buffer)
+ return image
+
+
+def auto_dispatch_model(model: Module, no_split_module_classes: list[str]) -> Module:
+ max_memory = get_balanced_memory(model, no_split_module_classes=no_split_module_classes)
+ device_map = infer_auto_device_map(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=no_split_module_classes,
+ )
+ model = dispatch_model(model, device_map=device_map)
+ return model
diff --git a/src/agentlab/agents/vl_agent/vl_agent/base.py b/src/agentlab/agents/vl_agent/vl_agent/base.py
new file mode 100644
index 00000000..8dde21a7
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_agent/base.py
@@ -0,0 +1,47 @@
+from abc import ABC, abstractmethod
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.experiments.benchmark import Benchmark
+from dataclasses import dataclass
+
+
+class VLAgent(ABC):
+ @property
+ @abstractmethod
+ def action_set(self) -> HighLevelActionSet:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_action(self, obs: dict) -> tuple[str, dict]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def obs_preprocessor(self, obs: dict) -> dict:
+ raise NotImplementedError
+
+
+@dataclass
+class VLAgentArgs(ABC):
+ @property
+ @abstractmethod
+ def agent_name(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def make_agent(self) -> VLAgent:
+ raise NotImplementedError
+
+ @abstractmethod
+ def prepare(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_reproducibility_mode(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
+ raise NotImplementedError
diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py
new file mode 100644
index 00000000..c850894e
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py
@@ -0,0 +1,145 @@
+from agentlab.llm.llm_utils import ParseError, retry
+from agentlab.llm.tracking import cost_tracker_decorator
+from browsergym.core.action.highlevel import HighLevelActionSet
+from browsergym.experiments.agent import AgentInfo
+from browsergym.experiments.benchmark import Benchmark
+from browsergym.experiments.benchmark.base import HighLevelActionSetArgs
+from browsergym.utils.obs import overlay_som
+from copy import copy, deepcopy
+from dataclasses import asdict, dataclass
+from functools import cache
+from typing import Optional
+from .base import VLAgent, VLAgentArgs
+from ..vl_model.base import VLModelArgs
+from ..vl_prompt.ui_prompt import UIPromptArgs
+
+
+class UIAgent(VLAgent):
+ def __init__(
+ self,
+ main_vl_model_args: VLModelArgs,
+ auxiliary_vl_model_args: Optional[VLModelArgs],
+ ui_prompt_args: UIPromptArgs,
+ action_set_args: HighLevelActionSetArgs,
+ max_retry: int,
+ ):
+ self.main_vl_model = main_vl_model_args.make_model()
+ if auxiliary_vl_model_args is None:
+ self.auxiliary_vl_model = None
+ else:
+ self.auxiliary_vl_model = auxiliary_vl_model_args.make_model()
+ self.ui_prompt_args = ui_prompt_args
+ self.action_set_args = action_set_args
+ self.max_retry = max_retry
+ self.thoughts = []
+ self.actions = []
+
+ @property
+ @cache
+ def action_set(self) -> HighLevelActionSet:
+ return self.action_set_args.make_action_set()
+
+ @cost_tracker_decorator
+ def get_action(self, obs: dict) -> tuple[str, dict]:
+ ui_prompt = self.ui_prompt_args.make_prompt(
+ obs=obs, thoughts=self.thoughts, actions=self.actions, action_set=self.action_set
+ )
+ try:
+ messages = ui_prompt.get_messages()
+ answer = retry(
+ chat=self.main_vl_model,
+ messages=messages,
+ n_retry=self.max_retry,
+ parser=ui_prompt.parse_answer,
+ )
+ stats = {"num_main_retries": (len(messages) - 3) // 2}
+ except ParseError:
+ answer = {"thought": None, "action": None}
+ stats = {"num_main_retries": self.max_retry}
+ stats.update(self.main_vl_model.get_stats())
+ if self.auxiliary_vl_model is not None:
+ preliminary_answer = answer
+ ui_prompt = self.ui_prompt_args.make_prompt(
+ obs=obs,
+ thoughts=self.thoughts,
+ actions=self.actions,
+ action_set=self.action_set,
+ preliminary_answer=preliminary_answer,
+ )
+ try:
+ messages = ui_prompt.get_messages()
+ answer = retry(
+ chat=self.auxiliary_vl_model,
+ messages=messages,
+ n_retry=self.max_retry,
+ parser=ui_prompt.parse_answer,
+ )
+ stats["num_auxiliary_retries"] = (len(messages) - 3) // 2
+ except ParseError:
+ answer = {"thought": None, "action": None}
+ stats["num_auxiliary_retries"] = self.max_retry
+ stats.update(self.auxiliary_vl_model.get_stats())
+ else:
+ preliminary_answer = None
+ self.thoughts.append(str(answer["thought"]))
+ self.actions.append(str(answer["action"]))
+ agent_info = AgentInfo(
+ think=str(answer["thought"]), stats=stats, extra_info=preliminary_answer
+ )
+ return answer["action"], asdict(agent_info)
+
+ def obs_preprocessor(self, obs: dict) -> dict:
+ obs = copy(obs)
+ if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som:
+ obs["screenshot"] = overlay_som(
+ obs["screenshot"], extra_properties=obs["extra_element_properties"]
+ )
+ return obs
+
+
+@dataclass
+class UIAgentArgs(VLAgentArgs):
+ main_vl_model_args: VLModelArgs
+ auxiliary_vl_model_args: Optional[VLModelArgs]
+ ui_prompt_args: UIPromptArgs
+ action_set_args: HighLevelActionSetArgs
+ max_retry: int
+
+ @property
+ @cache
+ def agent_name(self) -> str:
+ if self.auxiliary_vl_model_args is None:
+ return f"UIAgent-{self.main_vl_model_args.model_name}"
+ else:
+ return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}"
+
+ def make_agent(self) -> UIAgent:
+ self.ui_agent = UIAgent(
+ main_vl_model_args=self.main_vl_model_args,
+ auxiliary_vl_model_args=self.auxiliary_vl_model_args,
+ ui_prompt_args=self.ui_prompt_args,
+ action_set_args=self.action_set_args,
+ max_retry=self.max_retry,
+ )
+ return self.ui_agent
+
+ def prepare(self):
+ self.main_vl_model_args.prepare()
+ if self.auxiliary_vl_model_args is not None:
+ self.auxiliary_vl_model_args.prepare()
+
+ def close(self):
+ self.main_vl_model_args.close()
+ if self.auxiliary_vl_model_args is not None:
+ self.auxiliary_vl_model_args.close()
+
+ def set_reproducibility_mode(self):
+ self.main_vl_model_args.set_reproducibility_mode()
+ if self.auxiliary_vl_model_args is not None:
+ self.auxiliary_vl_model_args.set_reproducibility_mode()
+
+ def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
+ self.ui_prompt_args.use_tabs = benchmark.is_multi_tab
+ self.action_set_args = deepcopy(benchmark.high_level_action_set_args)
+ if demo_mode:
+ self.action_set_args.demo_mode = "all_blue"
diff --git a/src/agentlab/agents/vl_agent/vl_model/base.py b/src/agentlab/agents/vl_agent/vl_model/base.py
new file mode 100644
index 00000000..ce188183
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_model/base.py
@@ -0,0 +1,35 @@
+from abc import ABC, abstractmethod
+from agentlab.llm.llm_utils import AIMessage, Discussion
+
+
+class VLModel(ABC):
+ @abstractmethod
+ def __call__(self, messages: Discussion) -> AIMessage:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_stats(self) -> dict:
+ raise NotImplementedError
+
+
+class VLModelArgs(ABC):
+ @property
+ @abstractmethod
+ def model_name(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def make_model(self) -> VLModel:
+ raise NotImplementedError
+
+ @abstractmethod
+ def prepare(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_reproducibility_mode(self):
+ raise NotImplementedError
diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py
new file mode 100644
index 00000000..968877df
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py
@@ -0,0 +1,128 @@
+from accelerate import Accelerator
+from accelerate.utils.modeling import load_checkpoint_in_model
+from agentlab.llm.llm_utils import AIMessage, Discussion
+from dataclasses import dataclass
+from functools import cache
+from transformers import AutoProcessor, MllamaForConditionalGeneration
+from typing import Optional
+from .base import VLModel, VLModelArgs
+from ..utils import auto_dispatch_model, image_url_to_image
+
+
+class LlamaModel(VLModel):
+ def __init__(
+ self,
+ model_path: str,
+ torch_dtype: str,
+ accelerator_config: dict,
+ reproducibility_config: dict,
+ max_length: int,
+ max_new_tokens: int,
+ ):
+ self.model = MllamaForConditionalGeneration.from_pretrained(
+ model_path, torch_dtype=torch_dtype
+ )
+ self.processor = AutoProcessor.from_pretrained(model_path)
+ self.accelerator = Accelerator(**accelerator_config)
+ self.reproducibility_config = reproducibility_config
+ self.max_length = max_length
+ self.max_new_tokens = max_new_tokens
+
+ def __call__(self, messages: Discussion) -> AIMessage:
+ input_messages = []
+ input_images = []
+ for message in messages:
+ input_message = {"role": message["role"], "content": []}
+ if isinstance(message["content"], str):
+ input_message["content"].append({"type": "text", "text": message["content"]})
+ else:
+ for item in message["content"]:
+ if item["type"] == "text":
+ input_message["content"].append(item)
+ elif item["type"] == "image_url":
+ input_message["content"].append({"type": "image"})
+ input_images.append(image_url_to_image(item["image_url"]["url"]))
+ input_messages.append(input_message)
+ input_text = self.processor.apply_chat_template(
+ input_messages, add_generation_prompt=True, tokenize=False
+ )
+ input = self.processor(
+ images=input_images,
+ text=input_text,
+ add_special_tokens=False,
+ return_tensors="pt",
+ truncation=True,
+ max_length=self.max_length,
+ ).to(self.model.device)
+ with self.accelerator.autocast():
+ output = self.model.generate(
+ **input,
+ eos_token_id=self.processor.tokenizer.eos_token_id,
+ max_new_tokens=self.max_new_tokens,
+ use_cache=True,
+ **self.reproducibility_config,
+ )
+ output_text = self.processor.tokenizer.batch_decode(
+ output[:, input["input_ids"].shape[1] :],
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False,
+ )[0]
+ return AIMessage([{"type": "text", "text": output_text}])
+
+ def get_stats(self) -> dict:
+ return {}
+
+
+@dataclass
+class LlamaModelArgs(VLModelArgs):
+ model_path: str
+ torch_dtype: str
+ accelerator_config: dict
+ reproducibility_config: dict
+ max_length: int
+ max_new_tokens: int
+ checkpoint_file: Optional[str]
+ device: Optional[str]
+
+ @property
+ @cache
+ def model_name(self) -> str:
+ return self.model_path.split("/")[-1].replace("-", "_").replace(".", "")
+
+ def make_model(self) -> LlamaModel:
+ llama_model = LlamaModel(
+ model_path=self.model_path,
+ torch_dtype=self.torch_dtype,
+ accelerator_config=self.accelerator_config,
+ reproducibility_config=self.reproducibility_config,
+ max_length=self.max_length,
+ max_new_tokens=self.max_new_tokens,
+ )
+ if self.checkpoint_file is not None:
+ load_checkpoint_in_model(llama_model.model, checkpoint=self.checkpoint_file)
+ if self.device is None:
+ layer_classes = set()
+ for layer in llama_model.model.language_model.model.layers:
+ layer_classes.add(layer.__class__)
+ for layer in llama_model.model.vision_model.transformer.layers:
+ layer_classes.add(layer.__class__)
+ for layer in llama_model.model.vision_model.global_transformer.layers:
+ layer_classes.add(layer.__class__)
+ llama_model.model = auto_dispatch_model(
+ llama_model.model,
+ no_split_module_classes=[layer_class.__name__ for layer_class in layer_classes],
+ )
+ else:
+ llama_model.model = llama_model.model.to(self.device)
+ llama_model.model.eval()
+ self.llama_model = llama_model
+ return self.llama_model
+
+ def prepare(self):
+ pass
+
+ def close(self):
+ del self.llama_model.model
+
+ def set_reproducibility_mode(self):
+ self.reproducibility_config = {"do_sample": False}
diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py
new file mode 100644
index 00000000..d8c5b46f
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py
@@ -0,0 +1,70 @@
+from agentlab.llm.llm_utils import AIMessage, Discussion
+from dataclasses import dataclass
+from functools import cache
+from openai import OpenAI, RateLimitError
+from .base import VLModel, VLModelArgs
+import backoff
+import os
+
+
+class OpenRouterAPIModel(VLModel):
+ def __init__(
+ self,
+ base_url: str,
+ model_id: str,
+ max_tokens: int,
+ reproducibility_config: dict,
+ ):
+ self.client = OpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY"))
+ self.model_id = model_id
+ self.max_tokens = max_tokens
+ self.reproducibility_config = reproducibility_config
+
+ def __call__(self, messages: Discussion) -> AIMessage:
+ @backoff.on_exception(backoff.expo, RateLimitError)
+ def get_response(messages, max_tokens, **kwargs):
+ completion = self.client.chat.completions.create(
+ model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs
+ )
+ try:
+ response = completion.choices[0].message.content
+ except:
+ response = ""
+ return response
+
+ response = get_response(messages, self.max_tokens, **self.reproducibility_config)
+ return AIMessage([{"type": "text", "text": response}])
+
+ def get_stats(self) -> dict:
+ return {}
+
+
+@dataclass
+class OpenRouterAPIModelArgs(VLModelArgs):
+ base_url: str
+ model_id: str
+ max_tokens: int
+ reproducibility_config: dict
+
+ @property
+ @cache
+ def model_name(self) -> str:
+ return self.model_id.split("/")[-1].replace("-", "_").replace(".", "")
+
+ def make_model(self) -> OpenRouterAPIModel:
+ self.openrouter_api_model = OpenRouterAPIModel(
+ base_url=self.base_url,
+ model_id=self.model_id,
+ max_tokens=self.max_tokens,
+ reproducibility_config=self.reproducibility_config,
+ )
+ return self.openrouter_api_model
+
+ def prepare(self):
+ pass
+
+ def close(self):
+ pass
+
+ def set_reproducibility_mode(self):
+ self.reproducibility_config = {"temperature": 0.0}
diff --git a/src/agentlab/agents/vl_agent/vl_prompt/base.py b/src/agentlab/agents/vl_agent/vl_prompt/base.py
new file mode 100644
index 00000000..3658ef4d
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_prompt/base.py
@@ -0,0 +1,34 @@
+from abc import ABC, abstractmethod
+from agentlab.llm.llm_utils import Discussion
+from browsergym.core.action.highlevel import HighLevelActionSet
+from typing import Optional
+
+
+class VLPromptPart(ABC):
+ @abstractmethod
+ def get_message_content(self) -> list[dict]:
+ raise NotImplementedError
+
+
+class VLPrompt(ABC):
+ @abstractmethod
+ def get_messages(self) -> Discussion:
+ raise NotImplementedError
+
+ @abstractmethod
+ def parse_answer(self, answer_text: str) -> dict:
+ raise NotImplementedError
+
+
+class VLPromptArgs(ABC):
+ @abstractmethod
+ def make_prompt(
+ self,
+ obs: dict,
+ thoughts: list[str],
+ actions: list[str],
+ action_set: HighLevelActionSet,
+ extra_instruction: Optional[str] = None,
+ preliminary_answer: Optional[dict] = None,
+ ) -> VLPrompt:
+ raise NotImplementedError
diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py
new file mode 100644
index 00000000..b50e3eea
--- /dev/null
+++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py
@@ -0,0 +1,313 @@
+from agentlab.llm.llm_utils import (
+ Discussion,
+ extract_code_blocks,
+ HumanMessage,
+ ParseError,
+ parse_html_tags_raise,
+)
+from browsergym.core.action.highlevel import HighLevelActionSet
+from dataclasses import dataclass
+from PIL import Image
+from typing import Optional, Union
+from .base import VLPrompt, VLPromptArgs, VLPromptPart
+from ..utils import image_to_image_url
+import numpy as np
+
+
+class IntroductionPromptPart(VLPromptPart):
+ def __init__(self):
+ self.text = """\
+You are an agent working to address a web-based task through step-by-step interactions with the browser. \
+To achieve the goal of the task, at each step, you need to submit an action according to the current state of the browser. \
+This action will be executed to update the state of the browser, and you will proceed to the next step.
+"""
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+
+class GoalPromptPart(VLPromptPart):
+ def __init__(self, goal_object: list[dict]):
+ text = """\
+# The goal of the task
+"""
+ for item in goal_object:
+ if item["type"] == "text":
+ text += f"""\
+{item['text']}
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+
+class ScreenshotPromptPart(VLPromptPart):
+ def __init__(self, screenshot: Union[Image.Image, np.ndarray]):
+ self.text = """\
+# The screenshot of the current web page
+"""
+ self.image_url = image_to_image_url(screenshot)
+
+ def get_message_content(self) -> list[dict]:
+ return [
+ {"type": "text", "text": self.text},
+ {"type": "image_url", "image_url": {"url": self.image_url}},
+ ]
+
+
+class TabsPromptPart(VLPromptPart):
+ def __init__(
+ self, open_pages_titles: list[str], open_pages_urls: list[str], active_page_index: int
+ ):
+ text = """\
+# The open tabs of the browser
+"""
+ for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)):
+ text += f"""\
+## Tab {index}{' (active tab)' if index == active_page_index else ''}
+### Title
+{title}
+### URL
+{url}
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+
+class HistoryPromptPart(VLPromptPart):
+ def __init__(self, thoughts: list[str], actions: list[str]):
+ text = """\
+# The thoughts and actions of the previous steps
+"""
+ for index, (thought, action) in enumerate(zip(thoughts, actions)):
+ text += f"""\
+## Step {index}
+### Thought
+{thought}
+### Action
+{action}
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+
+class ErrorPromptPart(VLPromptPart):
+ def __init__(
+ self,
+ last_action_error: str,
+ logs_separator: str = "Call log:",
+ logs_limit: int = 5,
+ ):
+ text = """\
+# The error caused by the last action
+"""
+ if logs_separator in last_action_error:
+ error, logs = last_action_error.split(logs_separator)
+ logs = logs.split("\n")[:logs_limit]
+ text += f"""\
+{error}
+{logs_separator}
+"""
+ for log in logs:
+ text += f"""\
+{log}
+"""
+ else:
+ text += f"""\
+{last_action_error}
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+class PreliminaryAnswerPromptPart(VLPromptPart):
+ def __init__(
+ self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool
+ ):
+ text = f"""\
+# The action space
+Here are all the actions you can take to interact with the browser. \
+They are Python functions based on the Playwright library.
+{action_set_description}
+# The format of the answer
+Think about the action to take, and describe the location to take the action. \
+Your answer should include one thought and one location.
+"""
+ if use_abstract_example:
+ text += """\
+# An abstract example of the answer
+
+The thought about the action.
+
+
+The description of the location.
+
+"""
+ if use_concrete_example:
+ text += """\
+# A concrete example of the answer
+
+The goal is to click on the numbers in ascending order. \
+The smallest number visible on the screen is '1'. \
+I will use the 'mouse_click' action to directly click on the number '1'.
+
+
+The number '1' in the top-left quadrant of the white area.
+
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+class AnswerPromptPart(VLPromptPart):
+ def __init__(
+ self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool
+ ):
+ text = f"""\
+# The action space
+Here are all the actions you can take to interact with the browser. \
+They are Python functions based on the Playwright library.
+{action_set_description}
+# The format of the answer
+Think about the action to take, and choose the action from the action space. \
+Your answer should include one thought and one action.
+"""
+ if use_abstract_example:
+ text += """\
+# An abstract example of the answer
+
+The thought about the action.
+
+
+The action to take.
+
+"""
+ if use_concrete_example:
+ text += """\
+# A concrete example of the answer
+
+The goal is to click on the numbers in ascending order. \
+The smallest number visible on the screen is '1'. \
+I will use the 'mouse_click' action to directly click on the number '1'.
+
+
+mouse_click(50, 50)
+
+"""
+ self.text = text
+
+ def get_message_content(self) -> list[dict]:
+ return [{"type": "text", "text": self.text}]
+
+
+@dataclass
+class UIPrompt(VLPrompt):
+ introduction_prompt_part: IntroductionPromptPart
+ goal_prompt_part: GoalPromptPart
+ screenshot_prompt_part: Optional[ScreenshotPromptPart]
+ tabs_prompt_part: Optional[TabsPromptPart]
+ history_prompt_part: Optional[HistoryPromptPart]
+ error_prompt_part: Optional[ErrorPromptPart]
+ answer_prompt_part: AnswerPromptPart
+ action_validator: callable
+
+ def get_messages(self) -> Discussion:
+ message_content = self.introduction_prompt_part.get_message_content()
+ message_content.extend(self.goal_prompt_part.get_message_content())
+ if self.screenshot_prompt_part is not None:
+ message_content.extend(self.screenshot_prompt_part.get_message_content())
+ if self.tabs_prompt_part is not None:
+ message_content.extend(self.tabs_prompt_part.get_message_content())
+ if self.history_prompt_part is not None:
+ message_content.extend(self.history_prompt_part.get_message_content())
+ if self.error_prompt_part is not None:
+ message_content.extend(self.error_prompt_part.get_message_content())
+ message_content.extend(self.answer_prompt_part.get_message_content())
+ messages = Discussion([HumanMessage(message_content)])
+ messages.merge()
+ return messages
+
+ def parse_answer(self, answer_content: list[dict]) -> dict:
+ answer_text = answer_content[0]["text"]
+ answer_dict = {}
+ try:
+ answer_dict.update(parse_html_tags_raise(answer_text, keys=["thought", "action"]))
+ except ParseError as error:
+ answer_dict["parse_error"] = str(error)
+ answer_dict["thought"] = answer_text
+ code_blocks = extract_code_blocks(answer_text)
+ if len(code_blocks) == 0:
+ raise error
+ else:
+ answer_dict["action"] = "\n".join([block for _, block in code_blocks])
+ if answer_dict["action"] == "None":
+ answer_dict["action"] = None
+ else:
+ try:
+ self.action_validator(answer_dict["action"])
+ except Exception as error:
+ raise ParseError(str(error))
+ return answer_dict
+
+
+@dataclass
+class UIPromptArgs(VLPromptArgs):
+ use_screenshot: bool
+ use_screenshot_som: bool
+ use_tabs: bool
+ use_history: bool
+ use_error: bool
+ use_abstract_example: bool
+ use_concrete_example: bool
+ extra_instruction: Optional[str]
+
+ def make_prompt(
+ self, obs: dict, thoughts: list[str], actions: list[str], action_set: HighLevelActionSet
+ ) -> UIPrompt:
+ introduction_prompt_part = IntroductionPromptPart()
+ goal_prompt_part = GoalPromptPart(obs["goal_object"])
+ if self.use_screenshot:
+ screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"])
+ else:
+ screenshot_prompt_part = None
+ if self.use_tabs:
+ tabs_prompt_part = TabsPromptPart(
+ open_pages_titles=obs["open_pages_titles"],
+ open_pages_urls=obs["open_pages_urls"],
+ active_page_index=obs["active_page_index"],
+ )
+ else:
+ tabs_prompt_part = None
+ if self.use_history and len(thoughts) == len(actions) > 0:
+ history_prompt_part = HistoryPromptPart(thoughts=thoughts, actions=actions)
+ else:
+ history_prompt_part = None
+ if self.use_error and obs["last_action_error"]:
+ error_prompt_part = ErrorPromptPart(obs["last_action_error"])
+ else:
+ error_prompt_part = None
+ answer_prompt_part = AnswerPromptPart(
+ action_set_description=action_set.describe(
+ with_long_description=True, with_examples=False
+ ),
+ use_abstract_example=self.use_abstract_example,
+ use_concrete_example=self.use_concrete_example,
+ )
+ self.ui_prompt = UIPrompt(
+ introduction_prompt_part=introduction_prompt_part,
+ goal_prompt_part=goal_prompt_part,
+ screenshot_prompt_part=screenshot_prompt_part,
+ tabs_prompt_part=tabs_prompt_part,
+ history_prompt_part=history_prompt_part,
+ error_prompt_part=error_prompt_part,
+ answer_prompt_part=answer_prompt_part,
+ action_validator=action_set.to_python_code,
+ )
+ return self.ui_prompt