From 4f7c035b2d2a32bdbe5041b67806dfb68880059f Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Mon, 5 May 2025 22:33:22 +0000 Subject: [PATCH 01/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 100 ++++++++++++++++++ .../agents/vl_agent/vl_agent_config.py | 10 ++ src/agentlab/agents/vl_agent/vl_model.py | 32 ++++++ .../agents/vl_agent/vl_model_config.py | 11 ++ src/agentlab/agents/vl_agent/vl_prompt.py | 91 ++++++++++++++++ .../agents/vl_agent/vl_prompt_config.py | 29 +++++ 6 files changed, 273 insertions(+) create mode 100644 src/agentlab/agents/vl_agent/vl_agent.py create mode 100644 src/agentlab/agents/vl_agent/vl_agent_config.py create mode 100644 src/agentlab/agents/vl_agent/vl_model.py create mode 100644 src/agentlab/agents/vl_agent/vl_model_config.py create mode 100644 src/agentlab/agents/vl_agent/vl_prompt.py create mode 100644 src/agentlab/agents/vl_agent/vl_prompt_config.py diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py new file mode 100644 index 00000000..b415cae5 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -0,0 +1,100 @@ +from dataclasses import asdict, dataclass +from browsergym.experiments.agent import Agent, AgentInfo +from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.agent_args import AgentArgs +from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry +from agentlab.llm.tracking import cost_tracker_decorator +from .vl_model import VLModelArgs +from .vl_prompt import VLPromptFlags, VLPrompt +import bgym + + +@dataclass +class VLAgentArgs(AgentArgs): + vl_model_args: VLModelArgs = None + vl_prompt_flags: VLPromptFlags = None + + def __post_init__(self): + self.agent_name = f"VLAgent-{self.vl_model_args.model_name}".replace("/", "_") + + def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): + self.vl_prompt_flags.obs.use_tabs = benchmark.is_multi_tab + + def set_reproducibility_mode(self): + self.vl_model_args.temperature = 0 + + def prepare(self): + return self.vl_model_args.prepare() + + def close(self): + return self.vl_model_args.close() + + def make_agent(self): + return VLAgent(vl_model_args=self.vl_model_args, vl_prompt_flags=self.vl_prompt_flags) + + +class VLAgent(Agent): + def __init__(self, vl_model_args: VLModelArgs, vl_prompt_flags: VLPromptFlags): + self.vl_model_args = vl_model_args + self.vl_model = vl_model_args.make_model() + self.vl_prompt_flags = vl_prompt_flags + self.action_set = self.vl_prompt_flags.action_flags.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) + self.reset(seed=None) + + def obs_preprocessor(self, obs: dict) -> dict: + return self._obs_preprocessor(obs) + + @cost_tracker_decorator + def get_action(self, obs): + self.obs_history.append(obs) + main_prompt = VLPrompt( + action_set=self.action_set, + obs=obs, + actions=self.actions, + thoughts=self.thoughts, + flags=self.flags, + ) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + try: + # TODO, we would need to further shrink the prompt if the retry + # cause it to be too long + + chat_messages = Discussion([system_prompt, main_prompt.prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=main_prompt._parse_answer, + ) + ans_dict["busted_retry"] = 0 + # inferring the number of retries, TODO: make this less hacky + ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + except ParseError: + ans_dict = dict( + action=None, + n_retry=self.max_retry + 1, + busted_retry=1, + ) + + stats = self.chat_llm.get_stats() + stats["n_retry"] = ans_dict["n_retry"] + stats["busted_retry"] = ans_dict["busted_retry"] + + self.actions.append(ans_dict["action"]) + self.thoughts.append(ans_dict.get("think", None)) + + agent_info = AgentInfo( + think=ans_dict.get("think", None), + chat_messages=chat_messages, + stats=stats, + extra_info={"chat_model_args": asdict(self.chat_model_args)}, + ) + return ans_dict["action"], agent_info + + def reset(self, seed=None): + self.seed = seed + self.thoughts = [] + self.actions = [] + self.obs_history = [] diff --git a/src/agentlab/agents/vl_agent/vl_agent_config.py b/src/agentlab/agents/vl_agent/vl_agent_config.py new file mode 100644 index 00000000..9b8b96ec --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_agent_config.py @@ -0,0 +1,10 @@ +from .vl_agent import VLAgentArgs +from .vl_model_config import VL_MODEL_ARGS_DICT +from .vl_prompt_config import VL_PROMPT_FLAGS + + +VL_AGENT_ARGS_DICT = { + "vl_agent_llama_32_11b": VLAgentArgs( + vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], vl_prompt_flags=VL_PROMPT_FLAGS + ) +} diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py new file mode 100644 index 00000000..88bb24bf --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass + + +class VLModel(ABC): + @abstractmethod + def __call__(self, messages: list[dict]) -> dict: + pass + + def get_stats(self): + return {} + + +@dataclass +class VLModelArgs(ABC): + model_name: str + max_total_tokens: int = None + max_input_tokens: int = None + max_new_tokens: int = None + temperature: float = 0.1 + vision_support: bool = False + log_probs: bool = False + + @abstractmethod + def make_model(self) -> VLModel: + pass + + def prepare(self): + pass + + def close(self): + pass diff --git a/src/agentlab/agents/vl_agent/vl_model_config.py b/src/agentlab/agents/vl_agent/vl_model_config.py new file mode 100644 index 00000000..d1dfa23f --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model_config.py @@ -0,0 +1,11 @@ +from .vl_model import VLModelArgs + +VL_MODEL_ARGS_DICT = { + "llama_32_11b": VLModelArgs( + model_name="meta-llama/Llama-3.2-11B-Vision-Instruct", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=100_000, + vision_support=False, + ) +} diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py new file mode 100644 index 00000000..121a7d8a --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -0,0 +1,91 @@ +import logging +from dataclasses import dataclass +from browsergym.core.action.base import AbstractActionSet +from agentlab.agents import dynamic_prompting as dp +from agentlab.llm.llm_utils import HumanMessage + + +@dataclass +class VLPromptFlags(dp.Flags): + obs_flags: dp.ObsFlags = None + action_flags: dp.ActionFlags = None + use_thinking: bool = True + use_concrete_example: bool = False + use_abstract_example: bool = True + enable_chat: bool = False + extra_instructions: str | None = None + + +class VLPrompt(dp.PromptElement): + def __init__( + self, + vl_prompt_flags: VLPromptFlags, + action_set: AbstractActionSet, + obs_history: list[dict], + actions: list[str], + thoughts: list[str], + ): + super().__init__() + self.vl_prompt_flags = vl_prompt_flags + if vl_prompt_flags.enable_chat: + self.instructions = dp.ChatInstructions( + obs_history[-1]["chat_messages"], + extra_instructions=vl_prompt_flags.extra_instructions, + ) + else: + if sum([msg["role"] == "user" for msg in obs_history[-1].get("chat_messages", [])]) > 1: + logging.warning( + "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." + ) + self.instructions = dp.GoalInstructions( + obs_history[-1]["goal_object"], + extra_instructions=vl_prompt_flags.extra_instructions, + ) + self.observation = dp.Observation(obs_history[-1], vl_prompt_flags.obs_flags) + self.history = dp.History(obs_history, actions, None, thoughts, vl_prompt_flags.obs_flags) + self.action_prompt = dp.ActionPrompt(action_set, action_flags=vl_prompt_flags.action_flags) + self.think = dp.Think(visible=lambda: vl_prompt_flags.use_thinking) + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + prompt.add_text( + f"""\ +{self.observation.prompt}\ +{self.history.prompt}\ +{self.action_prompt.prompt}\ +{self.think.prompt}\ +""" + ) + + if self.vl_prompt_flags.use_abstract_example: + prompt.add_text( + f""" +# Abstract Example + +Here is an abstract version of the answer with description of the content of +each tag. Make sure you follow this structure, but replace the content with your +answer: +{self.think.abstract_ex}\ +{self.action_prompt.abstract_ex}\ +""" + ) + + if self.vl_prompt_flags.use_concrete_example: + prompt.add_text( + f""" +# Concrete Example + +Here is a concrete example of how to format your answer. +Make sure to follow the template with proper tags: +{self.think.concrete_ex}\ +{self.action_prompt.concrete_ex}\ +""" + ) + return self.observation.add_screenshot(prompt) + + def _parse_answer(self, text_answer): + ans_dict = {} + ans_dict.update(self.think.parse_answer(text_answer)) + ans_dict.update(self.action_prompt.parse_answer(text_answer)) + return ans_dict diff --git a/src/agentlab/agents/vl_agent/vl_prompt_config.py b/src/agentlab/agents/vl_agent/vl_prompt_config.py new file mode 100644 index 00000000..b2116e05 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_prompt_config.py @@ -0,0 +1,29 @@ +from .vl_prompt import VLPromptFlags +import agentlab.agents.dynamic_prompting as dp +import bgym + +VL_OBS_FLAGS = dp.ObsFlags( + use_tabs=True, + use_error_logs=True, + use_past_error_logs=False, + use_screenshot=True, + use_som=False, + openai_vision_detail="auto", +) + +VL_ACTION_FLAGS = dp.ActionFlags( + action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + long_description=True, + individual_examples=False, +) + + +VL_PROMPT_FLAGS = VLPromptFlags( + obs=VL_OBS_FLAGS, + action=VL_ACTION_FLAGS, + use_thinking=True, + use_concrete_example=False, + use_abstract_example=True, + enable_chat=False, + extra_instructions=None, +) From e619dfdf8e6022861b3fcb357da6714da6abc602 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Tue, 6 May 2025 03:48:40 +0000 Subject: [PATCH 02/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 62 +++++++++---------- .../agents/vl_agent/vl_agent_config.py | 5 +- src/agentlab/agents/vl_agent/vl_model.py | 13 ++-- src/agentlab/agents/vl_agent/vl_prompt.py | 8 +-- .../agents/vl_agent/vl_prompt_config.py | 46 +++++++------- 5 files changed, 67 insertions(+), 67 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index b415cae5..86b6f346 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -13,6 +13,7 @@ class VLAgentArgs(AgentArgs): vl_model_args: VLModelArgs = None vl_prompt_flags: VLPromptFlags = None + max_retry: int = None def __post_init__(self): self.agent_name = f"VLAgent-{self.vl_model_args.model_name}".replace("/", "_") @@ -30,15 +31,20 @@ def close(self): return self.vl_model_args.close() def make_agent(self): - return VLAgent(vl_model_args=self.vl_model_args, vl_prompt_flags=self.vl_prompt_flags) + return VLAgent( + vl_model_args=self.vl_model_args, + vl_prompt_flags=self.vl_prompt_flags, + max_retry=self.max_retry, + ) class VLAgent(Agent): - def __init__(self, vl_model_args: VLModelArgs, vl_prompt_flags: VLPromptFlags): + def __init__(self, vl_model_args: VLModelArgs, vl_prompt_flags: VLPromptFlags, max_retry: int): self.vl_model_args = vl_model_args - self.vl_model = vl_model_args.make_model() self.vl_prompt_flags = vl_prompt_flags - self.action_set = self.vl_prompt_flags.action_flags.action_set.make_action_set() + self.max_retry = max_retry + self.vl_model = vl_model_args.make_model() + self.action_set = vl_prompt_flags.action_flags.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) self.reset(seed=None) @@ -48,53 +54,45 @@ def obs_preprocessor(self, obs: dict) -> dict: @cost_tracker_decorator def get_action(self, obs): self.obs_history.append(obs) - main_prompt = VLPrompt( + vl_prompt = VLPrompt( + vl_prompt_flags=self.vl_prompt_flags, action_set=self.action_set, - obs=obs, + obs_history=self.obs_history, actions=self.actions, thoughts=self.thoughts, - flags=self.flags, ) - system_prompt = SystemMessage(dp.SystemPrompt().prompt) try: - # TODO, we would need to further shrink the prompt if the retry - # cause it to be too long - - chat_messages = Discussion([system_prompt, main_prompt.prompt]) - ans_dict = retry( - self.chat_llm, + chat_messages = Discussion([system_prompt, vl_prompt.prompt]) + answer = retry( + self.vl_model, chat_messages, n_retry=self.max_retry, - parser=main_prompt._parse_answer, + parser=vl_prompt.parse_answer, ) - ans_dict["busted_retry"] = 0 - # inferring the number of retries, TODO: make this less hacky - ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 + answer["busted_retry"] = 0 + answer["n_retry"] = (len(chat_messages) - 3) / 2 except ParseError: - ans_dict = dict( + answer = dict( action=None, n_retry=self.max_retry + 1, busted_retry=1, ) - - stats = self.chat_llm.get_stats() - stats["n_retry"] = ans_dict["n_retry"] - stats["busted_retry"] = ans_dict["busted_retry"] - - self.actions.append(ans_dict["action"]) - self.thoughts.append(ans_dict.get("think", None)) - + stats = self.vl_model.get_stats() + stats["n_retry"] = answer["n_retry"] + stats["busted_retry"] = answer["busted_retry"] + self.actions.append(answer["action"]) + self.thoughts.append(answer.get("think", None)) agent_info = AgentInfo( - think=ans_dict.get("think", None), + think=answer.get("think", None), chat_messages=chat_messages, stats=stats, - extra_info={"chat_model_args": asdict(self.chat_model_args)}, + extra_info={"vl_model_args": asdict(self.vl_model_args)}, ) - return ans_dict["action"], agent_info + return answer["action"], agent_info def reset(self, seed=None): self.seed = seed - self.thoughts = [] - self.actions = [] self.obs_history = [] + self.actions = [] + self.thoughts = [] diff --git a/src/agentlab/agents/vl_agent/vl_agent_config.py b/src/agentlab/agents/vl_agent/vl_agent_config.py index 9b8b96ec..653b8921 100644 --- a/src/agentlab/agents/vl_agent/vl_agent_config.py +++ b/src/agentlab/agents/vl_agent/vl_agent_config.py @@ -1,10 +1,11 @@ from .vl_agent import VLAgentArgs from .vl_model_config import VL_MODEL_ARGS_DICT -from .vl_prompt_config import VL_PROMPT_FLAGS +from .vl_prompt_config import VL_PROMPT_FLAGS_DICT VL_AGENT_ARGS_DICT = { "vl_agent_llama_32_11b": VLAgentArgs( - vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], vl_prompt_flags=VL_PROMPT_FLAGS + vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + vl_prompt_flags=VL_PROMPT_FLAGS_DICT["default"], ) } diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index 88bb24bf..9cfc2791 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -5,10 +5,11 @@ class VLModel(ABC): @abstractmethod def __call__(self, messages: list[dict]) -> dict: - pass + raise NotImplementedError + @abstractmethod def get_stats(self): - return {} + raise NotImplementedError @dataclass @@ -23,10 +24,12 @@ class VLModelArgs(ABC): @abstractmethod def make_model(self) -> VLModel: - pass + raise NotImplementedError + @abstractmethod def prepare(self): - pass + raise NotImplementedError + @abstractmethod def close(self): - pass + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 121a7d8a..c81ff6be 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -85,7 +85,7 @@ def _prompt(self) -> HumanMessage: return self.observation.add_screenshot(prompt) def _parse_answer(self, text_answer): - ans_dict = {} - ans_dict.update(self.think.parse_answer(text_answer)) - ans_dict.update(self.action_prompt.parse_answer(text_answer)) - return ans_dict + answer = {} + answer.update(self.think.parse_answer(text_answer)) + answer.update(self.action_prompt.parse_answer(text_answer)) + return answer diff --git a/src/agentlab/agents/vl_agent/vl_prompt_config.py b/src/agentlab/agents/vl_agent/vl_prompt_config.py index b2116e05..832ac2a4 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt_config.py +++ b/src/agentlab/agents/vl_agent/vl_prompt_config.py @@ -2,28 +2,26 @@ import agentlab.agents.dynamic_prompting as dp import bgym -VL_OBS_FLAGS = dp.ObsFlags( - use_tabs=True, - use_error_logs=True, - use_past_error_logs=False, - use_screenshot=True, - use_som=False, - openai_vision_detail="auto", -) -VL_ACTION_FLAGS = dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), - long_description=True, - individual_examples=False, -) - - -VL_PROMPT_FLAGS = VLPromptFlags( - obs=VL_OBS_FLAGS, - action=VL_ACTION_FLAGS, - use_thinking=True, - use_concrete_example=False, - use_abstract_example=True, - enable_chat=False, - extra_instructions=None, -) +VL_PROMPT_FLAGS_DICT = { + "default": VLPromptFlags( + obs_flags=dp.ObsFlags( + use_tabs=True, + use_error_logs=True, + use_past_error_logs=False, + use_screenshot=True, + use_som=False, + openai_vision_detail="auto", + ), + action_flags=dp.ActionFlags( + action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + long_description=True, + individual_examples=False, + ), + use_thinking=True, + use_concrete_example=False, + use_abstract_example=True, + enable_chat=False, + extra_instructions=None, + ) +} From c7aa049c73d70dd9b6cf9537b7e70f66cfcf0309 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Wed, 7 May 2025 06:22:10 +0000 Subject: [PATCH 03/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 134 ++++++++++++------ .../agents/vl_agent/vl_agent_config.py | 10 +- src/agentlab/agents/vl_agent/vl_model.py | 98 ++++++++++--- .../agents/vl_agent/vl_model_config.py | 16 ++- src/agentlab/agents/vl_agent/vl_prompt.py | 62 ++++---- .../agents/vl_agent/vl_prompt_config.py | 6 +- 6 files changed, 220 insertions(+), 106 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 86b6f346..65ab060c 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -1,61 +1,118 @@ -from dataclasses import asdict, dataclass -from browsergym.experiments.agent import Agent, AgentInfo +from abc import ABC, abstractmethod from agentlab.agents import dynamic_prompting as dp -from agentlab.agents.agent_args import AgentArgs -from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry +from agentlab.llm.llm_utils import Discussion, ParseError, retry, SystemMessage from agentlab.llm.tracking import cost_tracker_decorator +from browsergym.core.action.base import AbstractActionSet +from browsergym.experiments.agent import AgentInfo +from browsergym.experiments.benchmark import Benchmark +from copy import deepcopy +from dataclasses import asdict, dataclass from .vl_model import VLModelArgs -from .vl_prompt import VLPromptFlags, VLPrompt -import bgym +from .vl_prompt import VLPrompt, VLPromptFlags @dataclass -class VLAgentArgs(AgentArgs): - vl_model_args: VLModelArgs = None - vl_prompt_flags: VLPromptFlags = None - max_retry: int = None +class VLAgentArgs(ABC): + agent_name: str - def __post_init__(self): - self.agent_name = f"VLAgent-{self.vl_model_args.model_name}".replace("/", "_") + @abstractmethod + def make_agent(self): + raise NotImplementedError - def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): - self.vl_prompt_flags.obs.use_tabs = benchmark.is_multi_tab + @abstractmethod + def prepare(self): + raise NotImplementedError + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod def set_reproducibility_mode(self): - self.vl_model_args.temperature = 0 + raise NotImplementedError - def prepare(self): - return self.vl_model_args.prepare() + @abstractmethod + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + raise NotImplementedError - def close(self): - return self.vl_model_args.close() + +@dataclass +class VLAgent(ABC): + action_set: AbstractActionSet + + @abstractmethod + def get_action(self, obs: dict): + raise NotImplementedError + + @abstractmethod + def obs_preprocessor(self, obs: dict): + raise NotImplementedError + + +@dataclass +class UIAgentArgs(VLAgentArgs): + general_model_args: VLModelArgs + grounding_model_args: VLModelArgs + prompt_flags: VLPromptFlags + max_retry: int + + def __post_init__(self): + self.agent_name = ( + f"ui_agent-{self.general_model_args.model_name}-{self.grounding_model_args.model_name}" + ) def make_agent(self): - return VLAgent( - vl_model_args=self.vl_model_args, - vl_prompt_flags=self.vl_prompt_flags, + return UIAgent( + general_model_args=self.general_model_args, + grounding_model_args=self.grounding_model_args, + prompt_flags=self.prompt_flags, max_retry=self.max_retry, ) + def prepare(self): + self.general_model_args.prepare() + self.grounding_model_args.prepare() -class VLAgent(Agent): - def __init__(self, vl_model_args: VLModelArgs, vl_prompt_flags: VLPromptFlags, max_retry: int): - self.vl_model_args = vl_model_args - self.vl_prompt_flags = vl_prompt_flags - self.max_retry = max_retry - self.vl_model = vl_model_args.make_model() - self.action_set = vl_prompt_flags.action_flags.action_set.make_action_set() - self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) - self.reset(seed=None) + def close(self): + self.general_model_args.close() + self.grounding_model_args.close() - def obs_preprocessor(self, obs: dict) -> dict: - return self._obs_preprocessor(obs) + def set_reproducibility_mode(self): + self.general_model_args.set_reproducibility_mode() + self.grounding_model_args.set_reproducibility_mode() + + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + self.prompt_flags.obs_flags.use_tabs = benchmark.is_multi_tab + self.prompt_flags.action_flags.action_set = deepcopy(benchmark.high_level_action_set_args) + if demo_mode: + self.prompt_flags.action_flags.action_set.demo_mode = "all_blue" + + +class UIAgent(VLAgent): + def __init__( + self, + general_model_args: VLModelArgs, + grounding_model_args: VLModelArgs, + prompt_flags: VLPromptFlags, + max_retry: int, + ): + self.general_model_args = general_model_args + self.grounding_model_args = grounding_model_args + self.prompt_flags = prompt_flags + self.max_retry = max_retry + self.general_model = general_model_args.make_model() + self.grounding_model = grounding_model_args.make_model() + self.action_set = prompt_flags.action_flags.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(prompt_flags.obs_flags) + self.obs_history = [] + self.actions = [] + self.thoughts = [] @cost_tracker_decorator - def get_action(self, obs): + def get_action(self, obs: dict): self.obs_history.append(obs) vl_prompt = VLPrompt( - vl_prompt_flags=self.vl_prompt_flags, + prompt_flags=self.prompt_flags, action_set=self.action_set, obs_history=self.obs_history, actions=self.actions, @@ -91,8 +148,5 @@ def get_action(self, obs): ) return answer["action"], agent_info - def reset(self, seed=None): - self.seed = seed - self.obs_history = [] - self.actions = [] - self.thoughts = [] + def obs_preprocessor(self, obs: dict): + return self._obs_preprocessor(obs) diff --git a/src/agentlab/agents/vl_agent/vl_agent_config.py b/src/agentlab/agents/vl_agent/vl_agent_config.py index 653b8921..0670127c 100644 --- a/src/agentlab/agents/vl_agent/vl_agent_config.py +++ b/src/agentlab/agents/vl_agent/vl_agent_config.py @@ -1,11 +1,13 @@ -from .vl_agent import VLAgentArgs +from .vl_agent import UIAgentArgs from .vl_model_config import VL_MODEL_ARGS_DICT from .vl_prompt_config import VL_PROMPT_FLAGS_DICT VL_AGENT_ARGS_DICT = { - "vl_agent_llama_32_11b": VLAgentArgs( - vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - vl_prompt_flags=VL_PROMPT_FLAGS_DICT["default"], + "ui_agent-llama_32_11b": UIAgentArgs( + general_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + grounding_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + prompt_flags=VL_PROMPT_FLAGS_DICT["default"], + max_retry=4, ) } diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index 9cfc2791..8d0edf4c 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -1,35 +1,101 @@ from abc import ABC, abstractmethod +from accelerate.utils.modeling import load_checkpoint_in_model from dataclasses import dataclass +from transformers import AutoProcessor, MllamaForConditionalGeneration +from typing import Optional +import fnmatch +import os -class VLModel(ABC): +@dataclass +class VLModelArgs(ABC): + model_name: str + @abstractmethod - def __call__(self, messages: list[dict]) -> dict: + def make_model(self): raise NotImplementedError @abstractmethod - def get_stats(self): + def prepare(self): raise NotImplementedError - -@dataclass -class VLModelArgs(ABC): - model_name: str - max_total_tokens: int = None - max_input_tokens: int = None - max_new_tokens: int = None - temperature: float = 0.1 - vision_support: bool = False - log_probs: bool = False + @abstractmethod + def close(self): + raise NotImplementedError @abstractmethod - def make_model(self) -> VLModel: + def set_reproducibility_mode(self): raise NotImplementedError + +class VLModel(ABC): @abstractmethod - def prepare(self): + def __call__(self, messages: list[dict]): raise NotImplementedError @abstractmethod - def close(self): + def get_stats(self): raise NotImplementedError + + +@dataclass +class LlamaModelArgs(VLModelArgs): + model_path: str + torch_dtype: str + checkpoint_dir: Optional[str] + max_length: int + max_new_tokens: int + reproducibility_config: dict + + def make_model(self): + return LlamaModel( + model_path=self.model_path, + torch_dtype=self.torch_dtype, + checkpoint_dir=self.checkpoint_dir, + max_length=self.max_length, + max_new_tokens=self.max_new_tokens, + reproducibility_config=self.reproducibility_config, + ) + + def prepare(self): + pass + + def close(self): + pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"do_sample": False} + + +class LlamaModel(VLModel): + def __init__( + self, + model_path: str, + torch_dtype: str, + checkpoint_dir: str, + max_length: int, + max_new_tokens: int, + reproducibility_config: dict, + ): + self.model = MllamaForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch_dtype + ) + if checkpoint_dir is not None: + checkpoint_file = None + for item in os.listdir(checkpoint_dir): + if fnmatch.fnmatch(item, "pytorch_model*.bin") or fnmatch.fnmatch( + item, "model*.safetensors" + ): + checkpoint_file = os.path.join(checkpoint_dir, item) + break + load_checkpoint_in_model(self.model, checkpoint_file) + self.processor = AutoProcessor.from_pretrained(model_path) + self.max_length = max_length + self.max_new_tokens = max_new_tokens + self.reproducibility_config = reproducibility_config + + def __call__(self, messages: list[dict]): + pass + + def get_stats(self): + pass diff --git a/src/agentlab/agents/vl_agent/vl_model_config.py b/src/agentlab/agents/vl_agent/vl_model_config.py index d1dfa23f..1e163c87 100644 --- a/src/agentlab/agents/vl_agent/vl_model_config.py +++ b/src/agentlab/agents/vl_agent/vl_model_config.py @@ -1,11 +1,13 @@ -from .vl_model import VLModelArgs +from .vl_model import LlamaModelArgs VL_MODEL_ARGS_DICT = { - "llama_32_11b": VLModelArgs( - model_name="meta-llama/Llama-3.2-11B-Vision-Instruct", - max_total_tokens=200_000, - max_input_tokens=200_000, - max_new_tokens=100_000, - vision_support=False, + "llama_32_11b": LlamaModelArgs( + model_name="llama_32_11b", + model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", + torch_dtype="bfloat16", + checkpoint_dir=None, + max_length=32768, + max_new_tokens=8192, + reproducibility_config={"temperature": 0.1}, ) } diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index c81ff6be..92b683db 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,65 +1,56 @@ -import logging -from dataclasses import dataclass -from browsergym.core.action.base import AbstractActionSet from agentlab.agents import dynamic_prompting as dp from agentlab.llm.llm_utils import HumanMessage +from browsergym.core.action.base import AbstractActionSet +from dataclasses import dataclass +from typing import Optional @dataclass class VLPromptFlags(dp.Flags): - obs_flags: dp.ObsFlags = None - action_flags: dp.ActionFlags = None - use_thinking: bool = True - use_concrete_example: bool = False - use_abstract_example: bool = True - enable_chat: bool = False - extra_instructions: str | None = None + obs_flags: dp.ObsFlags + action_flags: dp.ActionFlags + use_thinking: bool + use_concrete_example: bool + use_abstract_example: bool + enable_chat: bool + extra_instructions: Optional[str] class VLPrompt(dp.PromptElement): def __init__( self, - vl_prompt_flags: VLPromptFlags, + prompt_flags: VLPromptFlags, action_set: AbstractActionSet, obs_history: list[dict], actions: list[str], thoughts: list[str], ): super().__init__() - self.vl_prompt_flags = vl_prompt_flags - if vl_prompt_flags.enable_chat: + if prompt_flags.enable_chat: self.instructions = dp.ChatInstructions( obs_history[-1]["chat_messages"], - extra_instructions=vl_prompt_flags.extra_instructions, + extra_instructions=prompt_flags.extra_instructions, ) else: - if sum([msg["role"] == "user" for msg in obs_history[-1].get("chat_messages", [])]) > 1: - logging.warning( - "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." - ) self.instructions = dp.GoalInstructions( obs_history[-1]["goal_object"], - extra_instructions=vl_prompt_flags.extra_instructions, + extra_instructions=prompt_flags.extra_instructions, ) - self.observation = dp.Observation(obs_history[-1], vl_prompt_flags.obs_flags) - self.history = dp.History(obs_history, actions, None, thoughts, vl_prompt_flags.obs_flags) - self.action_prompt = dp.ActionPrompt(action_set, action_flags=vl_prompt_flags.action_flags) - self.think = dp.Think(visible=lambda: vl_prompt_flags.use_thinking) - - @property - def _prompt(self) -> HumanMessage: - prompt = HumanMessage(self.instructions.prompt) - prompt.add_text( + self.observation = dp.Observation(obs_history[-1], prompt_flags.obs_flags) + self.history = dp.History(obs_history, actions, None, thoughts, prompt_flags.obs_flags) + self.think = dp.Think(visible=lambda: prompt_flags.use_thinking) + self.action_prompt = dp.ActionPrompt(action_set, action_flags=prompt_flags.action_flags) + self._prompt = HumanMessage(self.instructions.prompt) + self._prompt.add_text( f"""\ {self.observation.prompt}\ {self.history.prompt}\ -{self.action_prompt.prompt}\ {self.think.prompt}\ +{self.action_prompt.prompt}\ """ ) - - if self.vl_prompt_flags.use_abstract_example: - prompt.add_text( + if prompt_flags.use_abstract_example: + self._prompt.add_text( f""" # Abstract Example @@ -70,9 +61,8 @@ def _prompt(self) -> HumanMessage: {self.action_prompt.abstract_ex}\ """ ) - - if self.vl_prompt_flags.use_concrete_example: - prompt.add_text( + if prompt_flags.use_concrete_example: + self._prompt.add_text( f""" # Concrete Example @@ -82,7 +72,7 @@ def _prompt(self) -> HumanMessage: {self.action_prompt.concrete_ex}\ """ ) - return self.observation.add_screenshot(prompt) + self.observation.add_screenshot(self._prompt) def _parse_answer(self, text_answer): answer = {} diff --git a/src/agentlab/agents/vl_agent/vl_prompt_config.py b/src/agentlab/agents/vl_agent/vl_prompt_config.py index 832ac2a4..a77a8225 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt_config.py +++ b/src/agentlab/agents/vl_agent/vl_prompt_config.py @@ -1,6 +1,6 @@ -from .vl_prompt import VLPromptFlags import agentlab.agents.dynamic_prompting as dp -import bgym +from browsergym.experiments.benchmark import HighLevelActionSetArgs +from .vl_prompt import VLPromptFlags VL_PROMPT_FLAGS_DICT = { @@ -14,7 +14,7 @@ openai_vision_detail="auto", ), action_flags=dp.ActionFlags( - action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]), + action_set=HighLevelActionSetArgs(subsets=["coord"]), long_description=True, individual_examples=False, ), From f6318b48eff7088ea89d7ab8a5b40c69bd3f60c4 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Wed, 7 May 2025 20:54:36 +0000 Subject: [PATCH 04/29] update --- .../{vl_prompt_config.py => config.py} | 26 ++- src/agentlab/agents/vl_agent/vl_agent.py | 154 +++++++++--------- .../agents/vl_agent/vl_agent_config.py | 13 -- src/agentlab/agents/vl_agent/vl_model.py | 84 +++++----- .../agents/vl_agent/vl_model_config.py | 13 -- src/agentlab/agents/vl_agent/vl_prompt.py | 22 +-- 6 files changed, 152 insertions(+), 160 deletions(-) rename src/agentlab/agents/vl_agent/{vl_prompt_config.py => config.py} (53%) delete mode 100644 src/agentlab/agents/vl_agent/vl_agent_config.py delete mode 100644 src/agentlab/agents/vl_agent/vl_model_config.py diff --git a/src/agentlab/agents/vl_agent/vl_prompt_config.py b/src/agentlab/agents/vl_agent/config.py similarity index 53% rename from src/agentlab/agents/vl_agent/vl_prompt_config.py rename to src/agentlab/agents/vl_agent/config.py index a77a8225..0a2dbdee 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt_config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -1,6 +1,21 @@ -import agentlab.agents.dynamic_prompting as dp from browsergym.experiments.benchmark import HighLevelActionSetArgs +from .vl_agent import UIAgentArgs +from .vl_model import LlamaModelArgs from .vl_prompt import VLPromptFlags +import agentlab.agents.dynamic_prompting as dp + + +VL_MODEL_ARGS_DICT = { + "llama_32_11b": LlamaModelArgs( + model_name="llama_32_11b", + model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", + torch_dtype="bfloat16", + checkpoint_dir=None, + max_length=32768, + max_new_tokens=8192, + reproducibility_config={"temperature": 0.1}, + ) +} VL_PROMPT_FLAGS_DICT = { @@ -25,3 +40,12 @@ extra_instructions=None, ) } + +VL_AGENT_ARGS_DICT = { + "ui_agent-llama_32_11b": UIAgentArgs( + general_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + grounding_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + vl_prompt_flags=VL_PROMPT_FLAGS_DICT["default"], + max_retry=4, + ) +} diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 65ab060c..487a5815 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -12,107 +12,66 @@ @dataclass -class VLAgentArgs(ABC): - agent_name: str - - @abstractmethod - def make_agent(self): - raise NotImplementedError - - @abstractmethod - def prepare(self): - raise NotImplementedError - - @abstractmethod - def close(self): - raise NotImplementedError +class VLAgent(ABC): + action_set: AbstractActionSet @abstractmethod - def set_reproducibility_mode(self): + def get_action(self, obs: dict) -> tuple[str, dict]: raise NotImplementedError @abstractmethod - def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + def obs_preprocessor(self, obs: dict) -> dict: raise NotImplementedError @dataclass -class VLAgent(ABC): - action_set: AbstractActionSet +class VLAgentArgs(ABC): + agent_name: str @abstractmethod - def get_action(self, obs: dict): + def make_agent(self) -> VLAgent: raise NotImplementedError @abstractmethod - def obs_preprocessor(self, obs: dict): - raise NotImplementedError - - -@dataclass -class UIAgentArgs(VLAgentArgs): - general_model_args: VLModelArgs - grounding_model_args: VLModelArgs - prompt_flags: VLPromptFlags - max_retry: int - - def __post_init__(self): - self.agent_name = ( - f"ui_agent-{self.general_model_args.model_name}-{self.grounding_model_args.model_name}" - ) - - def make_agent(self): - return UIAgent( - general_model_args=self.general_model_args, - grounding_model_args=self.grounding_model_args, - prompt_flags=self.prompt_flags, - max_retry=self.max_retry, - ) - def prepare(self): - self.general_model_args.prepare() - self.grounding_model_args.prepare() + raise NotImplementedError + @abstractmethod def close(self): - self.general_model_args.close() - self.grounding_model_args.close() + raise NotImplementedError + @abstractmethod def set_reproducibility_mode(self): - self.general_model_args.set_reproducibility_mode() - self.grounding_model_args.set_reproducibility_mode() + raise NotImplementedError + @abstractmethod def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): - self.prompt_flags.obs_flags.use_tabs = benchmark.is_multi_tab - self.prompt_flags.action_flags.action_set = deepcopy(benchmark.high_level_action_set_args) - if demo_mode: - self.prompt_flags.action_flags.action_set.demo_mode = "all_blue" + raise NotImplementedError class UIAgent(VLAgent): def __init__( self, - general_model_args: VLModelArgs, - grounding_model_args: VLModelArgs, - prompt_flags: VLPromptFlags, + general_vl_model_args: VLModelArgs, + grounding_vl_model_args: VLModelArgs, + vl_prompt_flags: VLPromptFlags, max_retry: int, ): - self.general_model_args = general_model_args - self.grounding_model_args = grounding_model_args - self.prompt_flags = prompt_flags + self.general_vl_model = general_vl_model_args.make_model() + self.grounding_vl_model = grounding_vl_model_args.make_model() + self.action_set = vl_prompt_flags.action_flags.action_set.make_action_set() + self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) + self.vl_prompt_flags = vl_prompt_flags self.max_retry = max_retry - self.general_model = general_model_args.make_model() - self.grounding_model = grounding_model_args.make_model() - self.action_set = prompt_flags.action_flags.action_set.make_action_set() - self._obs_preprocessor = dp.make_obs_preprocessor(prompt_flags.obs_flags) self.obs_history = [] self.actions = [] self.thoughts = [] @cost_tracker_decorator - def get_action(self, obs: dict): + def get_action(self, obs: dict) -> tuple[str, dict]: self.obs_history.append(obs) vl_prompt = VLPrompt( - prompt_flags=self.prompt_flags, + vl_prompt_flags=self.vl_prompt_flags, action_set=self.action_set, obs_history=self.obs_history, actions=self.actions, @@ -120,33 +79,68 @@ def get_action(self, obs: dict): ) system_prompt = SystemMessage(dp.SystemPrompt().prompt) try: - chat_messages = Discussion([system_prompt, vl_prompt.prompt]) + messages = Discussion([system_prompt, vl_prompt.prompt]) answer = retry( - self.vl_model, - chat_messages, + self.general_vl_model, + messages, n_retry=self.max_retry, parser=vl_prompt.parse_answer, ) + answer["n_retry"] = (len(messages) - 3) / 2 answer["busted_retry"] = 0 - answer["n_retry"] = (len(chat_messages) - 3) / 2 except ParseError: answer = dict( action=None, + think=None, n_retry=self.max_retry + 1, busted_retry=1, ) - stats = self.vl_model.get_stats() + self.actions.append(answer["action"]) + self.thoughts.append(answer["think"]) + stats = self.general_vl_model.get_stats() stats["n_retry"] = answer["n_retry"] stats["busted_retry"] = answer["busted_retry"] - self.actions.append(answer["action"]) - self.thoughts.append(answer.get("think", None)) - agent_info = AgentInfo( - think=answer.get("think", None), - chat_messages=chat_messages, - stats=stats, - extra_info={"vl_model_args": asdict(self.vl_model_args)}, - ) - return answer["action"], agent_info + agent_info = AgentInfo(think=answer["think"], chat_messages=messages, stats=stats) + return answer["action"], asdict(agent_info) - def obs_preprocessor(self, obs: dict): + def obs_preprocessor(self, obs: dict) -> dict: return self._obs_preprocessor(obs) + + +@dataclass +class UIAgentArgs(VLAgentArgs): + general_vl_model_args: VLModelArgs + grounding_vl_model_args: VLModelArgs + vl_prompt_flags: VLPromptFlags + max_retry: int + + def __post_init__(self): + self.agent_name = f"ui_agent-{self.general_vl_model_args.model_name}-{self.grounding_vl_model_args.model_name}" + + def make_agent(self) -> UIAgent: + return UIAgent( + general_vl_model_args=self.general_vl_model_args, + grounding_vl_model_args=self.grounding_vl_model_args, + vl_prompt_flags=self.vl_prompt_flags, + max_retry=self.max_retry, + ) + + def prepare(self): + self.general_vl_model_args.prepare() + self.grounding_vl_model_args.prepare() + + def close(self): + self.general_vl_model_args.close() + self.grounding_vl_model_args.close() + + def set_reproducibility_mode(self): + self.general_vl_model_args.set_reproducibility_mode() + self.grounding_vl_model_args.set_reproducibility_mode() + + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + self.vl_prompt_flags.obs_flags.use_tabs = benchmark.is_multi_tab + self.vl_prompt_flags.action_flags.action_set = deepcopy( + benchmark.high_level_action_set_args + ) + if demo_mode: + self.vl_prompt_flags.action_flags.action_set.demo_mode = "all_blue" diff --git a/src/agentlab/agents/vl_agent/vl_agent_config.py b/src/agentlab/agents/vl_agent/vl_agent_config.py deleted file mode 100644 index 0670127c..00000000 --- a/src/agentlab/agents/vl_agent/vl_agent_config.py +++ /dev/null @@ -1,13 +0,0 @@ -from .vl_agent import UIAgentArgs -from .vl_model_config import VL_MODEL_ARGS_DICT -from .vl_prompt_config import VL_PROMPT_FLAGS_DICT - - -VL_AGENT_ARGS_DICT = { - "ui_agent-llama_32_11b": UIAgentArgs( - general_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - grounding_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - prompt_flags=VL_PROMPT_FLAGS_DICT["default"], - max_retry=4, - ) -} diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index 8d0edf4c..608b513c 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -7,12 +7,22 @@ import os +class VLModel(ABC): + @abstractmethod + def __call__(self, messages: list[dict]) -> dict: + raise NotImplementedError + + @abstractmethod + def get_stats(self) -> dict: + raise NotImplementedError + + @dataclass class VLModelArgs(ABC): model_name: str @abstractmethod - def make_model(self): + def make_model(self) -> VLModel: raise NotImplementedError @abstractmethod @@ -28,45 +38,6 @@ def set_reproducibility_mode(self): raise NotImplementedError -class VLModel(ABC): - @abstractmethod - def __call__(self, messages: list[dict]): - raise NotImplementedError - - @abstractmethod - def get_stats(self): - raise NotImplementedError - - -@dataclass -class LlamaModelArgs(VLModelArgs): - model_path: str - torch_dtype: str - checkpoint_dir: Optional[str] - max_length: int - max_new_tokens: int - reproducibility_config: dict - - def make_model(self): - return LlamaModel( - model_path=self.model_path, - torch_dtype=self.torch_dtype, - checkpoint_dir=self.checkpoint_dir, - max_length=self.max_length, - max_new_tokens=self.max_new_tokens, - reproducibility_config=self.reproducibility_config, - ) - - def prepare(self): - pass - - def close(self): - pass - - def set_reproducibility_mode(self): - self.reproducibility_config = {"do_sample": False} - - class LlamaModel(VLModel): def __init__( self, @@ -94,8 +65,37 @@ def __init__( self.max_new_tokens = max_new_tokens self.reproducibility_config = reproducibility_config - def __call__(self, messages: list[dict]): + def __call__(self, messages: list[dict]) -> dict: + pass + + def get_stats(self) -> dict: + pass + + +@dataclass +class LlamaModelArgs(VLModelArgs): + model_path: str + torch_dtype: str + checkpoint_dir: Optional[str] + max_length: int + max_new_tokens: int + reproducibility_config: dict + + def make_model(self) -> LlamaModel: + return LlamaModel( + model_path=self.model_path, + torch_dtype=self.torch_dtype, + checkpoint_dir=self.checkpoint_dir, + max_length=self.max_length, + max_new_tokens=self.max_new_tokens, + reproducibility_config=self.reproducibility_config, + ) + + def prepare(self): pass - def get_stats(self): + def close(self): pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"do_sample": False} diff --git a/src/agentlab/agents/vl_agent/vl_model_config.py b/src/agentlab/agents/vl_agent/vl_model_config.py deleted file mode 100644 index 1e163c87..00000000 --- a/src/agentlab/agents/vl_agent/vl_model_config.py +++ /dev/null @@ -1,13 +0,0 @@ -from .vl_model import LlamaModelArgs - -VL_MODEL_ARGS_DICT = { - "llama_32_11b": LlamaModelArgs( - model_name="llama_32_11b", - model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", - torch_dtype="bfloat16", - checkpoint_dir=None, - max_length=32768, - max_new_tokens=8192, - reproducibility_config={"temperature": 0.1}, - ) -} diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 92b683db..338096f2 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -19,27 +19,27 @@ class VLPromptFlags(dp.Flags): class VLPrompt(dp.PromptElement): def __init__( self, - prompt_flags: VLPromptFlags, + vl_prompt_flags: VLPromptFlags, action_set: AbstractActionSet, obs_history: list[dict], actions: list[str], thoughts: list[str], ): super().__init__() - if prompt_flags.enable_chat: + if vl_prompt_flags.enable_chat: self.instructions = dp.ChatInstructions( obs_history[-1]["chat_messages"], - extra_instructions=prompt_flags.extra_instructions, + extra_instructions=vl_prompt_flags.extra_instructions, ) else: self.instructions = dp.GoalInstructions( obs_history[-1]["goal_object"], - extra_instructions=prompt_flags.extra_instructions, + extra_instructions=vl_prompt_flags.extra_instructions, ) - self.observation = dp.Observation(obs_history[-1], prompt_flags.obs_flags) - self.history = dp.History(obs_history, actions, None, thoughts, prompt_flags.obs_flags) - self.think = dp.Think(visible=lambda: prompt_flags.use_thinking) - self.action_prompt = dp.ActionPrompt(action_set, action_flags=prompt_flags.action_flags) + self.observation = dp.Observation(obs_history[-1], vl_prompt_flags.obs_flags) + self.history = dp.History(obs_history, actions, None, thoughts, vl_prompt_flags.obs_flags) + self.think = dp.Think(visible=lambda: vl_prompt_flags.use_thinking) + self.action_prompt = dp.ActionPrompt(action_set, action_flags=vl_prompt_flags.action_flags) self._prompt = HumanMessage(self.instructions.prompt) self._prompt.add_text( f"""\ @@ -49,7 +49,7 @@ def __init__( {self.action_prompt.prompt}\ """ ) - if prompt_flags.use_abstract_example: + if vl_prompt_flags.use_abstract_example: self._prompt.add_text( f""" # Abstract Example @@ -61,7 +61,7 @@ def __init__( {self.action_prompt.abstract_ex}\ """ ) - if prompt_flags.use_concrete_example: + if vl_prompt_flags.use_concrete_example: self._prompt.add_text( f""" # Concrete Example @@ -74,7 +74,7 @@ def __init__( ) self.observation.add_screenshot(self._prompt) - def _parse_answer(self, text_answer): + def _parse_answer(self, text_answer) -> dict: answer = {} answer.update(self.think.parse_answer(text_answer)) answer.update(self.action_prompt.parse_answer(text_answer)) From 414cf681ecf6595dd7c45622adba730177a33683 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 8 May 2025 06:13:13 +0000 Subject: [PATCH 05/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 23 +++++----- src/agentlab/agents/vl_agent/vl_model.py | 4 +- src/agentlab/agents/vl_agent/vl_prompt.py | 55 ++++++++++++----------- 3 files changed, 40 insertions(+), 42 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 487a5815..fffa17cd 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -81,25 +81,22 @@ def get_action(self, obs: dict) -> tuple[str, dict]: try: messages = Discussion([system_prompt, vl_prompt.prompt]) answer = retry( - self.general_vl_model, - messages, + chat=self.general_vl_model, + messages=messages, n_retry=self.max_retry, parser=vl_prompt.parse_answer, ) - answer["n_retry"] = (len(messages) - 3) / 2 - answer["busted_retry"] = 0 + num_tries = (len(messages) - 3) / 2 + num_busted_tries = 0 except ParseError: - answer = dict( - action=None, - think=None, - n_retry=self.max_retry + 1, - busted_retry=1, - ) + answer = {"action": None, "think": None} + num_tries = self.max_retry + 1 + num_busted_tries = 1 self.actions.append(answer["action"]) self.thoughts.append(answer["think"]) - stats = self.general_vl_model.get_stats() - stats["n_retry"] = answer["n_retry"] - stats["busted_retry"] = answer["busted_retry"] + stats = {"num_tries": num_tries, "num_busted_tries": num_busted_tries} + stats.update(self.general_vl_model.get_stats()) + stats.update(self.grounding_vl_model.get_stats()) agent_info = AgentInfo(think=answer["think"], chat_messages=messages, stats=stats) return answer["action"], asdict(agent_info) diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index 608b513c..eccb1ca8 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -66,10 +66,10 @@ def __init__( self.reproducibility_config = reproducibility_config def __call__(self, messages: list[dict]) -> dict: - pass + return {} def get_stats(self) -> dict: - pass + return {} @dataclass diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 338096f2..f37639df 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -28,53 +28,54 @@ def __init__( super().__init__() if vl_prompt_flags.enable_chat: self.instructions = dp.ChatInstructions( - obs_history[-1]["chat_messages"], + chat_messages=obs_history[-1]["chat_messages"], extra_instructions=vl_prompt_flags.extra_instructions, ) else: self.instructions = dp.GoalInstructions( - obs_history[-1]["goal_object"], + goal_object=obs_history[-1]["goal_object"], extra_instructions=vl_prompt_flags.extra_instructions, ) - self.observation = dp.Observation(obs_history[-1], vl_prompt_flags.obs_flags) - self.history = dp.History(obs_history, actions, None, thoughts, vl_prompt_flags.obs_flags) - self.think = dp.Think(visible=lambda: vl_prompt_flags.use_thinking) - self.action_prompt = dp.ActionPrompt(action_set, action_flags=vl_prompt_flags.action_flags) - self._prompt = HumanMessage(self.instructions.prompt) + self.observation = dp.Observation(obs=obs_history[-1], flags=vl_prompt_flags.obs_flags) + self.history = dp.History( + history_obs=obs_history, + actions=actions, + memories=None, + thoughts=thoughts, + flags=vl_prompt_flags.obs_flags, + ) + self.think = dp.Think(visible=vl_prompt_flags.use_thinking) + self.action_prompt = dp.ActionPrompt( + action_set=action_set, action_flags=vl_prompt_flags.action_flags + ) + self._prompt = HumanMessage(content=self.instructions.prompt) self._prompt.add_text( f"""\ -{self.observation.prompt}\ -{self.history.prompt}\ -{self.think.prompt}\ -{self.action_prompt.prompt}\ +{self.observation.prompt} +{self.history.prompt} +{self.think.prompt} +{self.action_prompt.prompt} """ ) if vl_prompt_flags.use_abstract_example: self._prompt.add_text( - f""" -# Abstract Example - -Here is an abstract version of the answer with description of the content of -each tag. Make sure you follow this structure, but replace the content with your -answer: -{self.think.abstract_ex}\ -{self.action_prompt.abstract_ex}\ + f"""\ +# Abstract Example: +{self.think.abstract_ex} +{self.action_prompt.abstract_ex} """ ) if vl_prompt_flags.use_concrete_example: self._prompt.add_text( - f""" -# Concrete Example - -Here is a concrete example of how to format your answer. -Make sure to follow the template with proper tags: -{self.think.concrete_ex}\ -{self.action_prompt.concrete_ex}\ + f"""\ +# Concrete Example: +{self.think.concrete_ex} +{self.action_prompt.concrete_ex} """ ) self.observation.add_screenshot(self._prompt) - def _parse_answer(self, text_answer) -> dict: + def _parse_answer(self, text_answer: str) -> dict: answer = {} answer.update(self.think.parse_answer(text_answer)) answer.update(self.action_prompt.parse_answer(text_answer)) From e6e9724e3769d9c1936e78c4660a36eae4e95e14 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 8 May 2025 20:57:42 +0000 Subject: [PATCH 06/29] update --- src/agentlab/agents/vl_agent/config.py | 4 +- src/agentlab/agents/vl_agent/vl_agent.py | 80 +++++++++++++++-------- src/agentlab/agents/vl_agent/vl_prompt.py | 71 ++++++++++---------- 3 files changed, 90 insertions(+), 65 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index 0a2dbdee..67e2561d 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -43,8 +43,8 @@ VL_AGENT_ARGS_DICT = { "ui_agent-llama_32_11b": UIAgentArgs( - general_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - grounding_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + main_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], vl_prompt_flags=VL_PROMPT_FLAGS_DICT["default"], max_retry=4, ) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index fffa17cd..27a87b5e 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -7,6 +7,7 @@ from browsergym.experiments.benchmark import Benchmark from copy import deepcopy from dataclasses import asdict, dataclass +from typing import Optional from .vl_model import VLModelArgs from .vl_prompt import VLPrompt, VLPromptFlags @@ -52,13 +53,16 @@ def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): class UIAgent(VLAgent): def __init__( self, - general_vl_model_args: VLModelArgs, - grounding_vl_model_args: VLModelArgs, + main_vl_model_args: VLModelArgs, + auxiliary_vl_model_args: Optional[VLModelArgs], vl_prompt_flags: VLPromptFlags, max_retry: int, ): - self.general_vl_model = general_vl_model_args.make_model() - self.grounding_vl_model = grounding_vl_model_args.make_model() + self.main_vl_model = main_vl_model_args.make_model() + if auxiliary_vl_model_args is None: + self.auxiliary_vl_model = None + else: + self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() self.action_set = vl_prompt_flags.action_flags.action_set.make_action_set() self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) self.vl_prompt_flags = vl_prompt_flags @@ -74,29 +78,43 @@ def get_action(self, obs: dict) -> tuple[str, dict]: vl_prompt_flags=self.vl_prompt_flags, action_set=self.action_set, obs_history=self.obs_history, - actions=self.actions, thoughts=self.thoughts, + actions=self.actions, ) - system_prompt = SystemMessage(dp.SystemPrompt().prompt) try: - messages = Discussion([system_prompt, vl_prompt.prompt]) + messages = Discussion( + [SystemMessage(dp.SystemPrompt().prompt), vl_prompt.get_message()] + ) answer = retry( - chat=self.general_vl_model, + chat=self.main_vl_model, messages=messages, n_retry=self.max_retry, parser=vl_prompt.parse_answer, ) - num_tries = (len(messages) - 3) / 2 - num_busted_tries = 0 + stats = {"num_main_retries": (len(messages) - 3) // 2} except ParseError: - answer = {"action": None, "think": None} - num_tries = self.max_retry + 1 - num_busted_tries = 1 - self.actions.append(answer["action"]) + answer = {"think": None, "action": None} + stats = {"num_main_retries": self.max_retry} + stats.update(self.main_vl_model.get_stats()) + if self.auxiliary_vl_model is not None: + try: + messages = Discussion( + [SystemMessage(dp.SystemPrompt().prompt), vl_prompt.get_message()] + ) + messages.add_text(f"{answer['think']}\n{answer['action']}\n") + answer = retry( + chat=self.auxiliary_vl_model, + messages=messages, + n_retry=self.max_retry, + parser=vl_prompt.parse_answer, + ) + stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 + except ParseError: + answer = {"action": None, "think": None} + stats["num_auxiliary_retries"] = self.max_retry + stats.update(self.auxiliary_vl_model.get_stats()) self.thoughts.append(answer["think"]) - stats = {"num_tries": num_tries, "num_busted_tries": num_busted_tries} - stats.update(self.general_vl_model.get_stats()) - stats.update(self.grounding_vl_model.get_stats()) + self.actions.append(answer["action"]) agent_info = AgentInfo(think=answer["think"], chat_messages=messages, stats=stats) return answer["action"], asdict(agent_info) @@ -106,33 +124,39 @@ def obs_preprocessor(self, obs: dict) -> dict: @dataclass class UIAgentArgs(VLAgentArgs): - general_vl_model_args: VLModelArgs - grounding_vl_model_args: VLModelArgs + main_vl_model_args: VLModelArgs + auxiliary_vl_model_args: VLModelArgs vl_prompt_flags: VLPromptFlags max_retry: int def __post_init__(self): - self.agent_name = f"ui_agent-{self.general_vl_model_args.model_name}-{self.grounding_vl_model_args.model_name}" + if self.auxiliary_vl_model_args is None: + self.agent_name = f"ui_agent-{self.main_vl_model_args.model_name}" + else: + self.agent_name = f"ui_agent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" def make_agent(self) -> UIAgent: return UIAgent( - general_vl_model_args=self.general_vl_model_args, - grounding_vl_model_args=self.grounding_vl_model_args, + main_vl_model_args=self.main_vl_model_args, + auxiliary_vl_model_args=self.auxiliary_vl_model_args, vl_prompt_flags=self.vl_prompt_flags, max_retry=self.max_retry, ) def prepare(self): - self.general_vl_model_args.prepare() - self.grounding_vl_model_args.prepare() + self.main_vl_model_args.prepare() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.prepare() def close(self): - self.general_vl_model_args.close() - self.grounding_vl_model_args.close() + self.main_vl_model_args.close() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.close() def set_reproducibility_mode(self): - self.general_vl_model_args.set_reproducibility_mode() - self.grounding_vl_model_args.set_reproducibility_mode() + self.main_vl_model_args.set_reproducibility_mode() + if self.auxiliary_vl_model_args is not None: + self.auxiliary_vl_model_args.set_reproducibility_mode() def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): self.vl_prompt_flags.obs_flags.use_tabs = benchmark.is_multi_tab diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index f37639df..6addedce 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -22,58 +22,59 @@ def __init__( vl_prompt_flags: VLPromptFlags, action_set: AbstractActionSet, obs_history: list[dict], - actions: list[str], thoughts: list[str], + actions: list[str], ): super().__init__() - if vl_prompt_flags.enable_chat: + self.vl_prompt_flags = vl_prompt_flags + self.obs_history = obs_history + if self.vl_prompt_flags.enable_chat: self.instructions = dp.ChatInstructions( - chat_messages=obs_history[-1]["chat_messages"], - extra_instructions=vl_prompt_flags.extra_instructions, + chat_messages=self.obs_history[-1]["chat_messages"], + extra_instructions=self.vl_prompt_flags.extra_instructions, ) else: self.instructions = dp.GoalInstructions( - goal_object=obs_history[-1]["goal_object"], - extra_instructions=vl_prompt_flags.extra_instructions, + goal_object=self.obs_history[-1]["goal_object"], + extra_instructions=self.vl_prompt_flags.extra_instructions, ) - self.observation = dp.Observation(obs=obs_history[-1], flags=vl_prompt_flags.obs_flags) + self.observation = dp.Observation( + obs=self.obs_history[-1], flags=self.vl_prompt_flags.obs_flags + ) self.history = dp.History( - history_obs=obs_history, + history_obs=self.obs_history, actions=actions, memories=None, thoughts=thoughts, - flags=vl_prompt_flags.obs_flags, + flags=self.vl_prompt_flags.obs_flags, ) - self.think = dp.Think(visible=vl_prompt_flags.use_thinking) + self.think = dp.Think(visible=self.vl_prompt_flags.use_thinking) self.action_prompt = dp.ActionPrompt( - action_set=action_set, action_flags=vl_prompt_flags.action_flags + action_set=action_set, action_flags=self.vl_prompt_flags.action_flags ) - self._prompt = HumanMessage(content=self.instructions.prompt) - self._prompt.add_text( - f"""\ -{self.observation.prompt} -{self.history.prompt} -{self.think.prompt} -{self.action_prompt.prompt} -""" - ) - if vl_prompt_flags.use_abstract_example: - self._prompt.add_text( - f"""\ -# Abstract Example: -{self.think.abstract_ex} -{self.action_prompt.abstract_ex} -""" + self._prompt = f"{self.instructions.prompt}\n{self.observation.prompt}\n{self.history.prompt}\n{self.think.prompt}\n{self.action_prompt.prompt}\n" + if self.vl_prompt_flags.use_abstract_example: + self._prompt += ( + f"# Abstract Example:\n{self.think.abstract_ex}\n{self.action_prompt.abstract_ex}\n" ) - if vl_prompt_flags.use_concrete_example: - self._prompt.add_text( - f"""\ -# Concrete Example: -{self.think.concrete_ex} -{self.action_prompt.concrete_ex} -""" + if self.vl_prompt_flags.use_concrete_example: + self._prompt += ( + f"# Concrete Example:\n{self.think.concrete_ex}\n{self.action_prompt.concrete_ex}\n" ) - self.observation.add_screenshot(self._prompt) + + def get_message(self) -> HumanMessage: + message = HumanMessage(content=self.prompt) + if self.vl_prompt_flags.obs_flags.use_screenshot: + if self.vl_prompt_flags.obs_flags.use_som: + screenshot = self.obs_history[-1]["screenshot_som"] + message.add_text( + "## Screenshot:\nHere is a screenshot of the page, it is annotated with bounding boxes and corresponding bids:\n" + ) + else: + screenshot = self.obs_history[-1]["screenshot"] + message.add_text("## Screenshot:\nHere is a screenshot of the page:\n") + message.add_image(screenshot) + return message def _parse_answer(self, text_answer: str) -> dict: answer = {} From 6c272676e5cce6e78bfc10ec2f2b9af6ac01fd03 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 10 May 2025 13:33:51 +0000 Subject: [PATCH 07/29] update --- src/agentlab/agents/vl_agent/config.py | 18 +-- src/agentlab/agents/vl_agent/vl_agent.py | 54 +++---- src/agentlab/agents/vl_agent/vl_prompt.py | 167 ++++++++++++++-------- 3 files changed, 134 insertions(+), 105 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index 67e2561d..ae03f12a 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -1,7 +1,7 @@ from browsergym.experiments.benchmark import HighLevelActionSetArgs from .vl_agent import UIAgentArgs from .vl_model import LlamaModelArgs -from .vl_prompt import VLPromptFlags +from .vl_prompt import UIPromptArgs import agentlab.agents.dynamic_prompting as dp @@ -18,8 +18,9 @@ } -VL_PROMPT_FLAGS_DICT = { - "default": VLPromptFlags( +VL_PROMPT_ARGS_DICT = { + "ui_prompt-default": UIPromptArgs( + prompt_name="ui_prompt-default", obs_flags=dp.ObsFlags( use_tabs=True, use_error_logs=True, @@ -33,19 +34,20 @@ long_description=True, individual_examples=False, ), + extra_instructions=None, + enable_chat=False, use_thinking=True, - use_concrete_example=False, use_abstract_example=True, - enable_chat=False, - extra_instructions=None, + use_concrete_example=False, ) } VL_AGENT_ARGS_DICT = { - "ui_agent-llama_32_11b": UIAgentArgs( + "ui_agent-llama_32_11b-llama_32_11b": UIAgentArgs( + agent_name="ui_agent-llama_32_11b-llama_32_11b", main_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - vl_prompt_flags=VL_PROMPT_FLAGS_DICT["default"], + ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt-default"], max_retry=4, ) } diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 27a87b5e..0bde4906 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -2,26 +2,23 @@ from agentlab.agents import dynamic_prompting as dp from agentlab.llm.llm_utils import Discussion, ParseError, retry, SystemMessage from agentlab.llm.tracking import cost_tracker_decorator -from browsergym.core.action.base import AbstractActionSet from browsergym.experiments.agent import AgentInfo from browsergym.experiments.benchmark import Benchmark from copy import deepcopy from dataclasses import asdict, dataclass from typing import Optional from .vl_model import VLModelArgs -from .vl_prompt import VLPrompt, VLPromptFlags +from .vl_prompt import UIPromptArgs -@dataclass class VLAgent(ABC): - action_set: AbstractActionSet - @abstractmethod def get_action(self, obs: dict) -> tuple[str, dict]: raise NotImplementedError + @property @abstractmethod - def obs_preprocessor(self, obs: dict) -> dict: + def obs_preprocessor(self) -> callable: raise NotImplementedError @@ -55,7 +52,7 @@ def __init__( self, main_vl_model_args: VLModelArgs, auxiliary_vl_model_args: Optional[VLModelArgs], - vl_prompt_flags: VLPromptFlags, + ui_prompt_args: UIPromptArgs, max_retry: int, ): self.main_vl_model = main_vl_model_args.make_model() @@ -63,9 +60,7 @@ def __init__( self.auxiliary_vl_model = None else: self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() - self.action_set = vl_prompt_flags.action_flags.action_set.make_action_set() - self._obs_preprocessor = dp.make_obs_preprocessor(vl_prompt_flags.obs_flags) - self.vl_prompt_flags = vl_prompt_flags + self.ui_prompt_args = ui_prompt_args self.max_retry = max_retry self.obs_history = [] self.actions = [] @@ -74,22 +69,16 @@ def __init__( @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: self.obs_history.append(obs) - vl_prompt = VLPrompt( - vl_prompt_flags=self.vl_prompt_flags, - action_set=self.action_set, - obs_history=self.obs_history, - thoughts=self.thoughts, - actions=self.actions, - ) + ui_prompt = self.ui_prompt_args.make_prompt(self.obs_history, self.actions, self.thoughts) try: messages = Discussion( - [SystemMessage(dp.SystemPrompt().prompt), vl_prompt.get_message()] + [SystemMessage(dp.SystemPrompt().prompt), ui_prompt.get_message()] ) answer = retry( chat=self.main_vl_model, messages=messages, n_retry=self.max_retry, - parser=vl_prompt.parse_answer, + parser=ui_prompt.answer_parser, ) stats = {"num_main_retries": (len(messages) - 3) // 2} except ParseError: @@ -99,14 +88,14 @@ def get_action(self, obs: dict) -> tuple[str, dict]: if self.auxiliary_vl_model is not None: try: messages = Discussion( - [SystemMessage(dp.SystemPrompt().prompt), vl_prompt.get_message()] + [SystemMessage(dp.SystemPrompt().prompt), ui_prompt.get_message()] ) messages.add_text(f"{answer['think']}\n{answer['action']}\n") answer = retry( chat=self.auxiliary_vl_model, messages=messages, n_retry=self.max_retry, - parser=vl_prompt.parse_answer, + parser=ui_prompt.answer_parser, ) stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 except ParseError: @@ -118,28 +107,23 @@ def get_action(self, obs: dict) -> tuple[str, dict]: agent_info = AgentInfo(think=answer["think"], chat_messages=messages, stats=stats) return answer["action"], asdict(agent_info) - def obs_preprocessor(self, obs: dict) -> dict: - return self._obs_preprocessor(obs) + @property + def obs_preprocessor(self) -> callable: + return dp.make_obs_preprocessor(self.ui_prompt_args.obs_flags) @dataclass class UIAgentArgs(VLAgentArgs): main_vl_model_args: VLModelArgs auxiliary_vl_model_args: VLModelArgs - vl_prompt_flags: VLPromptFlags + ui_prompt_args: UIPromptArgs max_retry: int - def __post_init__(self): - if self.auxiliary_vl_model_args is None: - self.agent_name = f"ui_agent-{self.main_vl_model_args.model_name}" - else: - self.agent_name = f"ui_agent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" - def make_agent(self) -> UIAgent: return UIAgent( main_vl_model_args=self.main_vl_model_args, auxiliary_vl_model_args=self.auxiliary_vl_model_args, - vl_prompt_flags=self.vl_prompt_flags, + ui_prompt_args=self.ui_prompt_args, max_retry=self.max_retry, ) @@ -159,9 +143,7 @@ def set_reproducibility_mode(self): self.auxiliary_vl_model_args.set_reproducibility_mode() def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): - self.vl_prompt_flags.obs_flags.use_tabs = benchmark.is_multi_tab - self.vl_prompt_flags.action_flags.action_set = deepcopy( - benchmark.high_level_action_set_args - ) + self.ui_prompt_args.obs_flags.use_tabs = benchmark.is_multi_tab + self.ui_prompt_args.action_flags.action_set = deepcopy(benchmark.high_level_action_set_args) if demo_mode: - self.vl_prompt_flags.action_flags.action_set.demo_mode = "all_blue" + self.ui_prompt_args.action_flags.action_set.demo_mode = "all_blue" diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 6addedce..76c1bda4 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,83 +1,128 @@ +from abc import ABC, abstractmethod from agentlab.agents import dynamic_prompting as dp from agentlab.llm.llm_utils import HumanMessage -from browsergym.core.action.base import AbstractActionSet from dataclasses import dataclass -from typing import Optional +from PIL import Image +from typing import Optional, Union + + +class VLPrompt(ABC): + @abstractmethod + def get_message(self) -> HumanMessage: + raise NotImplementedError + + @abstractmethod + def answer_parser(self, answer_text: str) -> dict: + raise NotImplementedError + + +@dataclass +class VLPromptArgs(ABC): + prompt_name: str + + @abstractmethod + def make_prompt( + self, obs_history: list[dict], actions: list[str], thoughts: list[str] + ) -> VLPrompt: + raise NotImplementedError @dataclass -class VLPromptFlags(dp.Flags): +class UIPrompt(VLPrompt): + instructions: Union[dp.ChatInstructions, dp.GoalInstructions] + screenshot: Optional[Image.Image] + observation: dp.Observation + history: dp.History + think: dp.Think + action_prompt: dp.ActionPrompt + abstract_example: Optional[str] + concrete_example: Optional[str] + + def get_message(self) -> HumanMessage: + message = HumanMessage(self.instructions.prompt) + if self.screenshot is not None: + message.add_text("# Screenshot:\n") + message.add_image(self.screenshot) + message.add_text(self.observation.prompt) + message.add_text(self.history.prompt) + message.add_text(self.think.prompt) + message.add_text(self.action_prompt.prompt) + if self.abstract_example is not None: + message.add_text(self.abstract_example) + if self.concrete_example is not None: + message.add_text(self.concrete_example) + return message + + def answer_parser(self, answer_text: str) -> dict: + answer_dict = {} + answer_dict.update(self.think.parse_answer(answer_text)) + answer_dict.update(self.action_prompt.parse_answer(answer_text)) + return answer_dict + + +@dataclass +class UIPromptArgs(VLPromptArgs): obs_flags: dp.ObsFlags action_flags: dp.ActionFlags + extra_instructions: Optional[str] + enable_chat: bool use_thinking: bool - use_concrete_example: bool use_abstract_example: bool - enable_chat: bool - extra_instructions: Optional[str] - + use_concrete_example: bool -class VLPrompt(dp.PromptElement): - def __init__( - self, - vl_prompt_flags: VLPromptFlags, - action_set: AbstractActionSet, - obs_history: list[dict], - thoughts: list[str], - actions: list[str], - ): - super().__init__() - self.vl_prompt_flags = vl_prompt_flags - self.obs_history = obs_history - if self.vl_prompt_flags.enable_chat: - self.instructions = dp.ChatInstructions( - chat_messages=self.obs_history[-1]["chat_messages"], - extra_instructions=self.vl_prompt_flags.extra_instructions, + def make_prompt( + self, obs_history: list[dict], actions: list[str], thoughts: list[str] + ) -> UIPrompt: + if self.enable_chat: + instructions = dp.ChatInstructions( + chat_messages=obs_history[-1]["chat_messages"], + extra_instructions=self.extra_instructions, ) else: - self.instructions = dp.GoalInstructions( - goal_object=self.obs_history[-1]["goal_object"], - extra_instructions=self.vl_prompt_flags.extra_instructions, + instructions = dp.GoalInstructions( + goal_object=obs_history[-1]["goal_object"], + extra_instructions=self.extra_instructions, ) - self.observation = dp.Observation( - obs=self.obs_history[-1], flags=self.vl_prompt_flags.obs_flags - ) - self.history = dp.History( - history_obs=self.obs_history, + if self.obs_flags.use_screenshot: + if self.obs_flags.use_som: + screenshot = obs_history[-1]["screenshot_som"] + else: + screenshot = obs_history[-1]["screenshot"] + else: + screenshot = None + observation = dp.Observation(obs=obs_history[-1], flags=self.obs_flags) + history = dp.History( + history_obs=obs_history, actions=actions, memories=None, thoughts=thoughts, - flags=self.vl_prompt_flags.obs_flags, + flags=self.obs_flags, ) - self.think = dp.Think(visible=self.vl_prompt_flags.use_thinking) - self.action_prompt = dp.ActionPrompt( - action_set=action_set, action_flags=self.vl_prompt_flags.action_flags + think = dp.Think(visible=self.use_thinking) + action_prompt = dp.ActionPrompt( + action_set=self.action_flags.action_set.make_action_set(), + action_flags=self.action_flags, ) - self._prompt = f"{self.instructions.prompt}\n{self.observation.prompt}\n{self.history.prompt}\n{self.think.prompt}\n{self.action_prompt.prompt}\n" - if self.vl_prompt_flags.use_abstract_example: - self._prompt += ( - f"# Abstract Example:\n{self.think.abstract_ex}\n{self.action_prompt.abstract_ex}\n" + if self.use_abstract_example: + abstract_example = ( + f"# Abstract Example:\n{think.abstract_ex}{action_prompt.abstract_ex}" ) - if self.vl_prompt_flags.use_concrete_example: - self._prompt += ( - f"# Concrete Example:\n{self.think.concrete_ex}\n{self.action_prompt.concrete_ex}\n" + else: + abstract_example = None + if self.use_concrete_example: + concrete_example = ( + f"# Concrete Example:\n{think.concrete_ex}{action_prompt.concrete_ex}" ) + else: + concrete_example = None - def get_message(self) -> HumanMessage: - message = HumanMessage(content=self.prompt) - if self.vl_prompt_flags.obs_flags.use_screenshot: - if self.vl_prompt_flags.obs_flags.use_som: - screenshot = self.obs_history[-1]["screenshot_som"] - message.add_text( - "## Screenshot:\nHere is a screenshot of the page, it is annotated with bounding boxes and corresponding bids:\n" - ) - else: - screenshot = self.obs_history[-1]["screenshot"] - message.add_text("## Screenshot:\nHere is a screenshot of the page:\n") - message.add_image(screenshot) - return message - - def _parse_answer(self, text_answer: str) -> dict: - answer = {} - answer.update(self.think.parse_answer(text_answer)) - answer.update(self.action_prompt.parse_answer(text_answer)) - return answer + return UIPrompt( + instructions=instructions, + screenshot=screenshot, + observation=observation, + history=history, + think=think, + action_prompt=action_prompt, + abstract_example=abstract_example, + concrete_example=concrete_example, + ) From 1cfe4164aa1acac9999823814b6dfed6426119ad Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 10 May 2025 13:48:09 +0000 Subject: [PATCH 08/29] update --- src/agentlab/agents/vl_agent/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index ae03f12a..7f58c4e1 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -17,7 +17,6 @@ ) } - VL_PROMPT_ARGS_DICT = { "ui_prompt-default": UIPromptArgs( prompt_name="ui_prompt-default", From 3b6561a604e4815fd904e581f4ddb068f585c9b0 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 10 May 2025 13:57:43 +0000 Subject: [PATCH 09/29] update --- src/agentlab/agents/vl_agent/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index 7f58c4e1..a9578ed9 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -26,7 +26,6 @@ use_past_error_logs=False, use_screenshot=True, use_som=False, - openai_vision_detail="auto", ), action_flags=dp.ActionFlags( action_set=HighLevelActionSetArgs(subsets=["coord"]), From a1361ad5cb83cca92edb3dbdf3b436b337e85d2d Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 17 May 2025 04:57:07 +0000 Subject: [PATCH 10/29] update --- src/agentlab/agents/vl_agent/config.py | 5 +- src/agentlab/agents/vl_agent/utils.py | 16 ++ src/agentlab/agents/vl_agent/vl_agent.py | 16 +- src/agentlab/agents/vl_agent/vl_model.py | 10 +- src/agentlab/agents/vl_agent/vl_prompt.py | 330 ++++++++++++++++++---- 5 files changed, 311 insertions(+), 66 deletions(-) create mode 100644 src/agentlab/agents/vl_agent/utils.py diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index a9578ed9..d29b274e 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -7,7 +7,6 @@ VL_MODEL_ARGS_DICT = { "llama_32_11b": LlamaModelArgs( - model_name="llama_32_11b", model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype="bfloat16", checkpoint_dir=None, @@ -19,7 +18,6 @@ VL_PROMPT_ARGS_DICT = { "ui_prompt-default": UIPromptArgs( - prompt_name="ui_prompt-default", obs_flags=dp.ObsFlags( use_tabs=True, use_error_logs=True, @@ -41,8 +39,7 @@ } VL_AGENT_ARGS_DICT = { - "ui_agent-llama_32_11b-llama_32_11b": UIAgentArgs( - agent_name="ui_agent-llama_32_11b-llama_32_11b", + "ui_agent-default": UIAgentArgs( main_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt-default"], diff --git a/src/agentlab/agents/vl_agent/utils.py b/src/agentlab/agents/vl_agent/utils.py new file mode 100644 index 00000000..e23714e8 --- /dev/null +++ b/src/agentlab/agents/vl_agent/utils.py @@ -0,0 +1,16 @@ +from PIL import Image +from typing import Union +import base64 +import io +import numpy as np + + +def image_to_image_url(image: Union[Image.Image, np.ndarray]): + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") + buffer = io.BytesIO() + image.save(buffer, format="JPEG") + image_base64 = base64.b64encode(buffer.getvalue()).decode() + return f"data:image/jpeg;base64,{image_base64}" diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 0bde4906..6f6bccca 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -24,7 +24,10 @@ def obs_preprocessor(self) -> callable: @dataclass class VLAgentArgs(ABC): - agent_name: str + @property + @abstractmethod + def agent_name(self) -> str: + raise NotImplementedError @abstractmethod def make_agent(self) -> VLAgent: @@ -69,7 +72,12 @@ def __init__( @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: self.obs_history.append(obs) - ui_prompt = self.ui_prompt_args.make_prompt(self.obs_history, self.actions, self.thoughts) + ui_prompt = self.ui_prompt_args.make_prompt( + obs_history=self.obs_history, + actions=self.actions, + thoughts=self.thoughts, + extra_instructions=None, + ) try: messages = Discussion( [SystemMessage(dp.SystemPrompt().prompt), ui_prompt.get_message()] @@ -119,6 +127,10 @@ class UIAgentArgs(VLAgentArgs): ui_prompt_args: UIPromptArgs max_retry: int + @property + def agent_name(self) -> str: + return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" + def make_agent(self) -> UIAgent: return UIAgent( main_vl_model_args=self.main_vl_model_args, diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index eccb1ca8..8dcb73d5 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -17,9 +17,11 @@ def get_stats(self) -> dict: raise NotImplementedError -@dataclass class VLModelArgs(ABC): - model_name: str + @property + @abstractmethod + def model_name(self) -> str: + raise NotImplementedError @abstractmethod def make_model(self) -> VLModel: @@ -81,6 +83,10 @@ class LlamaModelArgs(VLModelArgs): max_new_tokens: int reproducibility_config: dict + @property + def model_name(self) -> str: + return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") + def make_model(self) -> LlamaModel: return LlamaModel( model_path=self.model_path, diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 76c1bda4..347e2b4f 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,32 +1,241 @@ from abc import ABC, abstractmethod from agentlab.agents import dynamic_prompting as dp -from agentlab.llm.llm_utils import HumanMessage +from agentlab.llm.llm_utils import HumanMessage, SystemMessage +from browsergym.experiments.benchmark.base import HighLevelActionSetArgs +from browsergym.core.action.highlevel import HighLevelActionSet from dataclasses import dataclass from PIL import Image from typing import Optional, Union +from .utils import image_to_image_url +import numpy as np +import time + + +class VLPromptPart(ABC): + @abstractmethod + def get_message_items(self) -> list[dict]: + raise NotImplementedError class VLPrompt(ABC): @abstractmethod - def get_message(self) -> HumanMessage: + def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: raise NotImplementedError @abstractmethod - def answer_parser(self, answer_text: str) -> dict: + def parse_answer(self, answer_text: str) -> dict: raise NotImplementedError -@dataclass class VLPromptArgs(ABC): - prompt_name: str - @abstractmethod - def make_prompt( - self, obs_history: list[dict], actions: list[str], thoughts: list[str] - ) -> VLPrompt: + def make_prompt(self, obs: dict, actions: list[str], thoughts: list[str]) -> VLPrompt: raise NotImplementedError +class SystemPromptPart(VLPromptPart): + def __init__(self, text: Optional[str]): + if text is None: + text = """\ +You are an agent trying to solve a web task based on the content of the page and user instructions. \ +You can interact with the page and explore, and send messages to the user. \ +Each time you submit an action it will be sent to the browser and you will receive a new page. +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class ChatInstructionPromptPart(VLPromptPart): + def __init__( + self, + chat_messages: list[dict], + extra_instruction: Optional[str], + ): + text = """\ +# Instruction +Your goal is to help the user perform tasks using a web browser. \ +You can communicate with the user via a chat, in which the user gives you instructions and in which you can send back messages. \ +Review the current state of the page and all other information to find the best possible next action to accomplish your goal. \ +Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. +## Chat Messages +""" + for chat_message in chat_messages: + text += f"""\ +[{time.asctime(time.localtime(chat_message['timestamp']))}] {chat_message['role']}: {chat_message['message']} +""" + if extra_instruction is not None: + text += f"""\ +## Extra Instruction +{extra_instruction} +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class GoalInstructionPromptPart(VLPromptPart): + def __init__( + self, + goal_object: list[dict], + extra_instruction: Optional[str], + ): + text = """\ +# Instruction +Review the current state of the page and all other information to find the best possible next action to accomplish your goal. \ +Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. +## Goal +""" + for item in goal_object: + if item["type"] == "text": + text += f"""\ +{item['text']} +""" + if extra_instruction is not None: + text += f"""\ +## Extra Instruction +{extra_instruction} +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class ScreenshotPromptPart(VLPromptPart): + def __init__(self, screenshot: Union[Image.Image, np.ndarray]): + self.text = """\ +# Screenshot +""" + self.image_url = image_to_image_url(screenshot) + + def get_message_items(self) -> list[dict]: + return [ + {"type": "text", "text": self.text}, + {"type": "image_url", "image_url": {"url": self.image_url}}, + ] + + +class TabsPromptPart(VLPromptPart): + def __init__( + self, open_pages_titles: list[str], open_pages_urls: list[str], active_page_index: int + ): + text = """\ +# Open Tabs +""" + for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): + text += f"""\ +Tab {index}{' (active tab)' if index == active_page_index else ''}: {title} ({url}) +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class ErrorPromptPart(VLPromptPart): + def __init__(self, last_action_error: str): + text = """\ +# Error from Last Action +""" + separator = "Call log:" + if separator in last_action_error: + error, logs = last_action_error.split(separator) + error = error.strip() + logs = logs.split("\n") + text += f"""\ +{error} +{separator} +""" + for log in logs[:10]: + text += f"""\ +{log} +""" + else: + text += f"""\ +{last_action_error} +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class HistoryPromptPart(VLPromptPart): + def __init__(self, thoughts: list[str], actions: list[str]): + text = """\ +# Thoughts and Actions of Previous Steps +""" + for index, (thought, action) in enumerate(zip(thoughts, actions)): + text += f""" +## Step {index} +### Thought +{thought} +### Action +{action} +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class AnswerPromptPart(VLPromptPart): + def __init__( + self, + action_set: HighLevelActionSet, + use_abstract_example: bool, + use_concrete_example: bool, + preliminary_answer: Optional[dict], + ): + text = f"""\ +# Answer Format Requirements +## Action Space +These actions allow you to interact with your environment. \ +Most of them are python functions executing playwright code. +{action_set.describe(with_long_description=True, with_examples=False)} +""" + if use_abstract_example: + text += """ +## Abstract Example + +The thought about which action to take at the current step. + + +One single action to be executed. You can only use one action at a time. + +""" + if use_concrete_example: + text += """ +## Concrete Example + +From previous action I tried to set the value of year to "2022", using select_option, but it doesn't appear to be in the form. \ +It may be a dynamic dropdown, I will try using click with the bid "a324" and look at the response from the page. + + +click('a324') + +""" + if preliminary_answer is not None: + text += f""" +## Preliminary Anser +Here is a preliminary anser, which might be incorrect or inaccurate. \ +Refine it to get the final answer. + +{preliminary_answer['thought']} + + +{preliminary_answer['action']} + +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + @dataclass class UIPrompt(VLPrompt): instructions: Union[dp.ChatInstructions, dp.GoalInstructions] @@ -37,8 +246,9 @@ class UIPrompt(VLPrompt): action_prompt: dp.ActionPrompt abstract_example: Optional[str] concrete_example: Optional[str] + preliminary_answer: Optional[str] - def get_message(self) -> HumanMessage: + def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: message = HumanMessage(self.instructions.prompt) if self.screenshot is not None: message.add_text("# Screenshot:\n") @@ -51,9 +261,11 @@ def get_message(self) -> HumanMessage: message.add_text(self.abstract_example) if self.concrete_example is not None: message.add_text(self.concrete_example) + if self.preliminary_answer is not None: + message.add_text(self.preliminary_answer) return message - def answer_parser(self, answer_text: str) -> dict: + def parse_answer(self, answer_text: str) -> dict: answer_dict = {} answer_dict.update(self.think.parse_answer(answer_text)) answer_dict.update(self.action_prompt.parse_answer(answer_text)) @@ -62,67 +274,69 @@ def answer_parser(self, answer_text: str) -> dict: @dataclass class UIPromptArgs(VLPromptArgs): - obs_flags: dp.ObsFlags - action_flags: dp.ActionFlags - extra_instructions: Optional[str] + action_set_args: HighLevelActionSetArgs enable_chat: bool - use_thinking: bool + use_screenshot: bool + use_som: bool + use_tabs: bool + use_error: bool + use_history: bool use_abstract_example: bool use_concrete_example: bool def make_prompt( - self, obs_history: list[dict], actions: list[str], thoughts: list[str] + self, + obs: dict, + actions: list[str], + thoughts: list[str], + extra_instruction: Optional[str], + preliminary_answer: Optional[dict], ) -> UIPrompt: if self.enable_chat: - instructions = dp.ChatInstructions( - chat_messages=obs_history[-1]["chat_messages"], - extra_instructions=self.extra_instructions, + instruction_prompt_part = ChatInstructionPromptPart( + chat_messages=obs["chat_messages"], + extra_instruction=extra_instruction, ) else: - instructions = dp.GoalInstructions( - goal_object=obs_history[-1]["goal_object"], - extra_instructions=self.extra_instructions, + instruction_prompt_part = GoalInstructionPromptPart( + goal_object=obs["goal_object"], + extra_instruction=extra_instruction, ) - if self.obs_flags.use_screenshot: - if self.obs_flags.use_som: - screenshot = obs_history[-1]["screenshot_som"] + if self.use_screenshot: + if self.use_som: + screenshot = obs["screenshot_som"] else: - screenshot = obs_history[-1]["screenshot"] + screenshot = obs["screenshot"] + screenshot_prompt_part = ScreenshotPromptPart(screenshot) else: - screenshot = None - observation = dp.Observation(obs=obs_history[-1], flags=self.obs_flags) - history = dp.History( - history_obs=obs_history, - actions=actions, - memories=None, - thoughts=thoughts, - flags=self.obs_flags, - ) - think = dp.Think(visible=self.use_thinking) - action_prompt = dp.ActionPrompt( - action_set=self.action_flags.action_set.make_action_set(), - action_flags=self.action_flags, - ) - if self.use_abstract_example: - abstract_example = ( - f"# Abstract Example:\n{think.abstract_ex}{action_prompt.abstract_ex}" + screenshot_prompt_part = None + if self.use_tabs: + tabs_prompt_part = TabsPromptPart( + open_pages_titles=obs["open_pages_titles"], + open_pages_urls=obs["open_pages_urls"], + active_page_index=obs["active_page_index"], ) else: - abstract_example = None - if self.use_concrete_example: - concrete_example = ( - f"# Concrete Example:\n{think.concrete_ex}{action_prompt.concrete_ex}" - ) + tabs_prompt_part = None + if self.use_error and obs["last_action_error"]: + error_prompt_part = ErrorPromptPart(obs["last_action_error"]) else: - concrete_example = None - + error_prompt_part = None + if self.use_history: + history_prompt_part = HistoryPromptPart(thoughts=thoughts, actions=actions) + else: + history_prompt_part = None + answer_prompt_part = AnswerPromptPart( + action_set=self.action_set_args.make_action_set(), + use_abstract_example=self.use_abstract_example, + use_concrete_example=self.use_concrete_example, + preliminary_answer=preliminary_answer, + ) return UIPrompt( - instructions=instructions, - screenshot=screenshot, - observation=observation, - history=history, - think=think, - action_prompt=action_prompt, - abstract_example=abstract_example, - concrete_example=concrete_example, + instruction_prompt_part=instruction_prompt_part, + screenshot_prompt_part=screenshot_prompt_part, + tabs_prompt_part=tabs_prompt_part, + error_prompt_part=error_prompt_part, + history_prompt_part=history_prompt_part, + answer_prompt_part=answer_prompt_part, ) From c536f4fd18e6577b447577fb893b1f748fe85d7c Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sun, 18 May 2025 01:34:16 +0000 Subject: [PATCH 11/29] update --- src/agentlab/agents/vl_agent/vl_prompt.py | 171 +++++++++------------- 1 file changed, 69 insertions(+), 102 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 347e2b4f..d5fe0aa4 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from agentlab.agents import dynamic_prompting as dp from agentlab.llm.llm_utils import HumanMessage, SystemMessage from browsergym.experiments.benchmark.base import HighLevelActionSetArgs from browsergym.core.action.highlevel import HighLevelActionSet @@ -8,7 +7,6 @@ from typing import Optional, Union from .utils import image_to_image_url import numpy as np -import time class VLPromptPart(ABC): @@ -37,9 +35,9 @@ class SystemPromptPart(VLPromptPart): def __init__(self, text: Optional[str]): if text is None: text = """\ -You are an agent trying to solve a web task based on the content of the page and user instructions. \ -You can interact with the page and explore, and send messages to the user. \ -Each time you submit an action it will be sent to the browser and you will receive a new page. +You are an agent working to address a web-based task through step-by-step interactions with the browser. \ +At each step, you need to submit an action according to the current state of the browser. \ +This action will be executed and the state of the browser will be updated. """ self.text = text @@ -47,36 +45,7 @@ def get_message_items(self) -> list[dict]: return [{"type": "text", "text": self.text}] -class ChatInstructionPromptPart(VLPromptPart): - def __init__( - self, - chat_messages: list[dict], - extra_instruction: Optional[str], - ): - text = """\ -# Instruction -Your goal is to help the user perform tasks using a web browser. \ -You can communicate with the user via a chat, in which the user gives you instructions and in which you can send back messages. \ -Review the current state of the page and all other information to find the best possible next action to accomplish your goal. \ -Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. -## Chat Messages -""" - for chat_message in chat_messages: - text += f"""\ -[{time.asctime(time.localtime(chat_message['timestamp']))}] {chat_message['role']}: {chat_message['message']} -""" - if extra_instruction is not None: - text += f"""\ -## Extra Instruction -{extra_instruction} -""" - self.text = text - - def get_message_items(self) -> list[dict]: - return [{"type": "text", "text": self.text}] - - -class GoalInstructionPromptPart(VLPromptPart): +class InstructionPromptPart(VLPromptPart): def __init__( self, goal_object: list[dict], @@ -84,8 +53,7 @@ def __init__( ): text = """\ # Instruction -Review the current state of the page and all other information to find the best possible next action to accomplish your goal. \ -Your answer will be interpreted and executed by a program, make sure to follow the formatting instructions. +Review the current state of the browser and all other information to find the next action to achieve the goal. ## Goal """ for item in goal_object: @@ -107,7 +75,7 @@ def get_message_items(self) -> list[dict]: class ScreenshotPromptPart(VLPromptPart): def __init__(self, screenshot: Union[Image.Image, np.ndarray]): self.text = """\ -# Screenshot +# The Screenshot of the Current Page """ self.image_url = image_to_image_url(screenshot) @@ -123,7 +91,7 @@ def __init__( self, open_pages_titles: list[str], open_pages_urls: list[str], active_page_index: int ): text = """\ -# Open Tabs +# The Open Tabs of the Browser """ for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): text += f"""\ @@ -138,7 +106,7 @@ def get_message_items(self) -> list[dict]: class ErrorPromptPart(VLPromptPart): def __init__(self, last_action_error: str): text = """\ -# Error from Last Action +# The Error from the Last Action """ separator = "Call log:" if separator in last_action_error: @@ -166,7 +134,7 @@ def get_message_items(self) -> list[dict]: class HistoryPromptPart(VLPromptPart): def __init__(self, thoughts: list[str], actions: list[str]): text = """\ -# Thoughts and Actions of Previous Steps +# The Previous Steps """ for index, (thought, action) in enumerate(zip(thoughts, actions)): text += f""" @@ -182,34 +150,45 @@ def get_message_items(self) -> list[dict]: return [{"type": "text", "text": self.text}] +class ActionPromptPart(VLPromptPart): + def __init__(self, action_set: HighLevelActionSet): + text = f"""\ +# Action Space +Here are the actions you can take to interact with the browser. \ +They are Python functions based on the Playwright library. +{action_set.describe(with_long_description=True, with_examples=False)} +""" + self.text = text + + def get_message_items(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + class AnswerPromptPart(VLPromptPart): def __init__( self, - action_set: HighLevelActionSet, use_abstract_example: bool, use_concrete_example: bool, preliminary_answer: Optional[dict], ): - text = f"""\ -# Answer Format Requirements -## Action Space -These actions allow you to interact with your environment. \ -Most of them are python functions executing playwright code. -{action_set.describe(with_long_description=True, with_examples=False)} + text = """\ +# Answer Format +Think about the next action, and choose it from the action space. \ +Your answer should include both the thought and the next action. """ if use_abstract_example: text += """ -## Abstract Example +## An Abstract Example of the Answer -The thought about which action to take at the current step. +The thought about the next action. -One single action to be executed. You can only use one action at a time. +The next action to take. """ if use_concrete_example: text += """ -## Concrete Example +## A Concrete Example of the Answer From previous action I tried to set the value of year to "2022", using select_option, but it doesn't appear to be in the form. \ It may be a dynamic dropdown, I will try using click with the bid "a324" and look at the response from the page. @@ -220,9 +199,9 @@ def __init__( """ if preliminary_answer is not None: text += f""" -## Preliminary Anser +## A Preliminary Answer to Refine Here is a preliminary anser, which might be incorrect or inaccurate. \ -Refine it to get the final answer. +You can refine it to obtain your answer. {preliminary_answer['thought']} @@ -238,32 +217,29 @@ def get_message_items(self) -> list[dict]: @dataclass class UIPrompt(VLPrompt): - instructions: Union[dp.ChatInstructions, dp.GoalInstructions] - screenshot: Optional[Image.Image] - observation: dp.Observation - history: dp.History - think: dp.Think - action_prompt: dp.ActionPrompt - abstract_example: Optional[str] - concrete_example: Optional[str] - preliminary_answer: Optional[str] + system_prompt_part: SystemPromptPart + instruction_prompt_part: InstructionPromptPart + screenshot_prompt_part: Optional[ScreenshotPromptPart] + tabs_prompt_part: Optional[TabsPromptPart] + history_prompt_part: Optional[HistoryPromptPart] + error_prompt_part: Optional[ErrorPromptPart] + action_prompt_part: ActionPromptPart + answer_prompt_part: AnswerPromptPart def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: - message = HumanMessage(self.instructions.prompt) - if self.screenshot is not None: - message.add_text("# Screenshot:\n") - message.add_image(self.screenshot) - message.add_text(self.observation.prompt) - message.add_text(self.history.prompt) - message.add_text(self.think.prompt) - message.add_text(self.action_prompt.prompt) - if self.abstract_example is not None: - message.add_text(self.abstract_example) - if self.concrete_example is not None: - message.add_text(self.concrete_example) - if self.preliminary_answer is not None: - message.add_text(self.preliminary_answer) - return message + system_message_items = self.system_prompt_part.get_message_items() + human_message_items = self.instruction_prompt_part.get_message_items() + if self.screenshot_prompt_part is not None: + human_message_items.extend(self.screenshot_prompt_part.get_message_items()) + if self.tabs_prompt_part is not None: + human_message_items.extend(self.tabs_prompt_part.get_message_items()) + if self.history_prompt_part is not None: + human_message_items.extend(self.history_prompt_part.get_message_items()) + if self.error_prompt_part is not None: + human_message_items.extend(self.error_prompt_part.get_message_items()) + human_message_items.extend(self.action_prompt_part.get_message_items()) + human_message_items.extend(self.answer_prompt_part.get_message_items()) + return [SystemMessage(system_message_items), HumanMessage(human_message_items)] def parse_answer(self, answer_text: str) -> dict: answer_dict = {} @@ -275,12 +251,10 @@ def parse_answer(self, answer_text: str) -> dict: @dataclass class UIPromptArgs(VLPromptArgs): action_set_args: HighLevelActionSetArgs - enable_chat: bool use_screenshot: bool - use_som: bool use_tabs: bool - use_error: bool use_history: bool + use_error: bool use_abstract_example: bool use_concrete_example: bool @@ -292,22 +266,13 @@ def make_prompt( extra_instruction: Optional[str], preliminary_answer: Optional[dict], ) -> UIPrompt: - if self.enable_chat: - instruction_prompt_part = ChatInstructionPromptPart( - chat_messages=obs["chat_messages"], - extra_instruction=extra_instruction, - ) - else: - instruction_prompt_part = GoalInstructionPromptPart( - goal_object=obs["goal_object"], - extra_instruction=extra_instruction, - ) + system_prompt_part = SystemPromptPart(None) + instruction_prompt_part = InstructionPromptPart( + goal_object=obs["goal_object"], + extra_instruction=extra_instruction, + ) if self.use_screenshot: - if self.use_som: - screenshot = obs["screenshot_som"] - else: - screenshot = obs["screenshot"] - screenshot_prompt_part = ScreenshotPromptPart(screenshot) + screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"]) else: screenshot_prompt_part = None if self.use_tabs: @@ -318,25 +283,27 @@ def make_prompt( ) else: tabs_prompt_part = None - if self.use_error and obs["last_action_error"]: - error_prompt_part = ErrorPromptPart(obs["last_action_error"]) - else: - error_prompt_part = None if self.use_history: history_prompt_part = HistoryPromptPart(thoughts=thoughts, actions=actions) else: history_prompt_part = None + if self.use_error and obs["last_action_error"]: + error_prompt_part = ErrorPromptPart(obs["last_action_error"]) + else: + error_prompt_part = None + action_prompt_part = ActionPromptPart(self.action_set_args.make_action_set()) answer_prompt_part = AnswerPromptPart( - action_set=self.action_set_args.make_action_set(), use_abstract_example=self.use_abstract_example, use_concrete_example=self.use_concrete_example, preliminary_answer=preliminary_answer, ) return UIPrompt( + system_prompt_part=system_prompt_part, instruction_prompt_part=instruction_prompt_part, screenshot_prompt_part=screenshot_prompt_part, tabs_prompt_part=tabs_prompt_part, - error_prompt_part=error_prompt_part, history_prompt_part=history_prompt_part, + error_prompt_part=error_prompt_part, + action_prompt_part=action_prompt_part, answer_prompt_part=answer_prompt_part, ) From cbbad61eadc0165606a869eaccb8a3ff52c8420d Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sun, 18 May 2025 01:42:56 +0000 Subject: [PATCH 12/29] update --- src/agentlab/agents/vl_agent/vl_prompt.py | 38 +++++++++++------------ 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index d5fe0aa4..6c76227d 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -11,7 +11,7 @@ class VLPromptPart(ABC): @abstractmethod - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: raise NotImplementedError @@ -41,7 +41,7 @@ def __init__(self, text: Optional[str]): """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -68,7 +68,7 @@ def __init__( """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -79,7 +79,7 @@ def __init__(self, screenshot: Union[Image.Image, np.ndarray]): """ self.image_url = image_to_image_url(screenshot) - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [ {"type": "text", "text": self.text}, {"type": "image_url", "image_url": {"url": self.image_url}}, @@ -99,7 +99,7 @@ def __init__( """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -127,7 +127,7 @@ def __init__(self, last_action_error: str): """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -146,7 +146,7 @@ def __init__(self, thoughts: list[str], actions: list[str]): """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -160,7 +160,7 @@ def __init__(self, action_set: HighLevelActionSet): """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -211,7 +211,7 @@ def __init__( """ self.text = text - def get_message_items(self) -> list[dict]: + def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -227,24 +227,22 @@ class UIPrompt(VLPrompt): answer_prompt_part: AnswerPromptPart def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: - system_message_items = self.system_prompt_part.get_message_items() - human_message_items = self.instruction_prompt_part.get_message_items() + system_message_content = self.system_prompt_part.get_message_content() + human_message_content = self.instruction_prompt_part.get_message_content() if self.screenshot_prompt_part is not None: - human_message_items.extend(self.screenshot_prompt_part.get_message_items()) + human_message_content.extend(self.screenshot_prompt_part.get_message_content()) if self.tabs_prompt_part is not None: - human_message_items.extend(self.tabs_prompt_part.get_message_items()) + human_message_content.extend(self.tabs_prompt_part.get_message_content()) if self.history_prompt_part is not None: - human_message_items.extend(self.history_prompt_part.get_message_items()) + human_message_content.extend(self.history_prompt_part.get_message_content()) if self.error_prompt_part is not None: - human_message_items.extend(self.error_prompt_part.get_message_items()) - human_message_items.extend(self.action_prompt_part.get_message_items()) - human_message_items.extend(self.answer_prompt_part.get_message_items()) - return [SystemMessage(system_message_items), HumanMessage(human_message_items)] + human_message_content.extend(self.error_prompt_part.get_message_content()) + human_message_content.extend(self.action_prompt_part.get_message_content()) + human_message_content.extend(self.answer_prompt_part.get_message_content()) + return [SystemMessage(system_message_content), HumanMessage(human_message_content)] def parse_answer(self, answer_text: str) -> dict: answer_dict = {} - answer_dict.update(self.think.parse_answer(answer_text)) - answer_dict.update(self.action_prompt.parse_answer(answer_text)) return answer_dict From 28dd3407cd112e19641d154c533bf33c5f8b4609 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Tue, 20 May 2025 17:55:13 +0000 Subject: [PATCH 13/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 54 ++++---- src/agentlab/agents/vl_agent/vl_prompt.py | 155 +++++++++++++--------- 2 files changed, 124 insertions(+), 85 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 6f6bccca..744d4972 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from agentlab.agents import dynamic_prompting as dp -from agentlab.llm.llm_utils import Discussion, ParseError, retry, SystemMessage +from agentlab.llm.llm_utils import Discussion, ParseError, retry from agentlab.llm.tracking import cost_tracker_decorator from browsergym.experiments.agent import AgentInfo from browsergym.experiments.benchmark import Benchmark @@ -65,59 +64,66 @@ def __init__( self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() self.ui_prompt_args = ui_prompt_args self.max_retry = max_retry - self.obs_history = [] self.actions = [] self.thoughts = [] @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: - self.obs_history.append(obs) ui_prompt = self.ui_prompt_args.make_prompt( - obs_history=self.obs_history, + obs=obs, actions=self.actions, thoughts=self.thoughts, extra_instructions=None, + preliminary_answer=None, ) try: - messages = Discussion( - [SystemMessage(dp.SystemPrompt().prompt), ui_prompt.get_message()] - ) - answer = retry( + messages = Discussion(ui_prompt.get_messages()) + messages.merge() + preliminary_answer = retry( chat=self.main_vl_model, messages=messages, n_retry=self.max_retry, - parser=ui_prompt.answer_parser, + parser=ui_prompt.parse_answer, ) stats = {"num_main_retries": (len(messages) - 3) // 2} except ParseError: - answer = {"think": None, "action": None} + preliminary_answer = {"thought": None, "action": None} stats = {"num_main_retries": self.max_retry} stats.update(self.main_vl_model.get_stats()) - if self.auxiliary_vl_model is not None: + if self.auxiliary_vl_model is None: + final_answer = preliminary_answer + else: try: - messages = Discussion( - [SystemMessage(dp.SystemPrompt().prompt), ui_prompt.get_message()] + ui_prompt = self.ui_prompt_args.make_prompt( + obs=obs, + actions=self.actions, + thoughts=self.thoughts, + extra_instructions=None, + preliminary_answer=preliminary_answer, ) - messages.add_text(f"{answer['think']}\n{answer['action']}\n") - answer = retry( + messages = Discussion(ui_prompt.get_messages()) + messages.merge() + final_answer = retry( chat=self.auxiliary_vl_model, messages=messages, n_retry=self.max_retry, - parser=ui_prompt.answer_parser, + parser=ui_prompt.parse_answer, ) stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 except ParseError: - answer = {"action": None, "think": None} + final_answer = preliminary_answer stats["num_auxiliary_retries"] = self.max_retry stats.update(self.auxiliary_vl_model.get_stats()) - self.thoughts.append(answer["think"]) - self.actions.append(answer["action"]) - agent_info = AgentInfo(think=answer["think"], chat_messages=messages, stats=stats) - return answer["action"], asdict(agent_info) + self.thoughts.append(str(final_answer["thought"])) + self.actions.append(str(final_answer["action"])) + agent_info = AgentInfo( + think=str(final_answer["thought"]), chat_messages=messages, stats=stats + ) + return final_answer["action"], asdict(agent_info) @property - def obs_preprocessor(self) -> callable: - return dp.make_obs_preprocessor(self.ui_prompt_args.obs_flags) + def obs_preprocessor(self, obs: dict) -> dict: + return obs @dataclass diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 6c76227d..38999ad2 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,7 +1,12 @@ from abc import ABC, abstractmethod -from agentlab.llm.llm_utils import HumanMessage, SystemMessage +from agentlab.llm.llm_utils import ( + extract_code_blocks, + HumanMessage, + ParseError, + parse_html_tags_raise, + SystemMessage, +) from browsergym.experiments.benchmark.base import HighLevelActionSetArgs -from browsergym.core.action.highlevel import HighLevelActionSet from dataclasses import dataclass from PIL import Image from typing import Optional, Union @@ -17,7 +22,7 @@ def get_message_content(self) -> list[dict]: class VLPrompt(ABC): @abstractmethod - def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: + def get_messages(self) -> list[Union[HumanMessage, SystemMessage]]: raise NotImplementedError @abstractmethod @@ -27,19 +32,24 @@ def parse_answer(self, answer_text: str) -> dict: class VLPromptArgs(ABC): @abstractmethod - def make_prompt(self, obs: dict, actions: list[str], thoughts: list[str]) -> VLPrompt: + def make_prompt( + self, + obs: dict, + actions: list[str], + thoughts: list[str], + extra_instruction: Optional[str], + preliminary_answer: Optional[dict], + ) -> VLPrompt: raise NotImplementedError class SystemPromptPart(VLPromptPart): - def __init__(self, text: Optional[str]): - if text is None: - text = """\ + def __init__(self): + self.text = """\ You are an agent working to address a web-based task through step-by-step interactions with the browser. \ At each step, you need to submit an action according to the current state of the browser. \ This action will be executed and the state of the browser will be updated. """ - self.text = text def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] @@ -53,7 +63,7 @@ def __init__( ): text = """\ # Instruction -Review the current state of the browser and all other information to find the next action to achieve the goal. +Review the current state of the browser and all other information to find the best next action to achieve the goal. ## Goal """ for item in goal_object: @@ -75,7 +85,7 @@ def get_message_content(self) -> list[dict]: class ScreenshotPromptPart(VLPromptPart): def __init__(self, screenshot: Union[Image.Image, np.ndarray]): self.text = """\ -# The Screenshot of the Current Page +# The Screenshot of the Current Web Page """ self.image_url = image_to_image_url(screenshot) @@ -94,8 +104,31 @@ def __init__( # The Open Tabs of the Browser """ for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): - text += f"""\ -Tab {index}{' (active tab)' if index == active_page_index else ''}: {title} ({url}) + text += f""" +## Tab {index}{' (active tab)' if index == active_page_index else ''} +### Title +{title} +### URL +{url} +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] + + +class HistoryPromptPart(VLPromptPart): + def __init__(self, thoughts: list[str], actions: list[str]): + text = """\ +# The Previous Steps +""" + for index, (thought, action) in enumerate(zip(thoughts, actions)): + text += f""" +## Step {index} +### Thought +{thought} +### Action +{action} """ self.text = text @@ -131,50 +164,23 @@ def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] -class HistoryPromptPart(VLPromptPart): - def __init__(self, thoughts: list[str], actions: list[str]): - text = """\ -# The Previous Steps -""" - for index, (thought, action) in enumerate(zip(thoughts, actions)): - text += f""" -## Step {index} -### Thought -{thought} -### Action -{action} -""" - self.text = text - - def get_message_content(self) -> list[dict]: - return [{"type": "text", "text": self.text}] - - -class ActionPromptPart(VLPromptPart): - def __init__(self, action_set: HighLevelActionSet): - text = f"""\ -# Action Space -Here are the actions you can take to interact with the browser. \ -They are Python functions based on the Playwright library. -{action_set.describe(with_long_description=True, with_examples=False)} -""" - self.text = text - - def get_message_content(self) -> list[dict]: - return [{"type": "text", "text": self.text}] - - class AnswerPromptPart(VLPromptPart): def __init__( self, + action_set_description: str, use_abstract_example: bool, use_concrete_example: bool, preliminary_answer: Optional[dict], ): - text = """\ -# Answer Format -Think about the next action, and choose it from the action space. \ -Your answer should include both the thought and the next action. + text = f"""\ +# Answer Requirements +## Action Space +Here are all the actions you can take to interact with the browser. \ +They are Python functions based on the Playwright library. +{action_set_description} +## Answer Format +Think about the next action to take, and choose it from the action space. \ +Your answer should include both the thought and the action. """ if use_abstract_example: text += """ @@ -183,7 +189,7 @@ def __init__( The thought about the next action. -The next action to take. +The next action. """ if use_concrete_example: @@ -199,9 +205,9 @@ def __init__( """ if preliminary_answer is not None: text += f""" -## A Preliminary Answer to Refine -Here is a preliminary anser, which might be incorrect or inaccurate. \ -You can refine it to obtain your answer. +## A Preliminary Answer +Here is a preliminary answer, which might be incorrect or inaccurate. \ +You can refine it to get your answer. {preliminary_answer['thought']} @@ -223,10 +229,10 @@ class UIPrompt(VLPrompt): tabs_prompt_part: Optional[TabsPromptPart] history_prompt_part: Optional[HistoryPromptPart] error_prompt_part: Optional[ErrorPromptPart] - action_prompt_part: ActionPromptPart answer_prompt_part: AnswerPromptPart + action_validator: callable - def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: + def get_messages(self) -> list[Union[HumanMessage, SystemMessage]]: system_message_content = self.system_prompt_part.get_message_content() human_message_content = self.instruction_prompt_part.get_message_content() if self.screenshot_prompt_part is not None: @@ -237,12 +243,36 @@ def get_messages(self) -> list[Union[SystemMessage, HumanMessage]]: human_message_content.extend(self.history_prompt_part.get_message_content()) if self.error_prompt_part is not None: human_message_content.extend(self.error_prompt_part.get_message_content()) - human_message_content.extend(self.action_prompt_part.get_message_content()) human_message_content.extend(self.answer_prompt_part.get_message_content()) return [SystemMessage(system_message_content), HumanMessage(human_message_content)] def parse_answer(self, answer_text: str) -> dict: answer_dict = {} + try: + answer_dict.update( + parse_html_tags_raise(answer_text, keys=["thought"], merge_multiple=True) + ) + except ParseError as error: + answer_dict["thought"] = answer_text + answer_dict["thought_parse_error"] = str(error) + try: + answer_dict.update( + parse_html_tags_raise(answer_text, keys=["action"], merge_multiple=True) + ) + except ParseError as error: + code_blocks = extract_code_blocks(answer_text) + if len(code_blocks) == 0: + raise error + else: + answer_dict["action"] = "\n".join([block for _, block in code_blocks]) + answer_dict["action_parse_error"] = str(error) + if answer_dict["action"] == "None": + answer_dict["action"] = None + else: + try: + self.action_validator(answer_dict["action"]) + except Exception as error: + raise ParseError(str(error)) return answer_dict @@ -264,10 +294,9 @@ def make_prompt( extra_instruction: Optional[str], preliminary_answer: Optional[dict], ) -> UIPrompt: - system_prompt_part = SystemPromptPart(None) + system_prompt_part = SystemPromptPart() instruction_prompt_part = InstructionPromptPart( - goal_object=obs["goal_object"], - extra_instruction=extra_instruction, + goal_object=obs["goal_object"], extra_instruction=extra_instruction ) if self.use_screenshot: screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"]) @@ -289,12 +318,16 @@ def make_prompt( error_prompt_part = ErrorPromptPart(obs["last_action_error"]) else: error_prompt_part = None - action_prompt_part = ActionPromptPart(self.action_set_args.make_action_set()) + action_set = self.action_set_args.make_action_set() answer_prompt_part = AnswerPromptPart( + action_set_description=action_set.describe( + with_long_description=True, with_examples=False + ), use_abstract_example=self.use_abstract_example, use_concrete_example=self.use_concrete_example, preliminary_answer=preliminary_answer, ) + action_validator = action_set.to_python_code return UIPrompt( system_prompt_part=system_prompt_part, instruction_prompt_part=instruction_prompt_part, @@ -302,6 +335,6 @@ def make_prompt( tabs_prompt_part=tabs_prompt_part, history_prompt_part=history_prompt_part, error_prompt_part=error_prompt_part, - action_prompt_part=action_prompt_part, answer_prompt_part=answer_prompt_part, + action_validator=action_validator, ) From 7767826534d9f2a2d8e87e44cd1e33a4d5efacf3 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Wed, 21 May 2025 15:48:08 +0000 Subject: [PATCH 14/29] update --- src/agentlab/agents/vl_agent/vl_agent.py | 59 +++++++++++--------- src/agentlab/agents/vl_agent/vl_model.py | 68 +++++++++++++++++++++++ src/agentlab/agents/vl_agent/vl_prompt.py | 32 ++++++----- 3 files changed, 119 insertions(+), 40 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent.py index 744d4972..ba21b999 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from agentlab.llm.llm_utils import Discussion, ParseError, retry +from agentlab.llm.llm_utils import ParseError, retry from agentlab.llm.tracking import cost_tracker_decorator from browsergym.experiments.agent import AgentInfo from browsergym.experiments.benchmark import Benchmark -from copy import deepcopy +from browsergym.utils.obs import overlay_som +from copy import copy, deepcopy from dataclasses import asdict, dataclass from typing import Optional from .vl_model import VLModelArgs @@ -77,9 +78,8 @@ def get_action(self, obs: dict) -> tuple[str, dict]: preliminary_answer=None, ) try: - messages = Discussion(ui_prompt.get_messages()) - messages.merge() - preliminary_answer = retry( + messages = ui_prompt.get_messages() + answer = retry( chat=self.main_vl_model, messages=messages, n_retry=self.max_retry, @@ -87,23 +87,21 @@ def get_action(self, obs: dict) -> tuple[str, dict]: ) stats = {"num_main_retries": (len(messages) - 3) // 2} except ParseError: - preliminary_answer = {"thought": None, "action": None} + answer = {"thought": None, "action": None} stats = {"num_main_retries": self.max_retry} stats.update(self.main_vl_model.get_stats()) - if self.auxiliary_vl_model is None: - final_answer = preliminary_answer - else: + if self.auxiliary_vl_model is not None: + preliminary_answer = answer + ui_prompt = self.ui_prompt_args.make_prompt( + obs=obs, + actions=self.actions, + thoughts=self.thoughts, + extra_instructions=None, + preliminary_answer=preliminary_answer, + ) try: - ui_prompt = self.ui_prompt_args.make_prompt( - obs=obs, - actions=self.actions, - thoughts=self.thoughts, - extra_instructions=None, - preliminary_answer=preliminary_answer, - ) - messages = Discussion(ui_prompt.get_messages()) - messages.merge() - final_answer = retry( + messages = ui_prompt.get_messages() + answer = retry( chat=self.auxiliary_vl_model, messages=messages, n_retry=self.max_retry, @@ -111,18 +109,25 @@ def get_action(self, obs: dict) -> tuple[str, dict]: ) stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 except ParseError: - final_answer = preliminary_answer + answer = {"thought": None, "action": None} stats["num_auxiliary_retries"] = self.max_retry stats.update(self.auxiliary_vl_model.get_stats()) - self.thoughts.append(str(final_answer["thought"])) - self.actions.append(str(final_answer["action"])) + else: + preliminary_answer = None + self.thoughts.append(str(answer["thought"])) + self.actions.append(str(answer["action"])) agent_info = AgentInfo( - think=str(final_answer["thought"]), chat_messages=messages, stats=stats + think=str(answer["thought"]), stats=stats, extra_info=preliminary_answer ) - return final_answer["action"], asdict(agent_info) + return answer["action"], asdict(agent_info) @property def obs_preprocessor(self, obs: dict) -> dict: + obs = copy(obs) + if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som: + obs["screenshot"] = overlay_som( + obs["screenshot"], extra_properties=obs["extra_element_properties"] + ) return obs @@ -161,7 +166,7 @@ def set_reproducibility_mode(self): self.auxiliary_vl_model_args.set_reproducibility_mode() def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): - self.ui_prompt_args.obs_flags.use_tabs = benchmark.is_multi_tab - self.ui_prompt_args.action_flags.action_set = deepcopy(benchmark.high_level_action_set_args) + self.ui_prompt_args.use_tabs = benchmark.is_multi_tab + self.ui_prompt_args.action_set_args = deepcopy(benchmark.high_level_action_set_args) if demo_mode: - self.ui_prompt_args.action_flags.action_set.demo_mode = "all_blue" + self.ui_prompt_args.action_set_args.demo_mode = "all_blue" diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py index 8dcb73d5..0b3f7380 100644 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ b/src/agentlab/agents/vl_agent/vl_model.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from accelerate.utils.modeling import load_checkpoint_in_model +from agentlab.llm.llm_utils import AIMessage from dataclasses import dataclass +from openai import AsyncOpenAI, RateLimitError from transformers import AutoProcessor, MllamaForConditionalGeneration from typing import Optional +import asyncio +import backoff import fnmatch import os @@ -105,3 +109,67 @@ def close(self): def set_reproducibility_mode(self): self.reproducibility_config = {"do_sample": False} + + +class OpenRouterAPIModel(VLModel): + def __init__( + self, + base_url: str, + model_id: str, + max_tokens: int, + reproducibility_config: dict, + ): + self.client = AsyncOpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) + self.model_id = model_id + self.max_tokens = max_tokens + self.reproducibility_config = reproducibility_config + + def __call__(self, messages: list[dict]) -> dict: + @backoff.on_exception(backoff.expo, RateLimitError) + async def get_response(messages: list[dict], max_tokens: int, **kwargs): + completion = await self.client.chat.completions.create( + model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs + ) + try: + response = AIMessage(completion.choices[0].message.content) + except: + response = AIMessage("") + return response + + return asyncio.run( + get_response( + messages=messages, max_tokens=self.max_tokens, **self.reproducibility_config + ) + ) + + def get_stats(self) -> dict: + return {} + + +@dataclass +class OpenRouterAPIModelArgs(VLModelArgs): + model_id: str + base_url: str + max_tokens: int + reproducibility_config: dict + + @property + def model_name(self) -> str: + return self.model_id.split("/")[-1].replace("-", "_").replace(".", "") + + def make_model(self) -> OpenRouterAPIModel: + return OpenRouterAPIModel( + model_id=self.model_id, + base_url=self.base_url, + max_tokens=self.max_tokens, + reproducibility_config=self.reproducibility_config, + ) + + def prepare(self): + pass + + def close(self): + pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"temperature": 0.0} diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt.py index 38999ad2..f8df9e6d 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod from agentlab.llm.llm_utils import ( + Discussion, extract_code_blocks, HumanMessage, ParseError, parse_html_tags_raise, SystemMessage, ) +from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.benchmark.base import HighLevelActionSetArgs from dataclasses import dataclass from PIL import Image @@ -22,7 +24,7 @@ def get_message_content(self) -> list[dict]: class VLPrompt(ABC): @abstractmethod - def get_messages(self) -> list[Union[HumanMessage, SystemMessage]]: + def get_messages(self) -> Discussion: raise NotImplementedError @abstractmethod @@ -104,7 +106,7 @@ def __init__( # The Open Tabs of the Browser """ for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): - text += f""" + text += f"""\ ## Tab {index}{' (active tab)' if index == active_page_index else ''} ### Title {title} @@ -123,7 +125,7 @@ def __init__(self, thoughts: list[str], actions: list[str]): # The Previous Steps """ for index, (thought, action) in enumerate(zip(thoughts, actions)): - text += f""" + text += f"""\ ## Step {index} ### Thought {thought} @@ -183,7 +185,7 @@ def __init__( Your answer should include both the thought and the action. """ if use_abstract_example: - text += """ + text += """\ ## An Abstract Example of the Answer The thought about the next action. @@ -193,7 +195,7 @@ def __init__( """ if use_concrete_example: - text += """ + text += """\ ## A Concrete Example of the Answer From previous action I tried to set the value of year to "2022", using select_option, but it doesn't appear to be in the form. \ @@ -204,7 +206,7 @@ def __init__( """ if preliminary_answer is not None: - text += f""" + text += f"""\ ## A Preliminary Answer Here is a preliminary answer, which might be incorrect or inaccurate. \ You can refine it to get your answer. @@ -223,6 +225,7 @@ def get_message_content(self) -> list[dict]: @dataclass class UIPrompt(VLPrompt): + action_set: HighLevelActionSet system_prompt_part: SystemPromptPart instruction_prompt_part: InstructionPromptPart screenshot_prompt_part: Optional[ScreenshotPromptPart] @@ -230,9 +233,8 @@ class UIPrompt(VLPrompt): history_prompt_part: Optional[HistoryPromptPart] error_prompt_part: Optional[ErrorPromptPart] answer_prompt_part: AnswerPromptPart - action_validator: callable - def get_messages(self) -> list[Union[HumanMessage, SystemMessage]]: + def get_messages(self) -> Discussion: system_message_content = self.system_prompt_part.get_message_content() human_message_content = self.instruction_prompt_part.get_message_content() if self.screenshot_prompt_part is not None: @@ -244,7 +246,11 @@ def get_messages(self) -> list[Union[HumanMessage, SystemMessage]]: if self.error_prompt_part is not None: human_message_content.extend(self.error_prompt_part.get_message_content()) human_message_content.extend(self.answer_prompt_part.get_message_content()) - return [SystemMessage(system_message_content), HumanMessage(human_message_content)] + messages = Discussion( + [SystemMessage(system_message_content), HumanMessage(human_message_content)] + ) + messages.merge() + return messages def parse_answer(self, answer_text: str) -> dict: answer_dict = {} @@ -270,7 +276,7 @@ def parse_answer(self, answer_text: str) -> dict: answer_dict["action"] = None else: try: - self.action_validator(answer_dict["action"]) + self.action_set.to_python_code(answer_dict["action"]) except Exception as error: raise ParseError(str(error)) return answer_dict @@ -280,6 +286,7 @@ def parse_answer(self, answer_text: str) -> dict: class UIPromptArgs(VLPromptArgs): action_set_args: HighLevelActionSetArgs use_screenshot: bool + use_screenshot_som: bool use_tabs: bool use_history: bool use_error: bool @@ -294,6 +301,7 @@ def make_prompt( extra_instruction: Optional[str], preliminary_answer: Optional[dict], ) -> UIPrompt: + action_set = self.action_set_args.make_action_set() system_prompt_part = SystemPromptPart() instruction_prompt_part = InstructionPromptPart( goal_object=obs["goal_object"], extra_instruction=extra_instruction @@ -318,7 +326,6 @@ def make_prompt( error_prompt_part = ErrorPromptPart(obs["last_action_error"]) else: error_prompt_part = None - action_set = self.action_set_args.make_action_set() answer_prompt_part = AnswerPromptPart( action_set_description=action_set.describe( with_long_description=True, with_examples=False @@ -327,8 +334,8 @@ def make_prompt( use_concrete_example=self.use_concrete_example, preliminary_answer=preliminary_answer, ) - action_validator = action_set.to_python_code return UIPrompt( + action_set=action_set, system_prompt_part=system_prompt_part, instruction_prompt_part=instruction_prompt_part, screenshot_prompt_part=screenshot_prompt_part, @@ -336,5 +343,4 @@ def make_prompt( history_prompt_part=history_prompt_part, error_prompt_part=error_prompt_part, answer_prompt_part=answer_prompt_part, - action_validator=action_validator, ) From 7fb9d3996cf3453cf5997d90aacf4ce28cd77fb0 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Wed, 21 May 2025 19:44:29 +0000 Subject: [PATCH 15/29] update --- src/agentlab/agents/vl_agent/config.py | 45 +++-- src/agentlab/agents/vl_agent/vl_agent/base.py | 42 +++++ .../{vl_agent.py => vl_agent/ui_agent.py} | 45 +---- src/agentlab/agents/vl_agent/vl_model.py | 175 ------------------ src/agentlab/agents/vl_agent/vl_model/base.py | 35 ++++ .../agents/vl_agent/vl_model/llama_model.py | 75 ++++++++ .../vl_agent/vl_model/openrouter_api_model.py | 71 +++++++ .../agents/vl_agent/vl_prompt/base.py | 32 ++++ .../{vl_prompt.py => vl_prompt/ui_prompt.py} | 70 +++---- 9 files changed, 301 insertions(+), 289 deletions(-) create mode 100644 src/agentlab/agents/vl_agent/vl_agent/base.py rename src/agentlab/agents/vl_agent/{vl_agent.py => vl_agent/ui_agent.py} (83%) delete mode 100644 src/agentlab/agents/vl_agent/vl_model.py create mode 100644 src/agentlab/agents/vl_agent/vl_model/base.py create mode 100644 src/agentlab/agents/vl_agent/vl_model/llama_model.py create mode 100644 src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py create mode 100644 src/agentlab/agents/vl_agent/vl_prompt/base.py rename src/agentlab/agents/vl_agent/{vl_prompt.py => vl_prompt/ui_prompt.py} (88%) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index d29b274e..8bbd583c 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -1,11 +1,17 @@ from browsergym.experiments.benchmark import HighLevelActionSetArgs -from .vl_agent import UIAgentArgs -from .vl_model import LlamaModelArgs -from .vl_prompt import UIPromptArgs -import agentlab.agents.dynamic_prompting as dp +from .vl_agent.ui_agent import UIAgentArgs +from .vl_model.llama_model import LlamaModelArgs +from .vl_model.openrouter_api_model import OpenRouterAPIModelArgs +from .vl_prompt.ui_prompt import UIPromptArgs VL_MODEL_ARGS_DICT = { + "gpt_4o": OpenRouterAPIModelArgs( + base_url="https://openrouter.ai/api/v1", + model_id="openai/gpt-4o-2024-11-20", + max_tokens=8192, + reproducibility_config={"temperature": 0.1}, + ), "llama_32_11b": LlamaModelArgs( model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype="bfloat16", @@ -13,36 +19,27 @@ max_length=32768, max_new_tokens=8192, reproducibility_config={"temperature": 0.1}, - ) + ), } VL_PROMPT_ARGS_DICT = { - "ui_prompt-default": UIPromptArgs( - obs_flags=dp.ObsFlags( - use_tabs=True, - use_error_logs=True, - use_past_error_logs=False, - use_screenshot=True, - use_som=False, - ), - action_flags=dp.ActionFlags( - action_set=HighLevelActionSetArgs(subsets=["coord"]), - long_description=True, - individual_examples=False, - ), - extra_instructions=None, - enable_chat=False, - use_thinking=True, + "ui_prompt": UIPromptArgs( + action_set_args=HighLevelActionSetArgs(subsets=["coord"]), + use_screenshot=True, + use_screenshot_som=False, + use_tabs=True, + use_history=True, + use_error=True, use_abstract_example=True, use_concrete_example=False, ) } VL_AGENT_ARGS_DICT = { - "ui_agent-default": UIAgentArgs( - main_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + "ui_agent": UIAgentArgs( + main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"], auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt-default"], + ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"], max_retry=4, ) } diff --git a/src/agentlab/agents/vl_agent/vl_agent/base.py b/src/agentlab/agents/vl_agent/vl_agent/base.py new file mode 100644 index 00000000..482ce2ce --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_agent/base.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from browsergym.experiments.benchmark import Benchmark +from dataclasses import dataclass + + +class VLAgent(ABC): + @abstractmethod + def get_action(self, obs: dict) -> tuple[str, dict]: + raise NotImplementedError + + @property + @abstractmethod + def obs_preprocessor(self) -> callable: + raise NotImplementedError + + +@dataclass +class VLAgentArgs(ABC): + @property + @abstractmethod + def agent_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def make_agent(self) -> VLAgent: + raise NotImplementedError + + @abstractmethod + def prepare(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def set_reproducibility_mode(self): + raise NotImplementedError + + @abstractmethod + def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py similarity index 83% rename from src/agentlab/agents/vl_agent/vl_agent.py rename to src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index ba21b999..6334c3c1 100644 --- a/src/agentlab/agents/vl_agent/vl_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from agentlab.llm.llm_utils import ParseError, retry from agentlab.llm.tracking import cost_tracker_decorator from browsergym.experiments.agent import AgentInfo @@ -7,47 +6,9 @@ from copy import copy, deepcopy from dataclasses import asdict, dataclass from typing import Optional -from .vl_model import VLModelArgs -from .vl_prompt import UIPromptArgs - - -class VLAgent(ABC): - @abstractmethod - def get_action(self, obs: dict) -> tuple[str, dict]: - raise NotImplementedError - - @property - @abstractmethod - def obs_preprocessor(self) -> callable: - raise NotImplementedError - - -@dataclass -class VLAgentArgs(ABC): - @property - @abstractmethod - def agent_name(self) -> str: - raise NotImplementedError - - @abstractmethod - def make_agent(self) -> VLAgent: - raise NotImplementedError - - @abstractmethod - def prepare(self): - raise NotImplementedError - - @abstractmethod - def close(self): - raise NotImplementedError - - @abstractmethod - def set_reproducibility_mode(self): - raise NotImplementedError - - @abstractmethod - def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): - raise NotImplementedError +from .base import VLAgent, VLAgentArgs +from ..vl_model import VLModelArgs +from ..vl_prompt import UIPromptArgs class UIAgent(VLAgent): diff --git a/src/agentlab/agents/vl_agent/vl_model.py b/src/agentlab/agents/vl_agent/vl_model.py deleted file mode 100644 index 0b3f7380..00000000 --- a/src/agentlab/agents/vl_agent/vl_model.py +++ /dev/null @@ -1,175 +0,0 @@ -from abc import ABC, abstractmethod -from accelerate.utils.modeling import load_checkpoint_in_model -from agentlab.llm.llm_utils import AIMessage -from dataclasses import dataclass -from openai import AsyncOpenAI, RateLimitError -from transformers import AutoProcessor, MllamaForConditionalGeneration -from typing import Optional -import asyncio -import backoff -import fnmatch -import os - - -class VLModel(ABC): - @abstractmethod - def __call__(self, messages: list[dict]) -> dict: - raise NotImplementedError - - @abstractmethod - def get_stats(self) -> dict: - raise NotImplementedError - - -class VLModelArgs(ABC): - @property - @abstractmethod - def model_name(self) -> str: - raise NotImplementedError - - @abstractmethod - def make_model(self) -> VLModel: - raise NotImplementedError - - @abstractmethod - def prepare(self): - raise NotImplementedError - - @abstractmethod - def close(self): - raise NotImplementedError - - @abstractmethod - def set_reproducibility_mode(self): - raise NotImplementedError - - -class LlamaModel(VLModel): - def __init__( - self, - model_path: str, - torch_dtype: str, - checkpoint_dir: str, - max_length: int, - max_new_tokens: int, - reproducibility_config: dict, - ): - self.model = MllamaForConditionalGeneration.from_pretrained( - model_path, torch_dtype=torch_dtype - ) - if checkpoint_dir is not None: - checkpoint_file = None - for item in os.listdir(checkpoint_dir): - if fnmatch.fnmatch(item, "pytorch_model*.bin") or fnmatch.fnmatch( - item, "model*.safetensors" - ): - checkpoint_file = os.path.join(checkpoint_dir, item) - break - load_checkpoint_in_model(self.model, checkpoint_file) - self.processor = AutoProcessor.from_pretrained(model_path) - self.max_length = max_length - self.max_new_tokens = max_new_tokens - self.reproducibility_config = reproducibility_config - - def __call__(self, messages: list[dict]) -> dict: - return {} - - def get_stats(self) -> dict: - return {} - - -@dataclass -class LlamaModelArgs(VLModelArgs): - model_path: str - torch_dtype: str - checkpoint_dir: Optional[str] - max_length: int - max_new_tokens: int - reproducibility_config: dict - - @property - def model_name(self) -> str: - return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") - - def make_model(self) -> LlamaModel: - return LlamaModel( - model_path=self.model_path, - torch_dtype=self.torch_dtype, - checkpoint_dir=self.checkpoint_dir, - max_length=self.max_length, - max_new_tokens=self.max_new_tokens, - reproducibility_config=self.reproducibility_config, - ) - - def prepare(self): - pass - - def close(self): - pass - - def set_reproducibility_mode(self): - self.reproducibility_config = {"do_sample": False} - - -class OpenRouterAPIModel(VLModel): - def __init__( - self, - base_url: str, - model_id: str, - max_tokens: int, - reproducibility_config: dict, - ): - self.client = AsyncOpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) - self.model_id = model_id - self.max_tokens = max_tokens - self.reproducibility_config = reproducibility_config - - def __call__(self, messages: list[dict]) -> dict: - @backoff.on_exception(backoff.expo, RateLimitError) - async def get_response(messages: list[dict], max_tokens: int, **kwargs): - completion = await self.client.chat.completions.create( - model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs - ) - try: - response = AIMessage(completion.choices[0].message.content) - except: - response = AIMessage("") - return response - - return asyncio.run( - get_response( - messages=messages, max_tokens=self.max_tokens, **self.reproducibility_config - ) - ) - - def get_stats(self) -> dict: - return {} - - -@dataclass -class OpenRouterAPIModelArgs(VLModelArgs): - model_id: str - base_url: str - max_tokens: int - reproducibility_config: dict - - @property - def model_name(self) -> str: - return self.model_id.split("/")[-1].replace("-", "_").replace(".", "") - - def make_model(self) -> OpenRouterAPIModel: - return OpenRouterAPIModel( - model_id=self.model_id, - base_url=self.base_url, - max_tokens=self.max_tokens, - reproducibility_config=self.reproducibility_config, - ) - - def prepare(self): - pass - - def close(self): - pass - - def set_reproducibility_mode(self): - self.reproducibility_config = {"temperature": 0.0} diff --git a/src/agentlab/agents/vl_agent/vl_model/base.py b/src/agentlab/agents/vl_agent/vl_model/base.py new file mode 100644 index 00000000..ce188183 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/base.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +from agentlab.llm.llm_utils import AIMessage, Discussion + + +class VLModel(ABC): + @abstractmethod + def __call__(self, messages: Discussion) -> AIMessage: + raise NotImplementedError + + @abstractmethod + def get_stats(self) -> dict: + raise NotImplementedError + + +class VLModelArgs(ABC): + @property + @abstractmethod + def model_name(self) -> str: + raise NotImplementedError + + @abstractmethod + def make_model(self) -> VLModel: + raise NotImplementedError + + @abstractmethod + def prepare(self): + raise NotImplementedError + + @abstractmethod + def close(self): + raise NotImplementedError + + @abstractmethod + def set_reproducibility_mode(self): + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py new file mode 100644 index 00000000..87ecd13e --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -0,0 +1,75 @@ +from accelerate.utils.modeling import load_checkpoint_in_model +from agentlab.llm.llm_utils import AIMessage, Discussion +from dataclasses import dataclass +from transformers import AutoProcessor, MllamaForConditionalGeneration +from typing import Optional +from .base import VLModel, VLModelArgs +import fnmatch +import os + + +class LlamaModel(VLModel): + def __init__( + self, + model_path: str, + torch_dtype: str, + checkpoint_dir: str, + max_length: int, + max_new_tokens: int, + reproducibility_config: dict, + ): + self.model = MllamaForConditionalGeneration.from_pretrained( + model_path, torch_dtype=torch_dtype + ) + if checkpoint_dir is not None: + checkpoint_file = None + for item in os.listdir(checkpoint_dir): + if fnmatch.fnmatch(item, "pytorch_model*.bin") or fnmatch.fnmatch( + item, "model*.safetensors" + ): + checkpoint_file = os.path.join(checkpoint_dir, item) + break + load_checkpoint_in_model(self.model, checkpoint_file) + self.processor = AutoProcessor.from_pretrained(model_path) + self.max_length = max_length + self.max_new_tokens = max_new_tokens + self.reproducibility_config = reproducibility_config + + def __call__(self, messages: Discussion) -> AIMessage: + return AIMessage([{}]) + + def get_stats(self) -> dict: + return {} + + +@dataclass +class LlamaModelArgs(VLModelArgs): + model_path: str + torch_dtype: str + checkpoint_dir: Optional[str] + max_length: int + max_new_tokens: int + reproducibility_config: dict + + @property + def model_name(self) -> str: + return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") + + def make_model(self) -> LlamaModel: + return LlamaModel( + model_path=self.model_path, + torch_dtype=self.torch_dtype, + checkpoint_dir=self.checkpoint_dir, + max_length=self.max_length, + max_new_tokens=self.max_new_tokens, + reproducibility_config=self.reproducibility_config, + ) + + def prepare(self): + pass + + def close(self): + pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"do_sample": False} diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py new file mode 100644 index 00000000..870f0843 --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -0,0 +1,71 @@ +from agentlab.llm.llm_utils import AIMessage, Discussion +from dataclasses import dataclass +from openai import AsyncOpenAI, RateLimitError +from .base import VLModel, VLModelArgs +import asyncio +import backoff +import os + + +class OpenRouterAPIModel(VLModel): + def __init__( + self, + base_url: str, + model_id: str, + max_tokens: int, + reproducibility_config: dict, + ): + self.client = AsyncOpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) + self.model_id = model_id + self.max_tokens = max_tokens + self.reproducibility_config = reproducibility_config + + def __call__(self, messages: Discussion) -> AIMessage: + @backoff.on_exception(backoff.expo, RateLimitError) + async def get_response(messages: list[dict], max_tokens: int, **kwargs): + completion = await self.client.chat.completions.create( + model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs + ) + try: + response = AIMessage(completion.choices[0].message.content) + except: + response = AIMessage("") + return response + + return asyncio.run( + get_response( + messages=messages, max_tokens=self.max_tokens, **self.reproducibility_config + ) + ) + + def get_stats(self) -> dict: + return {} + + +@dataclass +class OpenRouterAPIModelArgs(VLModelArgs): + base_url: str + model_id: str + max_tokens: int + reproducibility_config: dict + + @property + def model_name(self) -> str: + return self.model_id.split("/")[-1].replace("-", "_").replace(".", "") + + def make_model(self) -> OpenRouterAPIModel: + return OpenRouterAPIModel( + base_url=self.base_url, + model_id=self.model_id, + max_tokens=self.max_tokens, + reproducibility_config=self.reproducibility_config, + ) + + def prepare(self): + pass + + def close(self): + pass + + def set_reproducibility_mode(self): + self.reproducibility_config = {"temperature": 0.0} diff --git a/src/agentlab/agents/vl_agent/vl_prompt/base.py b/src/agentlab/agents/vl_agent/vl_prompt/base.py new file mode 100644 index 00000000..414b68ab --- /dev/null +++ b/src/agentlab/agents/vl_agent/vl_prompt/base.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +from agentlab.llm.llm_utils import Discussion +from typing import Optional + + +class VLPromptPart(ABC): + @abstractmethod + def get_message_content(self) -> list[dict]: + raise NotImplementedError + + +class VLPrompt(ABC): + @abstractmethod + def get_messages(self) -> Discussion: + raise NotImplementedError + + @abstractmethod + def parse_answer(self, answer_text: str) -> dict: + raise NotImplementedError + + +class VLPromptArgs(ABC): + @abstractmethod + def make_prompt( + self, + obs: dict, + thoughts: list[str], + actions: list[str], + extra_instruction: Optional[str] = None, + preliminary_answer: Optional[dict] = None, + ) -> VLPrompt: + raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py similarity index 88% rename from src/agentlab/agents/vl_agent/vl_prompt.py rename to src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index f8df9e6d..d6309c17 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from agentlab.llm.llm_utils import ( Discussion, extract_code_blocks, @@ -12,39 +11,11 @@ from dataclasses import dataclass from PIL import Image from typing import Optional, Union -from .utils import image_to_image_url +from .base import VLPrompt, VLPromptArgs, VLPromptPart +from ..utils import image_to_image_url import numpy as np -class VLPromptPart(ABC): - @abstractmethod - def get_message_content(self) -> list[dict]: - raise NotImplementedError - - -class VLPrompt(ABC): - @abstractmethod - def get_messages(self) -> Discussion: - raise NotImplementedError - - @abstractmethod - def parse_answer(self, answer_text: str) -> dict: - raise NotImplementedError - - -class VLPromptArgs(ABC): - @abstractmethod - def make_prompt( - self, - obs: dict, - actions: list[str], - thoughts: list[str], - extra_instruction: Optional[str], - preliminary_answer: Optional[dict], - ) -> VLPrompt: - raise NotImplementedError - - class SystemPromptPart(VLPromptPart): def __init__(self): self.text = """\ @@ -61,7 +32,7 @@ class InstructionPromptPart(VLPromptPart): def __init__( self, goal_object: list[dict], - extra_instruction: Optional[str], + extra_instruction: Optional[str] = None, ): text = """\ # Instruction @@ -139,20 +110,23 @@ def get_message_content(self) -> list[dict]: class ErrorPromptPart(VLPromptPart): - def __init__(self, last_action_error: str): + def __init__( + self, + last_action_error: str, + logs_separator: str = "Call log:", + logs_limit: int = 5, + ): text = """\ -# The Error from the Last Action +# The Error from Last Action """ - separator = "Call log:" - if separator in last_action_error: - error, logs = last_action_error.split(separator) - error = error.strip() - logs = logs.split("\n") + if logs_separator in last_action_error: + error, logs = last_action_error.split(logs_separator) + logs = logs.split("\n")[:logs_limit] text += f"""\ {error} -{separator} +{logs_separator} """ - for log in logs[:10]: + for log in logs: text += f"""\ {log} """ @@ -172,7 +146,7 @@ def __init__( action_set_description: str, use_abstract_example: bool, use_concrete_example: bool, - preliminary_answer: Optional[dict], + preliminary_answer: Optional[dict] = None, ): text = f"""\ # Answer Requirements @@ -186,17 +160,17 @@ def __init__( """ if use_abstract_example: text += """\ -## An Abstract Example of the Answer +### An Abstract Example of the Answer The thought about the next action. -The next action. +The next action to take. """ if use_concrete_example: text += """\ -## A Concrete Example of the Answer +### A Concrete Example of the Answer From previous action I tried to set the value of year to "2022", using select_option, but it doesn't appear to be in the form. \ It may be a dynamic dropdown, I will try using click with the bid "a324" and look at the response from the page. @@ -296,10 +270,10 @@ class UIPromptArgs(VLPromptArgs): def make_prompt( self, obs: dict, - actions: list[str], thoughts: list[str], - extra_instruction: Optional[str], - preliminary_answer: Optional[dict], + actions: list[str], + extra_instruction: Optional[str] = None, + preliminary_answer: Optional[dict] = None, ) -> UIPrompt: action_set = self.action_set_args.make_action_set() system_prompt_part = SystemPromptPart() From 8a727fdc43f1ef75c21cd3ee52c8f5cacaa87c25 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 22 May 2025 02:53:45 +0000 Subject: [PATCH 16/29] update --- src/agentlab/agents/vl_agent/utils.py | 8 ++++ .../agents/vl_agent/vl_model/llama_model.py | 37 ++++++++++++++++++- .../vl_agent/vl_model/openrouter_api_model.py | 7 ++-- 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/src/agentlab/agents/vl_agent/utils.py b/src/agentlab/agents/vl_agent/utils.py index e23714e8..ad704c0b 100644 --- a/src/agentlab/agents/vl_agent/utils.py +++ b/src/agentlab/agents/vl_agent/utils.py @@ -14,3 +14,11 @@ def image_to_image_url(image: Union[Image.Image, np.ndarray]): image.save(buffer, format="JPEG") image_base64 = base64.b64encode(buffer.getvalue()).decode() return f"data:image/jpeg;base64,{image_base64}" + + +def image_url_to_image(image_url: str) -> Image.Image: + image_base64 = image_url.replace("data:image/jpeg;base64,", "") + image_data = base64.b64decode(image_base64.encode()) + buffer = io.BytesIO(image_data) + image = Image.open(buffer) + return image diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py index 87ecd13e..6d1c1ae7 100644 --- a/src/agentlab/agents/vl_agent/vl_model/llama_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -4,6 +4,7 @@ from transformers import AutoProcessor, MllamaForConditionalGeneration from typing import Optional from .base import VLModel, VLModelArgs +from ..utils import image_url_to_image import fnmatch import os @@ -36,7 +37,41 @@ def __init__( self.reproducibility_config = reproducibility_config def __call__(self, messages: Discussion) -> AIMessage: - return AIMessage([{}]) + input_messages = [] + input_images = [] + for message in messages: + input_message = {"role": message["role"], "content": []} + for item in message["content"]: + if item["type"] == "text": + input_message["content"].append(item) + elif item["type"] == "image_url": + input_message["content"].append({"type": "image"}) + input_images.append(image_url_to_image(item["image_url"]["url"])) + input_messages.append(input_message) + input_text = self.processor.apply_chat_template( + input_messages, add_generation_prompt=True, tokenize=False + ) + input = self.processor( + images=input_images, + text=input_text, + add_special_tokens=False, + return_tensors="pt", + truncation=True, + max_length=self.max_length, + ).to(self.model.device) + output = self.model.generate( + **input, + eos_token_id=self.processor.tokenizer.eos_token_id, + max_new_tokens=self.max_new_tokens, + use_cache=True, + **self.reproducibility_config, + ) + output_text = self.processor.tokenizer.batch_decode( + output[:, input["input_ids"].shape[1] :], + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + )[0] + return AIMessage([{"type": "text", "text": output_text}]) def get_stats(self) -> dict: return {} diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py index 870f0843..6625e0fc 100644 --- a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -27,16 +27,17 @@ async def get_response(messages: list[dict], max_tokens: int, **kwargs): model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs ) try: - response = AIMessage(completion.choices[0].message.content) + response = completion.choices[0].message.content except: - response = AIMessage("") + response = "" return response - return asyncio.run( + response = asyncio.run( get_response( messages=messages, max_tokens=self.max_tokens, **self.reproducibility_config ) ) + return AIMessage([{"type": "text", "text": response}]) def get_stats(self) -> dict: return {} From acfd6a898a15fd2d4ce98bc7c45239bb9cde6f1e Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Thu, 22 May 2025 04:34:54 +0000 Subject: [PATCH 17/29] update --- .../agents/vl_agent/vl_model/llama_model.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py index 6d1c1ae7..819f2e00 100644 --- a/src/agentlab/agents/vl_agent/vl_model/llama_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -41,12 +41,15 @@ def __call__(self, messages: Discussion) -> AIMessage: input_images = [] for message in messages: input_message = {"role": message["role"], "content": []} - for item in message["content"]: - if item["type"] == "text": - input_message["content"].append(item) - elif item["type"] == "image_url": - input_message["content"].append({"type": "image"}) - input_images.append(image_url_to_image(item["image_url"]["url"])) + if isinstance(message["content"], str): + input_message["content"].append({"type": "text", "text": message["content"]}) + else: + for item in message["content"]: + if item["type"] == "text": + input_message["content"].append(item) + elif item["type"] == "image_url": + input_message["content"].append({"type": "image"}) + input_images.append(image_url_to_image(item["image_url"]["url"])) input_messages.append(input_message) input_text = self.processor.apply_chat_template( input_messages, add_generation_prompt=True, tokenize=False From e2624c1c2e9d32e459c8f229fef7d897a11f53ab Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Fri, 23 May 2025 02:37:09 +0000 Subject: [PATCH 18/29] update --- src/agentlab/agents/vl_agent/utils.py | 17 ++++- src/agentlab/agents/vl_agent/vl_model/base.py | 3 +- .../agents/vl_agent/vl_model/llama_model.py | 66 +++++++++++-------- 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/agentlab/agents/vl_agent/utils.py b/src/agentlab/agents/vl_agent/utils.py index ad704c0b..36db2dd9 100644 --- a/src/agentlab/agents/vl_agent/utils.py +++ b/src/agentlab/agents/vl_agent/utils.py @@ -1,4 +1,7 @@ +from accelerate import dispatch_model, infer_auto_device_map +from accelerate.utils.modeling import get_balanced_memory from PIL import Image +from torch.nn import Module from typing import Union import base64 import io @@ -13,7 +16,8 @@ def image_to_image_url(image: Union[Image.Image, np.ndarray]): buffer = io.BytesIO() image.save(buffer, format="JPEG") image_base64 = base64.b64encode(buffer.getvalue()).decode() - return f"data:image/jpeg;base64,{image_base64}" + image_url = f"data:image/jpeg;base64,{image_base64}" + return image_url def image_url_to_image(image_url: str) -> Image.Image: @@ -22,3 +26,14 @@ def image_url_to_image(image_url: str) -> Image.Image: buffer = io.BytesIO(image_data) image = Image.open(buffer) return image + + +def auto_dispatch_model(model: Module, no_split_module_classes: list[str]) -> Module: + max_memory = get_balanced_memory(model, no_split_module_classes=no_split_module_classes) + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + ) + model = dispatch_model(model, device_map=device_map) + return model diff --git a/src/agentlab/agents/vl_agent/vl_model/base.py b/src/agentlab/agents/vl_agent/vl_model/base.py index ce188183..266d3331 100644 --- a/src/agentlab/agents/vl_agent/vl_model/base.py +++ b/src/agentlab/agents/vl_agent/vl_model/base.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod from agentlab.llm.llm_utils import AIMessage, Discussion +from torch.nn import Module -class VLModel(ABC): +class VLModel(ABC, Module): @abstractmethod def __call__(self, messages: Discussion) -> AIMessage: raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py index 819f2e00..f15215f5 100644 --- a/src/agentlab/agents/vl_agent/vl_model/llama_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -1,12 +1,11 @@ +from accelerate import Accelerator from accelerate.utils.modeling import load_checkpoint_in_model from agentlab.llm.llm_utils import AIMessage, Discussion from dataclasses import dataclass from transformers import AutoProcessor, MllamaForConditionalGeneration from typing import Optional from .base import VLModel, VLModelArgs -from ..utils import image_url_to_image -import fnmatch -import os +from ..utils import auto_dispatch_model, image_url_to_image class LlamaModel(VLModel): @@ -14,27 +13,19 @@ def __init__( self, model_path: str, torch_dtype: str, - checkpoint_dir: str, + accelerator_config: dict, + reproducibility_config: dict, max_length: int, max_new_tokens: int, - reproducibility_config: dict, ): self.model = MllamaForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch_dtype ) - if checkpoint_dir is not None: - checkpoint_file = None - for item in os.listdir(checkpoint_dir): - if fnmatch.fnmatch(item, "pytorch_model*.bin") or fnmatch.fnmatch( - item, "model*.safetensors" - ): - checkpoint_file = os.path.join(checkpoint_dir, item) - break - load_checkpoint_in_model(self.model, checkpoint_file) self.processor = AutoProcessor.from_pretrained(model_path) + self.accelerator = Accelerator(**accelerator_config) + self.reproducibility_config = reproducibility_config self.max_length = max_length self.max_new_tokens = max_new_tokens - self.reproducibility_config = reproducibility_config def __call__(self, messages: Discussion) -> AIMessage: input_messages = [] @@ -62,13 +53,14 @@ def __call__(self, messages: Discussion) -> AIMessage: truncation=True, max_length=self.max_length, ).to(self.model.device) - output = self.model.generate( - **input, - eos_token_id=self.processor.tokenizer.eos_token_id, - max_new_tokens=self.max_new_tokens, - use_cache=True, - **self.reproducibility_config, - ) + with self.accelerator.autocast(): + output = self.model.generate( + **input, + eos_token_id=self.processor.tokenizer.eos_token_id, + max_new_tokens=self.max_new_tokens, + use_cache=True, + **self.reproducibility_config, + ) output_text = self.processor.tokenizer.batch_decode( output[:, input["input_ids"].shape[1] :], skip_special_tokens=True, @@ -84,24 +76,44 @@ def get_stats(self) -> dict: class LlamaModelArgs(VLModelArgs): model_path: str torch_dtype: str - checkpoint_dir: Optional[str] + accelerator_config: dict + reproducibility_config: dict max_length: int max_new_tokens: int - reproducibility_config: dict + checkpoint_file: Optional[str] + device: Optional[str] @property def model_name(self) -> str: return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") def make_model(self) -> LlamaModel: - return LlamaModel( + llama_model = LlamaModel( model_path=self.model_path, torch_dtype=self.torch_dtype, - checkpoint_dir=self.checkpoint_dir, + accelerator_config=self.accelerator_config, + reproducibility_config=self.reproducibility_config, max_length=self.max_length, max_new_tokens=self.max_new_tokens, - reproducibility_config=self.reproducibility_config, ) + if self.checkpoint_file is not None: + load_checkpoint_in_model(llama_model.model, checkpoint=self.checkpoint_file) + if self.device is None: + layer_classes = set() + for layer in llama_model.model.language_model.model.layers: + layer_classes.add(layer.__class__) + for layer in llama_model.model.vision_model.transformer.layers: + layer_classes.add(layer.__class__) + for layer in llama_model.model.vision_model.global_transformer.layers: + layer_classes.add(layer.__class__) + llama_model.model = auto_dispatch_model( + llama_model.model, + no_split_module_classes=[layer_class.__name__ for layer_class in layer_classes], + ) + else: + llama_model.model = llama_model.model.to(self.device) + llama_model.model.eval() + return llama_model def prepare(self): pass From 41cbc3bda80c93fd7f807ddde15d328e3d754c7f Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Fri, 23 May 2025 16:16:29 +0000 Subject: [PATCH 19/29] update --- src/agentlab/agents/vl_agent/config.py | 6 ++++-- .../agents/vl_agent/vl_model/openrouter_api_model.py | 6 ++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index 8bbd583c..cacca8e4 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -15,10 +15,12 @@ "llama_32_11b": LlamaModelArgs( model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", torch_dtype="bfloat16", - checkpoint_dir=None, + accelerator_config={"mixed_precision": "bf16", "cpu": False}, + reproducibility_config={"temperature": 0.1}, max_length=32768, max_new_tokens=8192, - reproducibility_config={"temperature": 0.1}, + checkpoint_file=None, + device=None, ), } diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py index 6625e0fc..56db8a7b 100644 --- a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -22,7 +22,7 @@ def __init__( def __call__(self, messages: Discussion) -> AIMessage: @backoff.on_exception(backoff.expo, RateLimitError) - async def get_response(messages: list[dict], max_tokens: int, **kwargs): + async def get_response(messages, max_tokens, **kwargs): completion = await self.client.chat.completions.create( model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs ) @@ -33,9 +33,7 @@ async def get_response(messages: list[dict], max_tokens: int, **kwargs): return response response = asyncio.run( - get_response( - messages=messages, max_tokens=self.max_tokens, **self.reproducibility_config - ) + get_response(messages, self.max_tokens, **self.reproducibility_config) ) return AIMessage([{"type": "text", "text": response}]) From 77ed33b35b9124d1c1bd55b05727fd2f3e613526 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 24 May 2025 01:00:57 +0000 Subject: [PATCH 20/29] update --- src/agentlab/agents/vl_agent/vl_agent/base.py | 3 +-- src/agentlab/agents/vl_agent/vl_agent/ui_agent.py | 5 ++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent/base.py b/src/agentlab/agents/vl_agent/vl_agent/base.py index 482ce2ce..a9f82119 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/base.py +++ b/src/agentlab/agents/vl_agent/vl_agent/base.py @@ -8,9 +8,8 @@ class VLAgent(ABC): def get_action(self, obs: dict) -> tuple[str, dict]: raise NotImplementedError - @property @abstractmethod - def obs_preprocessor(self) -> callable: + def obs_preprocessor(self, obs: dict) -> dict: raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index 6334c3c1..bca99f0f 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -7,8 +7,8 @@ from dataclasses import asdict, dataclass from typing import Optional from .base import VLAgent, VLAgentArgs -from ..vl_model import VLModelArgs -from ..vl_prompt import UIPromptArgs +from ..vl_model.base import VLModelArgs +from ..vl_prompt.ui_prompt import UIPromptArgs class UIAgent(VLAgent): @@ -82,7 +82,6 @@ def get_action(self, obs: dict) -> tuple[str, dict]: ) return answer["action"], asdict(agent_info) - @property def obs_preprocessor(self, obs: dict) -> dict: obs = copy(obs) if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som: From fd559585cd390a4d86d1106467ed7e48ec078255 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 24 May 2025 01:53:35 +0000 Subject: [PATCH 21/29] update --- src/agentlab/agents/vl_agent/vl_agent/ui_agent.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index bca99f0f..5dff4896 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -26,15 +26,15 @@ def __init__( self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() self.ui_prompt_args = ui_prompt_args self.max_retry = max_retry - self.actions = [] self.thoughts = [] + self.actions = [] @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: ui_prompt = self.ui_prompt_args.make_prompt( obs=obs, - actions=self.actions, thoughts=self.thoughts, + actions=self.actions, extra_instructions=None, preliminary_answer=None, ) @@ -55,8 +55,8 @@ def get_action(self, obs: dict) -> tuple[str, dict]: preliminary_answer = answer ui_prompt = self.ui_prompt_args.make_prompt( obs=obs, - actions=self.actions, thoughts=self.thoughts, + actions=self.actions, extra_instructions=None, preliminary_answer=preliminary_answer, ) @@ -94,13 +94,16 @@ def obs_preprocessor(self, obs: dict) -> dict: @dataclass class UIAgentArgs(VLAgentArgs): main_vl_model_args: VLModelArgs - auxiliary_vl_model_args: VLModelArgs + auxiliary_vl_model_args: Optional[VLModelArgs] ui_prompt_args: UIPromptArgs max_retry: int @property def agent_name(self) -> str: - return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" + if self.auxiliary_vl_model_args is None: + return f"UIAgent-{self.main_vl_model_args.model_name}" + else: + return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" def make_agent(self) -> UIAgent: return UIAgent( From 03d777c0ecb2000178124d8ff17e56fefe27da60 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sat, 24 May 2025 21:02:03 +0000 Subject: [PATCH 22/29] update --- src/agentlab/agents/vl_agent/config.py | 2 +- src/agentlab/agents/vl_agent/vl_agent/ui_agent.py | 11 +++++++++-- src/agentlab/agents/vl_agent/vl_prompt/base.py | 2 ++ src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py | 10 ++++------ 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index cacca8e4..ba3656ad 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -26,7 +26,6 @@ VL_PROMPT_ARGS_DICT = { "ui_prompt": UIPromptArgs( - action_set_args=HighLevelActionSetArgs(subsets=["coord"]), use_screenshot=True, use_screenshot_som=False, use_tabs=True, @@ -41,6 +40,7 @@ "ui_agent": UIAgentArgs( main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"], auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], + action_set_args=HighLevelActionSetArgs(subsets=["coord"]), ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"], max_retry=4, ) diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index 5dff4896..5bb970b0 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -2,6 +2,7 @@ from agentlab.llm.tracking import cost_tracker_decorator from browsergym.experiments.agent import AgentInfo from browsergym.experiments.benchmark import Benchmark +from browsergym.experiments.benchmark.base import HighLevelActionSetArgs from browsergym.utils.obs import overlay_som from copy import copy, deepcopy from dataclasses import asdict, dataclass @@ -16,6 +17,7 @@ def __init__( self, main_vl_model_args: VLModelArgs, auxiliary_vl_model_args: Optional[VLModelArgs], + action_set_args: HighLevelActionSetArgs, ui_prompt_args: UIPromptArgs, max_retry: int, ): @@ -24,6 +26,7 @@ def __init__( self.auxiliary_vl_model = None else: self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() + self.action_set = action_set_args.make_action_set() self.ui_prompt_args = ui_prompt_args self.max_retry = max_retry self.thoughts = [] @@ -35,6 +38,7 @@ def get_action(self, obs: dict) -> tuple[str, dict]: obs=obs, thoughts=self.thoughts, actions=self.actions, + action_set=self.action_set, extra_instructions=None, preliminary_answer=None, ) @@ -57,6 +61,7 @@ def get_action(self, obs: dict) -> tuple[str, dict]: obs=obs, thoughts=self.thoughts, actions=self.actions, + action_set=self.action_set, extra_instructions=None, preliminary_answer=preliminary_answer, ) @@ -95,6 +100,7 @@ def obs_preprocessor(self, obs: dict) -> dict: class UIAgentArgs(VLAgentArgs): main_vl_model_args: VLModelArgs auxiliary_vl_model_args: Optional[VLModelArgs] + action_set_args: HighLevelActionSetArgs ui_prompt_args: UIPromptArgs max_retry: int @@ -109,6 +115,7 @@ def make_agent(self) -> UIAgent: return UIAgent( main_vl_model_args=self.main_vl_model_args, auxiliary_vl_model_args=self.auxiliary_vl_model_args, + action_set_args=self.action_set_args, ui_prompt_args=self.ui_prompt_args, max_retry=self.max_retry, ) @@ -130,6 +137,6 @@ def set_reproducibility_mode(self): def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): self.ui_prompt_args.use_tabs = benchmark.is_multi_tab - self.ui_prompt_args.action_set_args = deepcopy(benchmark.high_level_action_set_args) + self.action_set_args = deepcopy(benchmark.high_level_action_set_args) if demo_mode: - self.ui_prompt_args.action_set_args.demo_mode = "all_blue" + self.action_set_args.demo_mode = "all_blue" diff --git a/src/agentlab/agents/vl_agent/vl_prompt/base.py b/src/agentlab/agents/vl_agent/vl_prompt/base.py index 414b68ab..3658ef4d 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/base.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from agentlab.llm.llm_utils import Discussion +from browsergym.core.action.highlevel import HighLevelActionSet from typing import Optional @@ -26,6 +27,7 @@ def make_prompt( obs: dict, thoughts: list[str], actions: list[str], + action_set: HighLevelActionSet, extra_instruction: Optional[str] = None, preliminary_answer: Optional[dict] = None, ) -> VLPrompt: diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index d6309c17..8152ed32 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -7,7 +7,6 @@ SystemMessage, ) from browsergym.core.action.highlevel import HighLevelActionSet -from browsergym.experiments.benchmark.base import HighLevelActionSetArgs from dataclasses import dataclass from PIL import Image from typing import Optional, Union @@ -199,7 +198,6 @@ def get_message_content(self) -> list[dict]: @dataclass class UIPrompt(VLPrompt): - action_set: HighLevelActionSet system_prompt_part: SystemPromptPart instruction_prompt_part: InstructionPromptPart screenshot_prompt_part: Optional[ScreenshotPromptPart] @@ -207,6 +205,7 @@ class UIPrompt(VLPrompt): history_prompt_part: Optional[HistoryPromptPart] error_prompt_part: Optional[ErrorPromptPart] answer_prompt_part: AnswerPromptPart + action_validator: callable def get_messages(self) -> Discussion: system_message_content = self.system_prompt_part.get_message_content() @@ -250,7 +249,7 @@ def parse_answer(self, answer_text: str) -> dict: answer_dict["action"] = None else: try: - self.action_set.to_python_code(answer_dict["action"]) + self.action_validator(answer_dict["action"]) except Exception as error: raise ParseError(str(error)) return answer_dict @@ -258,7 +257,6 @@ def parse_answer(self, answer_text: str) -> dict: @dataclass class UIPromptArgs(VLPromptArgs): - action_set_args: HighLevelActionSetArgs use_screenshot: bool use_screenshot_som: bool use_tabs: bool @@ -272,10 +270,10 @@ def make_prompt( obs: dict, thoughts: list[str], actions: list[str], + action_set: HighLevelActionSet, extra_instruction: Optional[str] = None, preliminary_answer: Optional[dict] = None, ) -> UIPrompt: - action_set = self.action_set_args.make_action_set() system_prompt_part = SystemPromptPart() instruction_prompt_part = InstructionPromptPart( goal_object=obs["goal_object"], extra_instruction=extra_instruction @@ -309,7 +307,6 @@ def make_prompt( preliminary_answer=preliminary_answer, ) return UIPrompt( - action_set=action_set, system_prompt_part=system_prompt_part, instruction_prompt_part=instruction_prompt_part, screenshot_prompt_part=screenshot_prompt_part, @@ -317,4 +314,5 @@ def make_prompt( history_prompt_part=history_prompt_part, error_prompt_part=error_prompt_part, answer_prompt_part=answer_prompt_part, + action_validator=action_set.to_python_code, ) From 7d19c074e333bd68f815430d626d05044e264fa6 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sun, 25 May 2025 00:56:58 +0000 Subject: [PATCH 23/29] update --- src/agentlab/agents/vl_agent/config.py | 1 + src/agentlab/agents/vl_agent/vl_agent/ui_agent.py | 8 +------- src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index ba3656ad..efbff5ba 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -33,6 +33,7 @@ use_error=True, use_abstract_example=True, use_concrete_example=False, + extra_instruction=None, ) } diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index 5bb970b0..0c49a085 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -35,12 +35,7 @@ def __init__( @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: ui_prompt = self.ui_prompt_args.make_prompt( - obs=obs, - thoughts=self.thoughts, - actions=self.actions, - action_set=self.action_set, - extra_instructions=None, - preliminary_answer=None, + obs=obs, thoughts=self.thoughts, actions=self.actions, action_set=self.action_set ) try: messages = ui_prompt.get_messages() @@ -62,7 +57,6 @@ def get_action(self, obs: dict) -> tuple[str, dict]: thoughts=self.thoughts, actions=self.actions, action_set=self.action_set, - extra_instructions=None, preliminary_answer=preliminary_answer, ) try: diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index 8152ed32..487874e2 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -264,6 +264,7 @@ class UIPromptArgs(VLPromptArgs): use_error: bool use_abstract_example: bool use_concrete_example: bool + extra_instruction: Optional[str] def make_prompt( self, @@ -271,12 +272,11 @@ def make_prompt( thoughts: list[str], actions: list[str], action_set: HighLevelActionSet, - extra_instruction: Optional[str] = None, preliminary_answer: Optional[dict] = None, ) -> UIPrompt: system_prompt_part = SystemPromptPart() instruction_prompt_part = InstructionPromptPart( - goal_object=obs["goal_object"], extra_instruction=extra_instruction + goal_object=obs["goal_object"], extra_instruction=self.extra_instruction ) if self.use_screenshot: screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"]) From 294e526f4e0ef4078d810923f0700649964323b4 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Sun, 25 May 2025 22:13:09 +0000 Subject: [PATCH 24/29] update --- src/agentlab/agents/vl_agent/main.py | 33 +++++++++++++++++++ src/agentlab/agents/vl_agent/vl_model/base.py | 3 +- .../vl_agent/vl_model/openrouter_api_model.py | 13 +++----- 3 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 src/agentlab/agents/vl_agent/main.py diff --git a/src/agentlab/agents/vl_agent/main.py b/src/agentlab/agents/vl_agent/main.py new file mode 100644 index 00000000..eeb1184c --- /dev/null +++ b/src/agentlab/agents/vl_agent/main.py @@ -0,0 +1,33 @@ +from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT +from agentlab.experiments.study import Study +import logging +import os + + +logging.getLogger().setLevel(logging.INFO) + +vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]] +benchmark = "miniwob" +os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/" +reproducibility_mode = False +relaunch = False +n_jobs = 1 + + +if __name__ == "__main__": + if reproducibility_mode: + for vl_agent_args in vl_agent_args_list: + vl_agent_args.set_reproducibility_mode() + if relaunch: + study = Study.load_most_recent(contains=None) + study.find_incomplete(include_errors=True) + else: + study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING) + study.run( + n_jobs=n_jobs, + parallel_backend="sequential", + strict_reproducibility=reproducibility_mode, + n_relaunch=3, + ) + if reproducibility_mode: + study.append_to_journal(strict_reproducibility=True) diff --git a/src/agentlab/agents/vl_agent/vl_model/base.py b/src/agentlab/agents/vl_agent/vl_model/base.py index 266d3331..ce188183 100644 --- a/src/agentlab/agents/vl_agent/vl_model/base.py +++ b/src/agentlab/agents/vl_agent/vl_model/base.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from agentlab.llm.llm_utils import AIMessage, Discussion -from torch.nn import Module -class VLModel(ABC, Module): +class VLModel(ABC): @abstractmethod def __call__(self, messages: Discussion) -> AIMessage: raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py index 56db8a7b..896cfcaa 100644 --- a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -1,8 +1,7 @@ from agentlab.llm.llm_utils import AIMessage, Discussion from dataclasses import dataclass -from openai import AsyncOpenAI, RateLimitError +from openai import OpenAI, RateLimitError from .base import VLModel, VLModelArgs -import asyncio import backoff import os @@ -15,15 +14,15 @@ def __init__( max_tokens: int, reproducibility_config: dict, ): - self.client = AsyncOpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) + self.client = OpenAI(base_url=base_url, api_key=os.getenv("OPENROUTER_API_KEY")) self.model_id = model_id self.max_tokens = max_tokens self.reproducibility_config = reproducibility_config def __call__(self, messages: Discussion) -> AIMessage: @backoff.on_exception(backoff.expo, RateLimitError) - async def get_response(messages, max_tokens, **kwargs): - completion = await self.client.chat.completions.create( + def get_response(messages, max_tokens, **kwargs): + completion = self.client.chat.completions.create( model=self.model_id, messages=messages, max_tokens=max_tokens, **kwargs ) try: @@ -32,9 +31,7 @@ async def get_response(messages, max_tokens, **kwargs): response = "" return response - response = asyncio.run( - get_response(messages, self.max_tokens, **self.reproducibility_config) - ) + response = get_response(messages, self.max_tokens, **self.reproducibility_config) return AIMessage([{"type": "text", "text": response}]) def get_stats(self) -> dict: From a7b9999a2b82461751aba0c6ee2c01233339689c Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Mon, 26 May 2025 03:09:01 +0000 Subject: [PATCH 25/29] update --- src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index 487874e2..1abdecd6 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -225,7 +225,8 @@ def get_messages(self) -> Discussion: messages.merge() return messages - def parse_answer(self, answer_text: str) -> dict: + def parse_answer(self, answer_content: list[dict]) -> dict: + answer_text = answer_content[0]["text"] answer_dict = {} try: answer_dict.update( From 30bc57b9599c9d56208fd7d55d0e5a7b639568b9 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Wed, 28 May 2025 05:15:09 +0000 Subject: [PATCH 26/29] update --- src/agentlab/agents/vl_agent/config.py | 2 +- src/agentlab/agents/vl_agent/vl_agent/base.py | 6 + .../agents/vl_agent/vl_agent/ui_agent.py | 19 ++- .../agents/vl_agent/vl_model/llama_model.py | 7 +- .../vl_agent/vl_model/openrouter_api_model.py | 5 +- .../agents/vl_agent/vl_prompt/ui_prompt.py | 133 ++++++------------ 6 files changed, 75 insertions(+), 97 deletions(-) diff --git a/src/agentlab/agents/vl_agent/config.py b/src/agentlab/agents/vl_agent/config.py index efbff5ba..c818e60a 100644 --- a/src/agentlab/agents/vl_agent/config.py +++ b/src/agentlab/agents/vl_agent/config.py @@ -41,8 +41,8 @@ "ui_agent": UIAgentArgs( main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"], auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], - action_set_args=HighLevelActionSetArgs(subsets=["coord"]), ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"], + action_set_args=HighLevelActionSetArgs(subsets=["coord"]), max_retry=4, ) } diff --git a/src/agentlab/agents/vl_agent/vl_agent/base.py b/src/agentlab/agents/vl_agent/vl_agent/base.py index a9f82119..8dde21a7 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/base.py +++ b/src/agentlab/agents/vl_agent/vl_agent/base.py @@ -1,9 +1,15 @@ from abc import ABC, abstractmethod +from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.benchmark import Benchmark from dataclasses import dataclass class VLAgent(ABC): + @property + @abstractmethod + def action_set(self) -> HighLevelActionSet: + raise NotImplementedError + @abstractmethod def get_action(self, obs: dict) -> tuple[str, dict]: raise NotImplementedError diff --git a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py index 0c49a085..c850894e 100644 --- a/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py +++ b/src/agentlab/agents/vl_agent/vl_agent/ui_agent.py @@ -1,11 +1,13 @@ from agentlab.llm.llm_utils import ParseError, retry from agentlab.llm.tracking import cost_tracker_decorator +from browsergym.core.action.highlevel import HighLevelActionSet from browsergym.experiments.agent import AgentInfo from browsergym.experiments.benchmark import Benchmark from browsergym.experiments.benchmark.base import HighLevelActionSetArgs from browsergym.utils.obs import overlay_som from copy import copy, deepcopy from dataclasses import asdict, dataclass +from functools import cache from typing import Optional from .base import VLAgent, VLAgentArgs from ..vl_model.base import VLModelArgs @@ -17,8 +19,8 @@ def __init__( self, main_vl_model_args: VLModelArgs, auxiliary_vl_model_args: Optional[VLModelArgs], - action_set_args: HighLevelActionSetArgs, ui_prompt_args: UIPromptArgs, + action_set_args: HighLevelActionSetArgs, max_retry: int, ): self.main_vl_model = main_vl_model_args.make_model() @@ -26,12 +28,17 @@ def __init__( self.auxiliary_vl_model = None else: self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() - self.action_set = action_set_args.make_action_set() self.ui_prompt_args = ui_prompt_args + self.action_set_args = action_set_args self.max_retry = max_retry self.thoughts = [] self.actions = [] + @property + @cache + def action_set(self) -> HighLevelActionSet: + return self.action_set_args.make_action_set() + @cost_tracker_decorator def get_action(self, obs: dict) -> tuple[str, dict]: ui_prompt = self.ui_prompt_args.make_prompt( @@ -94,11 +101,12 @@ def obs_preprocessor(self, obs: dict) -> dict: class UIAgentArgs(VLAgentArgs): main_vl_model_args: VLModelArgs auxiliary_vl_model_args: Optional[VLModelArgs] - action_set_args: HighLevelActionSetArgs ui_prompt_args: UIPromptArgs + action_set_args: HighLevelActionSetArgs max_retry: int @property + @cache def agent_name(self) -> str: if self.auxiliary_vl_model_args is None: return f"UIAgent-{self.main_vl_model_args.model_name}" @@ -106,13 +114,14 @@ def agent_name(self) -> str: return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" def make_agent(self) -> UIAgent: - return UIAgent( + self.ui_agent = UIAgent( main_vl_model_args=self.main_vl_model_args, auxiliary_vl_model_args=self.auxiliary_vl_model_args, - action_set_args=self.action_set_args, ui_prompt_args=self.ui_prompt_args, + action_set_args=self.action_set_args, max_retry=self.max_retry, ) + return self.ui_agent def prepare(self): self.main_vl_model_args.prepare() diff --git a/src/agentlab/agents/vl_agent/vl_model/llama_model.py b/src/agentlab/agents/vl_agent/vl_model/llama_model.py index f15215f5..968877df 100644 --- a/src/agentlab/agents/vl_agent/vl_model/llama_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/llama_model.py @@ -2,6 +2,7 @@ from accelerate.utils.modeling import load_checkpoint_in_model from agentlab.llm.llm_utils import AIMessage, Discussion from dataclasses import dataclass +from functools import cache from transformers import AutoProcessor, MllamaForConditionalGeneration from typing import Optional from .base import VLModel, VLModelArgs @@ -84,6 +85,7 @@ class LlamaModelArgs(VLModelArgs): device: Optional[str] @property + @cache def model_name(self) -> str: return self.model_path.split("/")[-1].replace("-", "_").replace(".", "") @@ -113,13 +115,14 @@ def make_model(self) -> LlamaModel: else: llama_model.model = llama_model.model.to(self.device) llama_model.model.eval() - return llama_model + self.llama_model = llama_model + return self.llama_model def prepare(self): pass def close(self): - pass + del self.llama_model.model def set_reproducibility_mode(self): self.reproducibility_config = {"do_sample": False} diff --git a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py index 896cfcaa..d8c5b46f 100644 --- a/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py +++ b/src/agentlab/agents/vl_agent/vl_model/openrouter_api_model.py @@ -1,5 +1,6 @@ from agentlab.llm.llm_utils import AIMessage, Discussion from dataclasses import dataclass +from functools import cache from openai import OpenAI, RateLimitError from .base import VLModel, VLModelArgs import backoff @@ -46,16 +47,18 @@ class OpenRouterAPIModelArgs(VLModelArgs): reproducibility_config: dict @property + @cache def model_name(self) -> str: return self.model_id.split("/")[-1].replace("-", "_").replace(".", "") def make_model(self) -> OpenRouterAPIModel: - return OpenRouterAPIModel( + self.openrouter_api_model = OpenRouterAPIModel( base_url=self.base_url, model_id=self.model_id, max_tokens=self.max_tokens, reproducibility_config=self.reproducibility_config, ) + return self.openrouter_api_model def prepare(self): pass diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index 1abdecd6..c24f18ec 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -4,7 +4,6 @@ HumanMessage, ParseError, parse_html_tags_raise, - SystemMessage, ) from browsergym.core.action.highlevel import HighLevelActionSet from dataclasses import dataclass @@ -15,38 +14,27 @@ import numpy as np -class SystemPromptPart(VLPromptPart): +class IntroductionPromptPart(VLPromptPart): def __init__(self): self.text = """\ You are an agent working to address a web-based task through step-by-step interactions with the browser. \ -At each step, you need to submit an action according to the current state of the browser. \ -This action will be executed and the state of the browser will be updated. +To achieve the goal of the task, at each step, you need to submit an action according to the current state of the browser. \ +This action will be executed to update the state of the browser, and you will proceed to the next step. """ def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] -class InstructionPromptPart(VLPromptPart): - def __init__( - self, - goal_object: list[dict], - extra_instruction: Optional[str] = None, - ): +class GoalPromptPart(VLPromptPart): + def __init__(self, goal_object: list[dict]): text = """\ -# Instruction -Review the current state of the browser and all other information to find the best next action to achieve the goal. -## Goal +# The goal of the task """ for item in goal_object: if item["type"] == "text": text += f"""\ {item['text']} -""" - if extra_instruction is not None: - text += f"""\ -## Extra Instruction -{extra_instruction} """ self.text = text @@ -57,7 +45,7 @@ def get_message_content(self) -> list[dict]: class ScreenshotPromptPart(VLPromptPart): def __init__(self, screenshot: Union[Image.Image, np.ndarray]): self.text = """\ -# The Screenshot of the Current Web Page +# The screenshot of the current web page """ self.image_url = image_to_image_url(screenshot) @@ -73,7 +61,7 @@ def __init__( self, open_pages_titles: list[str], open_pages_urls: list[str], active_page_index: int ): text = """\ -# The Open Tabs of the Browser +# The open tabs of the browser """ for index, (title, url) in enumerate(zip(open_pages_titles, open_pages_urls)): text += f"""\ @@ -92,7 +80,7 @@ def get_message_content(self) -> list[dict]: class HistoryPromptPart(VLPromptPart): def __init__(self, thoughts: list[str], actions: list[str]): text = """\ -# The Previous Steps +# The thoughts and actions of the previous steps """ for index, (thought, action) in enumerate(zip(thoughts, actions)): text += f"""\ @@ -116,7 +104,7 @@ def __init__( logs_limit: int = 5, ): text = """\ -# The Error from Last Action +# The error from the last action """ if logs_separator in last_action_error: error, logs = last_action_error.split(logs_separator) @@ -141,53 +129,39 @@ def get_message_content(self) -> list[dict]: class AnswerPromptPart(VLPromptPart): def __init__( - self, - action_set_description: str, - use_abstract_example: bool, - use_concrete_example: bool, - preliminary_answer: Optional[dict] = None, + self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool ): text = f"""\ -# Answer Requirements -## Action Space +# The action space Here are all the actions you can take to interact with the browser. \ They are Python functions based on the Playwright library. {action_set_description} -## Answer Format -Think about the next action to take, and choose it from the action space. \ -Your answer should include both the thought and the action. +# The format requirements for the answer +Think about the action to take, and choose it from the action space. \ +Your answer should include both the thought and the action. \ +Your answer should only include one thought and one action. """ if use_abstract_example: text += """\ -### An Abstract Example of the Answer +# An abstract example of the answer -The thought about the next action. +The thought about the action. -The next action to take. +The action to take. """ if use_concrete_example: text += """\ -### A Concrete Example of the Answer +# A concrete example of the answer -From previous action I tried to set the value of year to "2022", using select_option, but it doesn't appear to be in the form. \ -It may be a dynamic dropdown, I will try using click with the bid "a324" and look at the response from the page. +The goal is to click on the numbers in ascending order. \ +The smallest number visible on the screen is '1'. \ +Based on the screenshot, '1' is located in the top-left quadrant of the white area. \ +I will use the 'mouse_click' action to directly click on the visible '1' by specifying its coordinates. -click('a324') - -""" - if preliminary_answer is not None: - text += f"""\ -## A Preliminary Answer -Here is a preliminary answer, which might be incorrect or inaccurate. \ -You can refine it to get your answer. - -{preliminary_answer['thought']} - - -{preliminary_answer['action']} +mouse_click(50, 50) """ self.text = text @@ -198,8 +172,8 @@ def get_message_content(self) -> list[dict]: @dataclass class UIPrompt(VLPrompt): - system_prompt_part: SystemPromptPart - instruction_prompt_part: InstructionPromptPart + introduction_prompt_part: IntroductionPromptPart + goal_prompt_part: GoalPromptPart screenshot_prompt_part: Optional[ScreenshotPromptPart] tabs_prompt_part: Optional[TabsPromptPart] history_prompt_part: Optional[HistoryPromptPart] @@ -208,20 +182,18 @@ class UIPrompt(VLPrompt): action_validator: callable def get_messages(self) -> Discussion: - system_message_content = self.system_prompt_part.get_message_content() - human_message_content = self.instruction_prompt_part.get_message_content() + message_content = self.introduction_prompt_part.get_message_content() + message_content.extend(self.goal_prompt_part.get_message_content()) if self.screenshot_prompt_part is not None: - human_message_content.extend(self.screenshot_prompt_part.get_message_content()) + message_content.extend(self.screenshot_prompt_part.get_message_content()) if self.tabs_prompt_part is not None: - human_message_content.extend(self.tabs_prompt_part.get_message_content()) + message_content.extend(self.tabs_prompt_part.get_message_content()) if self.history_prompt_part is not None: - human_message_content.extend(self.history_prompt_part.get_message_content()) + message_content.extend(self.history_prompt_part.get_message_content()) if self.error_prompt_part is not None: - human_message_content.extend(self.error_prompt_part.get_message_content()) - human_message_content.extend(self.answer_prompt_part.get_message_content()) - messages = Discussion( - [SystemMessage(system_message_content), HumanMessage(human_message_content)] - ) + message_content.extend(self.error_prompt_part.get_message_content()) + message_content.extend(self.answer_prompt_part.get_message_content()) + messages = Discussion([HumanMessage(message_content)]) messages.merge() return messages @@ -229,23 +201,15 @@ def parse_answer(self, answer_content: list[dict]) -> dict: answer_text = answer_content[0]["text"] answer_dict = {} try: - answer_dict.update( - parse_html_tags_raise(answer_text, keys=["thought"], merge_multiple=True) - ) + answer_dict.update(parse_html_tags_raise(answer_text, keys=["thought", "action"])) except ParseError as error: + answer_dict["parse_error"] = str(error) answer_dict["thought"] = answer_text - answer_dict["thought_parse_error"] = str(error) - try: - answer_dict.update( - parse_html_tags_raise(answer_text, keys=["action"], merge_multiple=True) - ) - except ParseError as error: code_blocks = extract_code_blocks(answer_text) if len(code_blocks) == 0: raise error else: answer_dict["action"] = "\n".join([block for _, block in code_blocks]) - answer_dict["action_parse_error"] = str(error) if answer_dict["action"] == "None": answer_dict["action"] = None else: @@ -268,17 +232,10 @@ class UIPromptArgs(VLPromptArgs): extra_instruction: Optional[str] def make_prompt( - self, - obs: dict, - thoughts: list[str], - actions: list[str], - action_set: HighLevelActionSet, - preliminary_answer: Optional[dict] = None, + self, obs: dict, thoughts: list[str], actions: list[str], action_set: HighLevelActionSet ) -> UIPrompt: - system_prompt_part = SystemPromptPart() - instruction_prompt_part = InstructionPromptPart( - goal_object=obs["goal_object"], extra_instruction=self.extra_instruction - ) + introduction_prompt_part = IntroductionPromptPart() + goal_prompt_part = GoalPromptPart(obs["goal_object"]) if self.use_screenshot: screenshot_prompt_part = ScreenshotPromptPart(obs["screenshot"]) else: @@ -291,7 +248,7 @@ def make_prompt( ) else: tabs_prompt_part = None - if self.use_history: + if self.use_history and len(thoughts) == len(actions) > 0: history_prompt_part = HistoryPromptPart(thoughts=thoughts, actions=actions) else: history_prompt_part = None @@ -305,11 +262,10 @@ def make_prompt( ), use_abstract_example=self.use_abstract_example, use_concrete_example=self.use_concrete_example, - preliminary_answer=preliminary_answer, ) - return UIPrompt( - system_prompt_part=system_prompt_part, - instruction_prompt_part=instruction_prompt_part, + self.ui_prompt = UIPrompt( + introduction_prompt_part=introduction_prompt_part, + goal_prompt_part=goal_prompt_part, screenshot_prompt_part=screenshot_prompt_part, tabs_prompt_part=tabs_prompt_part, history_prompt_part=history_prompt_part, @@ -317,3 +273,4 @@ def make_prompt( answer_prompt_part=answer_prompt_part, action_validator=action_set.to_python_code, ) + return self.ui_prompt From ace3e413ad5c3488a565cda0e00fd5d17d1abfdf Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Tue, 3 Jun 2025 19:03:58 -0400 Subject: [PATCH 27/29] update --- src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index c24f18ec..ce727f3e 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -104,7 +104,7 @@ def __init__( logs_limit: int = 5, ): text = """\ -# The error from the last action +# The error caused by the last action """ if logs_separator in last_action_error: error, logs = last_action_error.split(logs_separator) @@ -136,10 +136,9 @@ def __init__( Here are all the actions you can take to interact with the browser. \ They are Python functions based on the Playwright library. {action_set_description} -# The format requirements for the answer +# The format of the answer Think about the action to take, and choose it from the action space. \ -Your answer should include both the thought and the action. \ -Your answer should only include one thought and one action. +Your answer should include one thought and one action. """ if use_abstract_example: text += """\ @@ -157,8 +156,7 @@ def __init__( The goal is to click on the numbers in ascending order. \ The smallest number visible on the screen is '1'. \ -Based on the screenshot, '1' is located in the top-left quadrant of the white area. \ -I will use the 'mouse_click' action to directly click on the visible '1' by specifying its coordinates. +I will use the 'mouse_click' action to directly click on the number '1'. mouse_click(50, 50) From 12016cb17681d3c6797f9d1e065df0112a756e46 Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Tue, 3 Jun 2025 19:14:35 -0400 Subject: [PATCH 28/29] update --- .../agents/vl_agent/vl_prompt/ui_prompt.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index ce727f3e..2fad8498 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -126,6 +126,45 @@ def __init__( def get_message_content(self) -> list[dict]: return [{"type": "text", "text": self.text}] +class PreliminaryAnswerPromptPart(VLPromptPart): + def __init__( + self, action_set_description: str, use_abstract_example: bool, use_concrete_example: bool + ): + text = f"""\ +# The action space +Here are all the actions you can take to interact with the browser. \ +They are Python functions based on the Playwright library. +{action_set_description} +# The format of the answer +Think about the action to take, and describe the location to take the action. \ +Your answer should include one thought and one location. +""" + if use_abstract_example: + text += """\ +# An abstract example of the answer + +The thought about the action. + + +The description of the location. + +""" + if use_concrete_example: + text += """\ +# A concrete example of the answer + +The goal is to click on the numbers in ascending order. \ +The smallest number visible on the screen is '1'. \ +I will use the 'mouse_click' action to directly click on the number '1'. + + +The number '1' in the top-left quadrant of the white area. + +""" + self.text = text + + def get_message_content(self) -> list[dict]: + return [{"type": "text", "text": self.text}] class AnswerPromptPart(VLPromptPart): def __init__( From 6ebcdd8e8da11f29c48bfa7c9c404f46f471769a Mon Sep 17 00:00:00 2001 From: Chao Wang Date: Tue, 3 Jun 2025 19:27:50 -0400 Subject: [PATCH 29/29] update --- src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py index 2fad8498..b50e3eea 100644 --- a/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py +++ b/src/agentlab/agents/vl_agent/vl_prompt/ui_prompt.py @@ -176,7 +176,7 @@ def __init__( They are Python functions based on the Playwright library. {action_set_description} # The format of the answer -Think about the action to take, and choose it from the action space. \ +Think about the action to take, and choose the action from the action space. \ Your answer should include one thought and one action. """ if use_abstract_example: