-
Notifications
You must be signed in to change notification settings - Fork 107
VLAgent #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VLAgent #250
Changes from all commits
4f7c035
e619dfd
c7aa049
f6318b4
414cf68
e6e9724
6c27267
1cfe416
3b6561a
a1361ad
c536f4f
cbbad61
28dd340
7767826
7fb9d39
8a727fd
acfd6a8
e2624c1
41cbc3b
77ed33b
fd55958
03d777c
7d19c07
294e526
a7b9999
30bc57b
ace3e41
12016cb
6ebcdd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| from browsergym.experiments.benchmark import HighLevelActionSetArgs | ||
| from .vl_agent.ui_agent import UIAgentArgs | ||
| from .vl_model.llama_model import LlamaModelArgs | ||
| from .vl_model.openrouter_api_model import OpenRouterAPIModelArgs | ||
| from .vl_prompt.ui_prompt import UIPromptArgs | ||
|
|
||
|
|
||
| VL_MODEL_ARGS_DICT = { | ||
| "gpt_4o": OpenRouterAPIModelArgs( | ||
| base_url="https://openrouter.ai/api/v1", | ||
| model_id="openai/gpt-4o-2024-11-20", | ||
| max_tokens=8192, | ||
| reproducibility_config={"temperature": 0.1}, | ||
| ), | ||
| "llama_32_11b": LlamaModelArgs( | ||
| model_path="meta-llama/Llama-3.2-11B-Vision-Instruct", | ||
| torch_dtype="bfloat16", | ||
| accelerator_config={"mixed_precision": "bf16", "cpu": False}, | ||
| reproducibility_config={"temperature": 0.1}, | ||
| max_length=32768, | ||
| max_new_tokens=8192, | ||
| checkpoint_file=None, | ||
| device=None, | ||
| ), | ||
| } | ||
|
|
||
| VL_PROMPT_ARGS_DICT = { | ||
| "ui_prompt": UIPromptArgs( | ||
| use_screenshot=True, | ||
| use_screenshot_som=False, | ||
| use_tabs=True, | ||
| use_history=True, | ||
| use_error=True, | ||
| use_abstract_example=True, | ||
| use_concrete_example=False, | ||
| extra_instruction=None, | ||
| ) | ||
| } | ||
|
|
||
| VL_AGENT_ARGS_DICT = { | ||
| "ui_agent": UIAgentArgs( | ||
| main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"], | ||
| auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"], | ||
| ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"], | ||
| action_set_args=HighLevelActionSetArgs(subsets=["coord"]), | ||
| max_retry=4, | ||
| ) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT | ||
| from agentlab.experiments.study import Study | ||
| import logging | ||
| import os | ||
|
|
||
|
|
||
| logging.getLogger().setLevel(logging.INFO) | ||
|
|
||
| vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]] | ||
| benchmark = "miniwob" | ||
| os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/" | ||
| reproducibility_mode = False | ||
| relaunch = False | ||
| n_jobs = 1 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| if reproducibility_mode: | ||
| for vl_agent_args in vl_agent_args_list: | ||
| vl_agent_args.set_reproducibility_mode() | ||
| if relaunch: | ||
| study = Study.load_most_recent(contains=None) | ||
| study.find_incomplete(include_errors=True) | ||
| else: | ||
| study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING) | ||
| study.run( | ||
| n_jobs=n_jobs, | ||
| parallel_backend="sequential", | ||
| strict_reproducibility=reproducibility_mode, | ||
| n_relaunch=3, | ||
| ) | ||
| if reproducibility_mode: | ||
| study.append_to_journal(strict_reproducibility=True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from accelerate import dispatch_model, infer_auto_device_map | ||
| from accelerate.utils.modeling import get_balanced_memory | ||
| from PIL import Image | ||
| from torch.nn import Module | ||
| from typing import Union | ||
| import base64 | ||
| import io | ||
| import numpy as np | ||
|
|
||
|
|
||
| def image_to_image_url(image: Union[Image.Image, np.ndarray]): | ||
| if isinstance(image, np.ndarray): | ||
| image = Image.fromarray(image) | ||
|
Comment on lines
+11
to
+13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsafe Numpy Array Conversion
Tell me moreWhat is the issue?The function doesn't handle invalid numpy array inputs that can't be converted to PIL Images. Why this mattersInvalid array shapes, data types, or value ranges could cause crashes when converting numpy arrays to PIL Images. Suggested change ∙ Feature PreviewAdd error handling for numpy array conversion: def image_to_image_url(image: Union[Image.Image, np.ndarray]):
if isinstance(image, np.ndarray):
try:
image = Image.fromarray(image)
except TypeError as e:
raise ValueError(f"Invalid numpy array format: {e}")Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| if image.mode in ("RGBA", "LA"): | ||
| image = image.convert("RGB") | ||
| buffer = io.BytesIO() | ||
| image.save(buffer, format="JPEG") | ||
| image_base64 = base64.b64encode(buffer.getvalue()).decode() | ||
| image_url = f"data:image/jpeg;base64,{image_base64}" | ||
| return image_url | ||
|
|
||
|
|
||
| def image_url_to_image(image_url: str) -> Image.Image: | ||
| image_base64 = image_url.replace("data:image/jpeg;base64,", "") | ||
|
Comment on lines
+23
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inflexible Image Format Handling
Tell me moreWhat is the issue?The function assumes all image URLs will be JPEG format, but the input URL could contain other image formats (PNG, GIF, etc.). Why this mattersThis could cause conversion failures when processing non-JPEG images, leading to runtime errors in the virtual learning agent. Suggested change ∙ Feature PreviewModify the function to handle different image formats by extracting the format from the URL: def image_url_to_image(image_url: str) -> Image.Image:
# Extract format from data URL
format_match = image_url.split(';')[0]
if not format_match.startswith('data:image/'):
raise ValueError('Invalid image data URL')
# Remove the data URL prefix
image_base64 = image_url.split(',')[1]
image_data = base64.b64decode(image_base64.encode())
buffer = io.BytesIO(image_data)
image = Image.open(buffer)
return imageProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| image_data = base64.b64decode(image_base64.encode()) | ||
| buffer = io.BytesIO(image_data) | ||
| image = Image.open(buffer) | ||
|
Comment on lines
+23
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsafe image data processing without validation
Tell me moreWhat is the issue?The function accepts and processes base64 image data without proper input validation, which could lead to processing of malicious image data. Why this mattersMalformed or malicious image data could lead to memory exhaustion, buffer overflows, or arbitrary code execution through image parsing vulnerabilities. Suggested change ∙ Feature Previewdef image_url_to_image(image_url: str) -> Image.Image:
if not image_url.startswith("data:image/jpeg;base64,"):
raise ValueError("Invalid image URL format")
try:
image_base64 = image_url.replace("data:image/jpeg;base64,", "")
image_data = base64.b64decode(image_base64.encode())
buffer = io.BytesIO(image_data)
image = Image.open(buffer)
if image.format != "JPEG":
raise ValueError("Invalid image format")
return image
except (ValueError, IOError) as e:
raise ValueError("Invalid image data") from eProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| return image | ||
|
|
||
|
|
||
| def auto_dispatch_model(model: Module, no_split_module_classes: list[str]) -> Module: | ||
| max_memory = get_balanced_memory(model, no_split_module_classes=no_split_module_classes) | ||
| device_map = infer_auto_device_map( | ||
| model, | ||
| max_memory=max_memory, | ||
| no_split_module_classes=no_split_module_classes, | ||
| ) | ||
| model = dispatch_model(model, device_map=device_map) | ||
| return model | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| from abc import ABC, abstractmethod | ||
| from browsergym.core.action.highlevel import HighLevelActionSet | ||
| from browsergym.experiments.benchmark import Benchmark | ||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| class VLAgent(ABC): | ||
| @property | ||
| @abstractmethod | ||
| def action_set(self) -> HighLevelActionSet: | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def get_action(self, obs: dict) -> tuple[str, dict]: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unclear Action Format Specification
Tell me moreWhat is the issue?The return type tuple[str, dict] lacks documentation about what the string and dictionary represent in the action tuple, which could lead to incorrect implementations by derived classes. Why this mattersWithout clear documentation of the expected format and content of the action tuple, implementations may return incompatible or incorrect action formats that could cause runtime errors or unexpected behavior. Suggested change ∙ Feature PreviewAdd a docstring explaining the expected format of the action tuple: @abstractmethod
def get_action(self, obs: dict) -> tuple[str, dict]:
"""Get the next action based on the observation.
Args:
obs (dict): The observation from the environment
Returns:
tuple[str, dict]: A tuple containing:
- action_type (str): The type of action to perform (e.g., 'click', 'type')
- action_params (dict): Parameters specific to the action type
"""
raise NotImplementedErrorProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def obs_preprocessor(self, obs: dict) -> dict: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unspecified Observation Preprocessing Requirements
Tell me moreWhat is the issue?The obs_preprocessor method lacks specification about what preprocessing should be performed and what the output format should be. Why this mattersImplementations might process observations incorrectly or inconsistently, leading to unexpected agent behavior when processing environment observations. Suggested change ∙ Feature PreviewAdd a docstring specifying the preprocessing requirements: @abstractmethod
def obs_preprocessor(self, obs: dict) -> dict:
"""Preprocess the raw observation from the environment.
Args:
obs (dict): Raw observation from the environment
Returns:
dict: Processed observation with standardized format:
- Required fields should be documented here
- Any transformations that must be applied
"""
raise NotImplementedErrorProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| raise NotImplementedError | ||
|
|
||
|
|
||
| @dataclass | ||
| class VLAgentArgs(ABC): | ||
| @property | ||
| @abstractmethod | ||
| def agent_name(self) -> str: | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def make_agent(self) -> VLAgent: | ||
| raise NotImplementedError | ||
|
Comment on lines
+22
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mixed Responsibilities in VLAgentArgs
Tell me moreWhat is the issue?VLAgentArgs is mixing the Factory pattern (make_agent) with configuration (agent_name) and lifecycle management (prepare, close) responsibilities, violating the Single Responsibility Principle. Why this mattersThis design makes the class harder to maintain and test, while reducing reusability. Lifecycle management, configuration, and object creation should be handled by separate components. Suggested change ∙ Feature PreviewSplit the class into separate components: @dataclass
class VLAgentConfig:
agent_name: str
# other configuration parameters
class VLAgentFactory(ABC):
@abstractmethod
def make_agent(self, config: VLAgentConfig) -> VLAgent:
pass
class VLAgentLifecycle(ABC):
@abstractmethod
def prepare(self):
pass
@abstractmethod
def close(self):
passProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| @abstractmethod | ||
| def prepare(self): | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def close(self): | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def set_reproducibility_mode(self): | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): | ||
| raise NotImplementedError | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,145 @@ | ||
| from agentlab.llm.llm_utils import ParseError, retry | ||
| from agentlab.llm.tracking import cost_tracker_decorator | ||
| from browsergym.core.action.highlevel import HighLevelActionSet | ||
| from browsergym.experiments.agent import AgentInfo | ||
| from browsergym.experiments.benchmark import Benchmark | ||
| from browsergym.experiments.benchmark.base import HighLevelActionSetArgs | ||
| from browsergym.utils.obs import overlay_som | ||
| from copy import copy, deepcopy | ||
| from dataclasses import asdict, dataclass | ||
| from functools import cache | ||
| from typing import Optional | ||
| from .base import VLAgent, VLAgentArgs | ||
| from ..vl_model.base import VLModelArgs | ||
| from ..vl_prompt.ui_prompt import UIPromptArgs | ||
|
|
||
|
|
||
| class UIAgent(VLAgent): | ||
| def __init__( | ||
| self, | ||
| main_vl_model_args: VLModelArgs, | ||
| auxiliary_vl_model_args: Optional[VLModelArgs], | ||
| ui_prompt_args: UIPromptArgs, | ||
| action_set_args: HighLevelActionSetArgs, | ||
| max_retry: int, | ||
| ): | ||
| self.main_vl_model = main_vl_model_args.make_model() | ||
| if auxiliary_vl_model_args is None: | ||
| self.auxiliary_vl_model = None | ||
| else: | ||
| self.auxiliary_vl_model = auxiliary_vl_model_args.make_model() | ||
| self.ui_prompt_args = ui_prompt_args | ||
| self.action_set_args = action_set_args | ||
| self.max_retry = max_retry | ||
| self.thoughts = [] | ||
| self.actions = [] | ||
|
|
||
| @property | ||
| @cache | ||
| def action_set(self) -> HighLevelActionSet: | ||
| return self.action_set_args.make_action_set() | ||
|
|
||
| @cost_tracker_decorator | ||
| def get_action(self, obs: dict) -> tuple[str, dict]: | ||
| ui_prompt = self.ui_prompt_args.make_prompt( | ||
| obs=obs, thoughts=self.thoughts, actions=self.actions, action_set=self.action_set | ||
| ) | ||
| try: | ||
| messages = ui_prompt.get_messages() | ||
| answer = retry( | ||
| chat=self.main_vl_model, | ||
| messages=messages, | ||
| n_retry=self.max_retry, | ||
| parser=ui_prompt.parse_answer, | ||
| ) | ||
| stats = {"num_main_retries": (len(messages) - 3) // 2} | ||
| except ParseError: | ||
| answer = {"thought": None, "action": None} | ||
| stats = {"num_main_retries": self.max_retry} | ||
| stats.update(self.main_vl_model.get_stats()) | ||
| if self.auxiliary_vl_model is not None: | ||
| preliminary_answer = answer | ||
| ui_prompt = self.ui_prompt_args.make_prompt( | ||
| obs=obs, | ||
| thoughts=self.thoughts, | ||
| actions=self.actions, | ||
| action_set=self.action_set, | ||
| preliminary_answer=preliminary_answer, | ||
| ) | ||
| try: | ||
| messages = ui_prompt.get_messages() | ||
| answer = retry( | ||
| chat=self.auxiliary_vl_model, | ||
| messages=messages, | ||
| n_retry=self.max_retry, | ||
| parser=ui_prompt.parse_answer, | ||
| ) | ||
| stats["num_auxiliary_retries"] = (len(messages) - 3) // 2 | ||
| except ParseError: | ||
| answer = {"thought": None, "action": None} | ||
| stats["num_auxiliary_retries"] = self.max_retry | ||
| stats.update(self.auxiliary_vl_model.get_stats()) | ||
| else: | ||
| preliminary_answer = None | ||
| self.thoughts.append(str(answer["thought"])) | ||
| self.actions.append(str(answer["action"])) | ||
| agent_info = AgentInfo( | ||
| think=str(answer["thought"]), stats=stats, extra_info=preliminary_answer | ||
| ) | ||
| return answer["action"], asdict(agent_info) | ||
|
|
||
| def obs_preprocessor(self, obs: dict) -> dict: | ||
| obs = copy(obs) | ||
| if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som: | ||
| obs["screenshot"] = overlay_som( | ||
| obs["screenshot"], extra_properties=obs["extra_element_properties"] | ||
| ) | ||
| return obs | ||
|
|
||
|
|
||
| @dataclass | ||
| class UIAgentArgs(VLAgentArgs): | ||
| main_vl_model_args: VLModelArgs | ||
| auxiliary_vl_model_args: Optional[VLModelArgs] | ||
| ui_prompt_args: UIPromptArgs | ||
| action_set_args: HighLevelActionSetArgs | ||
| max_retry: int | ||
|
|
||
| @property | ||
| @cache | ||
| def agent_name(self) -> str: | ||
| if self.auxiliary_vl_model_args is None: | ||
| return f"UIAgent-{self.main_vl_model_args.model_name}" | ||
| else: | ||
| return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}" | ||
|
|
||
| def make_agent(self) -> UIAgent: | ||
| self.ui_agent = UIAgent( | ||
| main_vl_model_args=self.main_vl_model_args, | ||
| auxiliary_vl_model_args=self.auxiliary_vl_model_args, | ||
| ui_prompt_args=self.ui_prompt_args, | ||
| action_set_args=self.action_set_args, | ||
| max_retry=self.max_retry, | ||
| ) | ||
| return self.ui_agent | ||
|
|
||
| def prepare(self): | ||
| self.main_vl_model_args.prepare() | ||
| if self.auxiliary_vl_model_args is not None: | ||
| self.auxiliary_vl_model_args.prepare() | ||
|
|
||
| def close(self): | ||
| self.main_vl_model_args.close() | ||
| if self.auxiliary_vl_model_args is not None: | ||
| self.auxiliary_vl_model_args.close() | ||
|
|
||
| def set_reproducibility_mode(self): | ||
| self.main_vl_model_args.set_reproducibility_mode() | ||
| if self.auxiliary_vl_model_args is not None: | ||
| self.auxiliary_vl_model_args.set_reproducibility_mode() | ||
|
|
||
| def set_benchmark(self, benchmark: Benchmark, demo_mode: bool): | ||
| self.ui_prompt_args.use_tabs = benchmark.is_multi_tab | ||
| self.action_set_args = deepcopy(benchmark.high_level_action_set_args) | ||
| if demo_mode: | ||
| self.action_set_args.demo_mode = "all_blue" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| from abc import ABC, abstractmethod | ||
| from agentlab.llm.llm_utils import AIMessage, Discussion | ||
|
|
||
|
|
||
| class VLModel(ABC): | ||
| @abstractmethod | ||
| def __call__(self, messages: Discussion) -> AIMessage: | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def get_stats(self) -> dict: | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class VLModelArgs(ABC): | ||
| @property | ||
| @abstractmethod | ||
| def model_name(self) -> str: | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def make_model(self) -> VLModel: | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def prepare(self): | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def close(self): | ||
| raise NotImplementedError | ||
|
|
||
| @abstractmethod | ||
| def set_reproducibility_mode(self): | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return type hint
Tell me more
What is the issue?
Return type hint is missing for the function.
Why this matters
Missing return type hints make it harder for developers to understand the expected output type without reading the implementation details.
Suggested change ∙ Feature Preview
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.