Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/agentlab/agents/vl_agent/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from browsergym.experiments.benchmark import HighLevelActionSetArgs
from .vl_agent.ui_agent import UIAgentArgs
from .vl_model.llama_model import LlamaModelArgs
from .vl_model.openrouter_api_model import OpenRouterAPIModelArgs
from .vl_prompt.ui_prompt import UIPromptArgs


VL_MODEL_ARGS_DICT = {
"gpt_4o": OpenRouterAPIModelArgs(
base_url="https://openrouter.ai/api/v1",
model_id="openai/gpt-4o-2024-11-20",
max_tokens=8192,
reproducibility_config={"temperature": 0.1},
),
"llama_32_11b": LlamaModelArgs(
model_path="meta-llama/Llama-3.2-11B-Vision-Instruct",
torch_dtype="bfloat16",
accelerator_config={"mixed_precision": "bf16", "cpu": False},
reproducibility_config={"temperature": 0.1},
max_length=32768,
max_new_tokens=8192,
checkpoint_file=None,
device=None,
),
}

VL_PROMPT_ARGS_DICT = {
"ui_prompt": UIPromptArgs(
use_screenshot=True,
use_screenshot_som=False,
use_tabs=True,
use_history=True,
use_error=True,
use_abstract_example=True,
use_concrete_example=False,
extra_instruction=None,
)
}

VL_AGENT_ARGS_DICT = {
"ui_agent": UIAgentArgs(
main_vl_model_args=VL_MODEL_ARGS_DICT["gpt_4o"],
auxiliary_vl_model_args=VL_MODEL_ARGS_DICT["llama_32_11b"],
ui_prompt_args=VL_PROMPT_ARGS_DICT["ui_prompt"],
action_set_args=HighLevelActionSetArgs(subsets=["coord"]),
max_retry=4,
)
}
33 changes: 33 additions & 0 deletions src/agentlab/agents/vl_agent/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from agentlab.agents.vl_agent.config import VL_AGENT_ARGS_DICT
from agentlab.experiments.study import Study
import logging
import os


logging.getLogger().setLevel(logging.INFO)

vl_agent_args_list = [VL_AGENT_ARGS_DICT["ui_agent"]]
benchmark = "miniwob"
os.environ["MINIWOB_URL"] = "file:///mnt/home/miniwob-plusplus/miniwob/html/miniwob/"
reproducibility_mode = False
relaunch = False
n_jobs = 1


if __name__ == "__main__":
if reproducibility_mode:
for vl_agent_args in vl_agent_args_list:
vl_agent_args.set_reproducibility_mode()
if relaunch:
study = Study.load_most_recent(contains=None)
study.find_incomplete(include_errors=True)
else:
study = Study(vl_agent_args_list, benchmark=benchmark, logging_level_stdout=logging.WARNING)
study.run(
n_jobs=n_jobs,
parallel_backend="sequential",
strict_reproducibility=reproducibility_mode,
n_relaunch=3,
)
if reproducibility_mode:
study.append_to_journal(strict_reproducibility=True)
39 changes: 39 additions & 0 deletions src/agentlab/agents/vl_agent/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils.modeling import get_balanced_memory
from PIL import Image
from torch.nn import Module
from typing import Union
import base64
import io
import numpy as np


def image_to_image_url(image: Union[Image.Image, np.ndarray]):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type hint category Readability

Tell me more
What is the issue?

Return type hint is missing for the function.

Why this matters

Missing return type hints make it harder for developers to understand the expected output type without reading the implementation details.

Suggested change ∙ Feature Preview
def image_to_image_url(image: Union[Image.Image, np.ndarray]) -> str:
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

if isinstance(image, np.ndarray):
image = Image.fromarray(image)
Comment on lines +11 to +13
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe Numpy Array Conversion category Functionality

Tell me more
What is the issue?

The function doesn't handle invalid numpy array inputs that can't be converted to PIL Images.

Why this matters

Invalid array shapes, data types, or value ranges could cause crashes when converting numpy arrays to PIL Images.

Suggested change ∙ Feature Preview

Add error handling for numpy array conversion:

def image_to_image_url(image: Union[Image.Image, np.ndarray]):
    if isinstance(image, np.ndarray):
        try:
            image = Image.fromarray(image)
        except TypeError as e:
            raise ValueError(f"Invalid numpy array format: {e}")
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

if image.mode in ("RGBA", "LA"):
image = image.convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
image_base64 = base64.b64encode(buffer.getvalue()).decode()
image_url = f"data:image/jpeg;base64,{image_base64}"
return image_url


def image_url_to_image(image_url: str) -> Image.Image:
image_base64 = image_url.replace("data:image/jpeg;base64,", "")
Comment on lines +23 to +24
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inflexible Image Format Handling category Functionality

Tell me more
What is the issue?

The function assumes all image URLs will be JPEG format, but the input URL could contain other image formats (PNG, GIF, etc.).

Why this matters

This could cause conversion failures when processing non-JPEG images, leading to runtime errors in the virtual learning agent.

Suggested change ∙ Feature Preview

Modify the function to handle different image formats by extracting the format from the URL:

def image_url_to_image(image_url: str) -> Image.Image:
    # Extract format from data URL
    format_match = image_url.split(';')[0]
    if not format_match.startswith('data:image/'):
        raise ValueError('Invalid image data URL')
    # Remove the data URL prefix
    image_base64 = image_url.split(',')[1]
    image_data = base64.b64decode(image_base64.encode())
    buffer = io.BytesIO(image_data)
    image = Image.open(buffer)
    return image
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

image_data = base64.b64decode(image_base64.encode())
buffer = io.BytesIO(image_data)
image = Image.open(buffer)
Comment on lines +23 to +27
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe image data processing without validation category Security

Tell me more
What is the issue?

The function accepts and processes base64 image data without proper input validation, which could lead to processing of malicious image data.

Why this matters

Malformed or malicious image data could lead to memory exhaustion, buffer overflows, or arbitrary code execution through image parsing vulnerabilities.

Suggested change ∙ Feature Preview
def image_url_to_image(image_url: str) -> Image.Image:
    if not image_url.startswith("data:image/jpeg;base64,"):
        raise ValueError("Invalid image URL format")
    try:
        image_base64 = image_url.replace("data:image/jpeg;base64,", "")
        image_data = base64.b64decode(image_base64.encode())
        buffer = io.BytesIO(image_data)
        image = Image.open(buffer)
        if image.format != "JPEG":
            raise ValueError("Invalid image format")
        return image
    except (ValueError, IOError) as e:
        raise ValueError("Invalid image data") from e
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

return image


def auto_dispatch_model(model: Module, no_split_module_classes: list[str]) -> Module:
max_memory = get_balanced_memory(model, no_split_module_classes=no_split_module_classes)
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
)
model = dispatch_model(model, device_map=device_map)
return model
47 changes: 47 additions & 0 deletions src/agentlab/agents/vl_agent/vl_agent/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from abc import ABC, abstractmethod
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.benchmark import Benchmark
from dataclasses import dataclass


class VLAgent(ABC):
@property
@abstractmethod
def action_set(self) -> HighLevelActionSet:
raise NotImplementedError

@abstractmethod
def get_action(self, obs: dict) -> tuple[str, dict]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear Action Format Specification category Documentation

Tell me more
What is the issue?

The return type tuple[str, dict] lacks documentation about what the string and dictionary represent in the action tuple, which could lead to incorrect implementations by derived classes.

Why this matters

Without clear documentation of the expected format and content of the action tuple, implementations may return incompatible or incorrect action formats that could cause runtime errors or unexpected behavior.

Suggested change ∙ Feature Preview

Add a docstring explaining the expected format of the action tuple:

@abstractmethod
def get_action(self, obs: dict) -> tuple[str, dict]:
    """Get the next action based on the observation.
    
    Args:
        obs (dict): The observation from the environment
        
    Returns:
        tuple[str, dict]: A tuple containing:
            - action_type (str): The type of action to perform (e.g., 'click', 'type')
            - action_params (dict): Parameters specific to the action type
    """
    raise NotImplementedError
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

raise NotImplementedError

@abstractmethod
def obs_preprocessor(self, obs: dict) -> dict:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unspecified Observation Preprocessing Requirements category Documentation

Tell me more
What is the issue?

The obs_preprocessor method lacks specification about what preprocessing should be performed and what the output format should be.

Why this matters

Implementations might process observations incorrectly or inconsistently, leading to unexpected agent behavior when processing environment observations.

Suggested change ∙ Feature Preview

Add a docstring specifying the preprocessing requirements:

@abstractmethod
def obs_preprocessor(self, obs: dict) -> dict:
    """Preprocess the raw observation from the environment.
    
    Args:
        obs (dict): Raw observation from the environment
        
    Returns:
        dict: Processed observation with standardized format:
            - Required fields should be documented here
            - Any transformations that must be applied
    """
    raise NotImplementedError
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

raise NotImplementedError


@dataclass
class VLAgentArgs(ABC):
@property
@abstractmethod
def agent_name(self) -> str:
raise NotImplementedError

@abstractmethod
def make_agent(self) -> VLAgent:
raise NotImplementedError
Comment on lines +22 to +31
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixed Responsibilities in VLAgentArgs category Design

Tell me more
What is the issue?

VLAgentArgs is mixing the Factory pattern (make_agent) with configuration (agent_name) and lifecycle management (prepare, close) responsibilities, violating the Single Responsibility Principle.

Why this matters

This design makes the class harder to maintain and test, while reducing reusability. Lifecycle management, configuration, and object creation should be handled by separate components.

Suggested change ∙ Feature Preview

Split the class into separate components:

@dataclass
class VLAgentConfig:
    agent_name: str
    # other configuration parameters

class VLAgentFactory(ABC):
    @abstractmethod
    def make_agent(self, config: VLAgentConfig) -> VLAgent:
        pass

class VLAgentLifecycle(ABC):
    @abstractmethod
    def prepare(self):
        pass
    @abstractmethod
    def close(self):
        pass
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


@abstractmethod
def prepare(self):
raise NotImplementedError

@abstractmethod
def close(self):
raise NotImplementedError

@abstractmethod
def set_reproducibility_mode(self):
raise NotImplementedError

@abstractmethod
def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
raise NotImplementedError
145 changes: 145 additions & 0 deletions src/agentlab/agents/vl_agent/vl_agent/ui_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from agentlab.llm.llm_utils import ParseError, retry
from agentlab.llm.tracking import cost_tracker_decorator
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.agent import AgentInfo
from browsergym.experiments.benchmark import Benchmark
from browsergym.experiments.benchmark.base import HighLevelActionSetArgs
from browsergym.utils.obs import overlay_som
from copy import copy, deepcopy
from dataclasses import asdict, dataclass
from functools import cache
from typing import Optional
from .base import VLAgent, VLAgentArgs
from ..vl_model.base import VLModelArgs
from ..vl_prompt.ui_prompt import UIPromptArgs


class UIAgent(VLAgent):
def __init__(
self,
main_vl_model_args: VLModelArgs,
auxiliary_vl_model_args: Optional[VLModelArgs],
ui_prompt_args: UIPromptArgs,
action_set_args: HighLevelActionSetArgs,
max_retry: int,
):
self.main_vl_model = main_vl_model_args.make_model()
if auxiliary_vl_model_args is None:
self.auxiliary_vl_model = None
else:
self.auxiliary_vl_model = auxiliary_vl_model_args.make_model()
self.ui_prompt_args = ui_prompt_args
self.action_set_args = action_set_args
self.max_retry = max_retry
self.thoughts = []
self.actions = []

@property
@cache
def action_set(self) -> HighLevelActionSet:
return self.action_set_args.make_action_set()

@cost_tracker_decorator
def get_action(self, obs: dict) -> tuple[str, dict]:
ui_prompt = self.ui_prompt_args.make_prompt(
obs=obs, thoughts=self.thoughts, actions=self.actions, action_set=self.action_set
)
try:
messages = ui_prompt.get_messages()
answer = retry(
chat=self.main_vl_model,
messages=messages,
n_retry=self.max_retry,
parser=ui_prompt.parse_answer,
)
stats = {"num_main_retries": (len(messages) - 3) // 2}
except ParseError:
answer = {"thought": None, "action": None}
stats = {"num_main_retries": self.max_retry}
stats.update(self.main_vl_model.get_stats())
if self.auxiliary_vl_model is not None:
preliminary_answer = answer
ui_prompt = self.ui_prompt_args.make_prompt(
obs=obs,
thoughts=self.thoughts,
actions=self.actions,
action_set=self.action_set,
preliminary_answer=preliminary_answer,
)
try:
messages = ui_prompt.get_messages()
answer = retry(
chat=self.auxiliary_vl_model,
messages=messages,
n_retry=self.max_retry,
parser=ui_prompt.parse_answer,
)
stats["num_auxiliary_retries"] = (len(messages) - 3) // 2
except ParseError:
answer = {"thought": None, "action": None}
stats["num_auxiliary_retries"] = self.max_retry
stats.update(self.auxiliary_vl_model.get_stats())
else:
preliminary_answer = None
self.thoughts.append(str(answer["thought"]))
self.actions.append(str(answer["action"]))
agent_info = AgentInfo(
think=str(answer["thought"]), stats=stats, extra_info=preliminary_answer
)
return answer["action"], asdict(agent_info)

def obs_preprocessor(self, obs: dict) -> dict:
obs = copy(obs)
if self.ui_prompt_args.use_screenshot and self.ui_prompt_args.use_screenshot_som:
obs["screenshot"] = overlay_som(
obs["screenshot"], extra_properties=obs["extra_element_properties"]
)
return obs


@dataclass
class UIAgentArgs(VLAgentArgs):
main_vl_model_args: VLModelArgs
auxiliary_vl_model_args: Optional[VLModelArgs]
ui_prompt_args: UIPromptArgs
action_set_args: HighLevelActionSetArgs
max_retry: int

@property
@cache
def agent_name(self) -> str:
if self.auxiliary_vl_model_args is None:
return f"UIAgent-{self.main_vl_model_args.model_name}"
else:
return f"UIAgent-{self.main_vl_model_args.model_name}-{self.auxiliary_vl_model_args.model_name}"

def make_agent(self) -> UIAgent:
self.ui_agent = UIAgent(
main_vl_model_args=self.main_vl_model_args,
auxiliary_vl_model_args=self.auxiliary_vl_model_args,
ui_prompt_args=self.ui_prompt_args,
action_set_args=self.action_set_args,
max_retry=self.max_retry,
)
return self.ui_agent

def prepare(self):
self.main_vl_model_args.prepare()
if self.auxiliary_vl_model_args is not None:
self.auxiliary_vl_model_args.prepare()

def close(self):
self.main_vl_model_args.close()
if self.auxiliary_vl_model_args is not None:
self.auxiliary_vl_model_args.close()

def set_reproducibility_mode(self):
self.main_vl_model_args.set_reproducibility_mode()
if self.auxiliary_vl_model_args is not None:
self.auxiliary_vl_model_args.set_reproducibility_mode()

def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
self.ui_prompt_args.use_tabs = benchmark.is_multi_tab
self.action_set_args = deepcopy(benchmark.high_level_action_set_args)
if demo_mode:
self.action_set_args.demo_mode = "all_blue"
35 changes: 35 additions & 0 deletions src/agentlab/agents/vl_agent/vl_model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
from agentlab.llm.llm_utils import AIMessage, Discussion


class VLModel(ABC):
@abstractmethod
def __call__(self, messages: Discussion) -> AIMessage:
raise NotImplementedError

@abstractmethod
def get_stats(self) -> dict:
raise NotImplementedError


class VLModelArgs(ABC):
@property
@abstractmethod
def model_name(self) -> str:
raise NotImplementedError

@abstractmethod
def make_model(self) -> VLModel:
raise NotImplementedError

@abstractmethod
def prepare(self):
raise NotImplementedError

@abstractmethod
def close(self):
raise NotImplementedError

@abstractmethod
def set_reproducibility_mode(self):
raise NotImplementedError
Loading
Loading