Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 152 additions & 83 deletions src/agentlab/agents/tool_use_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,60 +5,28 @@

import bgym
import numpy as np
from browsergym.core.observation import extract_screenshot
from PIL import Image, ImageDraw

from agentlab.agents import agent_utils
from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.llm_utils import image_to_png_base64_url
from agentlab.llm.response_api import (
BaseModelArgs,
ClaudeResponseModelArgs,
MessageBuilder,
OpenAIChatModelArgs,
OpenAIResponseModelArgs,
OpenRouterModelArgs,
ResponseLLMOutput,
VLLMModelArgs,
)
from agentlab.llm.tracking import cost_tracker_decorator
from browsergym.core.observation import extract_screenshot

if TYPE_CHECKING:
from openai.types.responses import Response


def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
"""
If action is a coordinate action, try to render it on the screenshot.

e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot

Args:
screenshot: The screenshot to tag.
action: The action to tag the screenshot with.

Returns:
The tagged screenshot.

Raises:
ValueError: If the action parsing fails.
"""
if action.startswith("mouse_click"):
try:
coords = action[action.index("(") + 1 : action.index(")")].split(",")
coords = [c.strip() for c in coords]
if len(coords) != 2:
raise ValueError(f"Invalid coordinate format: {coords}")
if coords[0].startswith("x="):
coords[0] = coords[0][2:]
if coords[1].startswith("y="):
coords[1] = coords[1][2:]
x, y = float(coords[0].strip()), float(coords[1].strip())
draw = ImageDraw.Draw(screenshot)
radius = 5
draw.ellipse(
(x - radius, y - radius, x + radius, y + radius), fill="red", outline="red"
)
except (ValueError, IndexError) as e:
logging.warning(f"Failed to parse action '{action}': {e}")
return screenshot


@dataclass
class ToolUseAgentArgs(AgentArgs):
model_args: OpenAIResponseModelArgs = None
Expand Down Expand Up @@ -97,19 +65,9 @@ def __init__(
self.model_args = model_args
self.use_first_obs = use_first_obs
self.tag_screenshot = tag_screenshot

self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)

self.tools = self.action_set.to_tool_description(api=model_args.api)

# count tools tokens
from agentlab.llm.llm_utils import count_tokens

tool_str = json.dumps(self.tools, indent=2)
print(f"Tool description: {tool_str}")
tool_tokens = count_tokens(tool_str, model_args.model_name)
print(f"Tool tokens: {tool_tokens}")

self.call_ids = []

# self.tools.append(
Expand All @@ -131,7 +89,7 @@ def __init__(
# )

self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})

self.msg_builder = model_args.get_message_builder()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the static method from message builder could be in the llm object instead. this way we wouldn't need to implement this extra thing.

e.g. message = self.llm.tool().add_image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But we can push like this and move it after

self.messages: list[MessageBuilder] = []

def obs_preprocessor(self, obs):
Expand All @@ -140,7 +98,7 @@ def obs_preprocessor(self, obs):
obs["screenshot"] = extract_screenshot(page)
if self.tag_screenshot:
screenshot = Image.fromarray(obs["screenshot"])
screenshot = tag_screenshot_with_action(screenshot, obs["last_action"])
screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"])
obs["screenshot_tag"] = np.array(screenshot)
else:
raise ValueError("No page found in the observation.")
Expand All @@ -150,56 +108,31 @@ def obs_preprocessor(self, obs):
@cost_tracker_decorator
def get_action(self, obs: Any) -> float:
if len(self.messages) == 0:
system_message = MessageBuilder.system().add_text(
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
)
self.messages.append(system_message)

goal_message = MessageBuilder.user()
for content in obs["goal_object"]:
if content["type"] == "text":
goal_message.add_text(content["text"])
elif content["type"] == "image_url":
goal_message.add_image(content["image_url"])
self.messages.append(goal_message)

extra_info = []

extra_info.append(
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
)

self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info)))

if self.use_first_obs:
msg = "Here is the first observation."
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
if self.tag_screenshot:
msg += " A red dot on screenshots indicate the previous click action."
message = MessageBuilder.user().add_text(msg)
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
self.messages.append(message)
self.initalize_messages(obs)
else:
if obs["last_action_error"] == "":
if obs["last_action_error"] == "": # Check No error in the last action
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
tool_message = MessageBuilder.tool().add_image(
tool_message = self.msg_builder.tool().add_image(
image_to_png_base64_url(obs[screenshot_key])
)
tool_message.update_last_raw_response(self.last_response)
tool_message.add_tool_id(self.previous_call_id)
self.messages.append(tool_message)
else:
tool_message = MessageBuilder.tool().add_text(
tool_message = self.msg_builder.tool().add_text(
f"Function call failed: {obs['last_action_error']}"
)
tool_message.add_tool_id(self.previous_call_id)
tool_message.update_last_raw_response(self.last_response)
self.messages.append(tool_message)

response: ResponseLLMOutput = self.llm(messages=self.messages)

action = response.action
think = response.think
self.last_response = response
self.previous_call_id = response.last_computer_call_id
self.messages.append(response.assistant_message)
self.messages.append(response.assistant_message) # this is tool call

return (
action,
Expand All @@ -210,6 +143,37 @@ def get_action(self, obs: Any) -> float:
),
)

def initalize_messages(self, obs: Any) -> None:
system_message = self.msg_builder.system().add_text(
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
)
self.messages.append(system_message)

goal_message = self.msg_builder.user()
for content in obs["goal_object"]:
if content["type"] == "text":
goal_message.add_text(content["text"])
elif content["type"] == "image_url":
goal_message.add_image(content["image_url"])
self.messages.append(goal_message)

extra_info = []

extra_info.append(
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
)

self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info)))

if self.use_first_obs:
msg = "Here is the first observation."
screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot"
if self.tag_screenshot:
msg += " A red dot on screenshots indicate the previous click action."
message = self.msg_builder.user().add_text(msg)
message.add_image(image_to_png_base64_url(obs[screenshot_key]))
self.messages.append(message)


OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs(
model_name="gpt-4.1",
Expand All @@ -220,6 +184,14 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs(
model_name="gpt-4o-2024-08-06",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=0.1,
vision_support=True,
)

CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
model_name="claude-3-7-sonnet-20250219",
Expand All @@ -231,6 +203,103 @@ def get_action(self, obs: Any) -> float:
)




# def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs:
# default_model_args = {
# "max_total_tokens": 200_000,
# "max_input_tokens": 180_000,
# "max_new_tokens": 2_000,
# "temperature": 0.1,
# "vision_support": True,
# }
# merged_args = {**default_model_args, **open_router_args}

# return OpenRouterModelArgs(model_name=model_name, **merged_args)


# def get_openrouter_tool_use_agent(
# model_name: str,
# model_args: dict = {},
# use_first_obs=True,
# tag_screenshot=True,
# use_raw_page_output=True,
# ) -> ToolUseAgentArgs:
# # To Do : Check if OpenRouter endpoint specific args are working
# if not supports_tool_calling(model_name):
# raise ValueError(f"Model {model_name} does not support tool calling.")

# model_args = get_openrouter_model(model_name, **model_args)

# return ToolUseAgentArgs(
# model_args=model_args,
# use_first_obs=use_first_obs,
# tag_screenshot=tag_screenshot,
# use_raw_page_output=use_raw_page_output,
# )


# OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview")


AGENT_CONFIG = ToolUseAgentArgs(
model_args=CLAUDE_MODEL_CONFIG,
)

# MT_TOOL_USE_AGENT = ToolUseAgentArgs(
# model_args=OPENROUTER_MODEL,
# )
CHATAPI_AGENT_CONFIG = ToolUseAgentArgs(
model_args=OpenAIChatModelArgs(
model_name="gpt-4o-2024-11-20",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=0.7,
vision_support=True,
),
)


OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs(
model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06")
)
Comment on lines +264 to +266
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent Constant Naming category Readability

Tell me more
What is the issue?

Variable name contains a typo ('TOOl' instead of 'TOOL') and is inconsistent with other constant naming patterns.

Why this matters

Inconsistent naming makes the code harder to search for and understand, especially when the inconsistency is due to a typo.

Suggested change ∙ Feature Preview
OAI_CHAT_TOOL_AGENT = ToolUseAgentArgs(
    model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06")
)
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.



PROVIDER_FACTORY_MAP = {
"openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs},
"openrouter": OpenRouterModelArgs,
"vllm": VLLMModelArgs,
"antrophic": ClaudeResponseModelArgs,
}
Comment on lines +269 to +274
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provider Name Typo category Functionality

Tell me more
What is the issue?

The 'antrophic' key in PROVIDER_FACTORY_MAP contains a typo (should be 'anthropic').

Why this matters

This typo will cause runtime errors when trying to use Anthropic models as the provider key won't match.

Suggested change ∙ Feature Preview

Correct the spelling in the dictionary:

PROVIDER_FACTORY_MAP = {
    "openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs},
    "openrouter": OpenRouterModelArgs,
    "vllm": VLLMModelArgs,
    "anthropic": ClaudeResponseModelArgs,
}
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.



def get_tool_use_agent(
api_provider: str,
model_args: "BaseModelArgs",
tool_use_agent_args: dict = None,
api_provider_spec=None,
) -> ToolUseAgentArgs:

if api_provider == "openai":
assert (
api_provider_spec is not None
), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'."

model_args_factory = (
PROVIDER_FACTORY_MAP[api_provider]
if api_provider_spec is None
else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec]
)

# Create the agent with model arguments from the factory
agent = ToolUseAgentArgs(
model_args=model_args_factory(**model_args), **(tool_use_agent_args or {})
)
return agent


## We have three providers that we want to support.
# Anthropic
# OpenAI
# vllm (uses OpenAI API)
Loading
Loading