Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,8 @@ results/
outputs/
miniwob-plusplus/
.miniwob-server.pid
debugging_results/
debugging_results/

# working files
main_miniwob_debug.py
main_workarena_debug.py
137 changes: 0 additions & 137 deletions src/agentlab/agents/agent_utils.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,6 @@
from logging import warning
from typing import Optional, Tuple

import numpy as np
from PIL import Image, ImageDraw
from playwright.sync_api import Page

"""
This module contains utility functions for handling observations and actions in the context of agent interactions.
"""


def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
"""
If action is a coordinate action, try to render it on the screenshot.

e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot

Args:
screenshot: The screenshot to tag.
action: The action to tag the screenshot with.

Returns:
The tagged screenshot.

Raises:
ValueError: If the action parsing fails.
"""
if action.startswith("mouse_click"):
try:
coords = action[action.index("(") + 1 : action.index(")")].split(",")
coords = [c.strip() for c in coords]
if len(coords) not in [2, 3]:
raise ValueError(f"Invalid coordinate format: {coords}")
if coords[0].startswith("x="):
coords[0] = coords[0][2:]
if coords[1].startswith("y="):
coords[1] = coords[1][2:]
x, y = float(coords[0].strip()), float(coords[1].strip())
draw = ImageDraw.Draw(screenshot)
radius = 5
draw.ellipse(
(x - radius, y - radius, x + radius, y + radius), fill="blue", outline="blue"
)
except (ValueError, IndexError) as e:
warning(f"Failed to parse action '{action}': {e}")

elif action.startswith("mouse_drag_and_drop"):
try:
func_name, parsed_args = parse_func_call_string(action)
if func_name == "mouse_drag_and_drop" and parsed_args is not None:
args, kwargs = parsed_args
x1, y1, x2, y2 = None, None, None, None

if args and len(args) >= 4:
# Positional arguments: mouse_drag_and_drop(x1, y1, x2, y2)
x1, y1, x2, y2 = map(float, args[:4])
elif kwargs:
# Keyword arguments: mouse_drag_and_drop(from_x=x1, from_y=y1, to_x=x2, to_y=y2)
x1 = float(kwargs.get("from_x", 0))
y1 = float(kwargs.get("from_y", 0))
x2 = float(kwargs.get("to_x", 0))
y2 = float(kwargs.get("to_y", 0))

if all(coord is not None for coord in [x1, y1, x2, y2]):
draw = ImageDraw.Draw(screenshot)
# Draw the main line
draw.line((x1, y1, x2, y2), fill="red", width=2)
# Draw arrowhead at the end point using the helper function
draw_arrowhead(draw, (x1, y1), (x2, y2))
except (ValueError, IndexError) as e:
warning(f"Failed to parse action '{action}': {e}")
return screenshot


def add_mouse_pointer_from_action(screenshot: Image, action: str) -> Image.Image:

if action.startswith("mouse_click"):
try:
coords = action[action.index("(") + 1 : action.index(")")].split(",")
coords = [c.strip() for c in coords]
if len(coords) not in [2, 3]:
raise ValueError(f"Invalid coordinate format: {coords}")
if coords[0].startswith("x="):
coords[0] = coords[0][2:]
if coords[1].startswith("y="):
coords[1] = coords[1][2:]
x, y = int(coords[0].strip()), int(coords[1].strip())
screenshot = draw_mouse_pointer(screenshot, x, y)
except (ValueError, IndexError) as e:
warning(f"Failed to parse action '{action}': {e}")
return screenshot


def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
"""
Expand Down Expand Up @@ -218,50 +128,3 @@ def zoom_webpage(page: Page, zoom_factor: float = 1.5):

page.evaluate(f"document.documentElement.style.zoom='{zoom_factor*100}%'")
return page


def parse_func_call_string(call_str: str) -> Tuple[Optional[str], Optional[Tuple[list, dict]]]:
"""
Parse a function call string and extract the function name and arguments.

Args:
call_str (str): A string like "mouse_click(100, 200)" or "mouse_drag_and_drop(x=10, y=20)"

Returns:
Tuple (func_name, (args, kwargs)), or (None, None) if parsing fails
"""
import ast

try:
tree = ast.parse(call_str.strip(), mode="eval")
if not isinstance(tree.body, ast.Call):
return None, None

call_node = tree.body

# Function name
if isinstance(call_node.func, ast.Name):
func_name = call_node.func.id
else:
return None, None

# Positional arguments
args = []
for arg in call_node.args:
try:
args.append(ast.literal_eval(arg))
except (ValueError, TypeError):
return None, None

# Keyword arguments
kwargs = {}
for kw in call_node.keywords:
try:
kwargs[kw.arg] = ast.literal_eval(kw.value)
except (ValueError, TypeError):
return None, None

return func_name, (args, kwargs)

except (SyntaxError, ValueError, TypeError):
return None, None
4 changes: 4 additions & 0 deletions src/agentlab/agents/tool_use_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import sys

# for backward compatibility of unpickling
sys.modules[__name__ + ".multi_tool_agent"] = sys.modules[__name__]
102 changes: 64 additions & 38 deletions src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from attr import dataclass
from langchain.schema import BaseMessage, HumanMessage
from openai import OpenAI
from openai.types.responses import ResponseFunctionToolCall
from PIL import Image

from agentlab.agents import agent_utils
from agentlab.analyze import inspect_results
from agentlab.analyze.overlay_utils import annotate_action
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments.loop import ExpResult, StepInfo
from agentlab.experiments.study import get_most_recent_study
Expand Down Expand Up @@ -351,7 +352,7 @@ def run_gradio(results_dir: Path):
pruned_html_code = gr.Code(language="html", **code_args)

with gr.Tab("AXTree") as tab_axtree:
axtree_code = gr.Code(language=None, **code_args)
axtree_code = gr.Markdown()

with gr.Tab("Chat Messages") as tab_chat:
chat_messages = gr.Markdown()
Expand Down Expand Up @@ -536,38 +537,45 @@ def wrapper(*args, **kwargs):

def update_screenshot(som_or_not: str):
global info
action = info.exp_result.steps_info[info.step].action
return agent_utils.tag_screenshot_with_action(
get_screenshot(info, som_or_not=som_or_not), action
)
img, action_str = get_screenshot(info, som_or_not=som_or_not, annotate=True)
return img


def get_screenshot(info: Info, step: int = None, som_or_not: str = "Raw Screenshots"):
def get_screenshot(
info: Info, step: int = None, som_or_not: str = "Raw Screenshots", annotate: bool = False
):
if step is None:
step = info.step
step_info = info.exp_result.steps_info[step]
try:
is_som = som_or_not == "SOM Screenshots"
return info.exp_result.get_screenshot(step, som=is_som)
img = info.exp_result.get_screenshot(step, som=is_som)
if annotate:
action_str = step_info.action
properties = step_info.obs.get("extra_element_properties", None)
action_colored = annotate_action(img, action_string=action_str, properties=properties)
else:
action_colored = None
return img, action_colored
except FileNotFoundError:
return None
return None, None


def update_screenshot_pair(som_or_not: str):
global info
s1 = get_screenshot(info, info.step, som_or_not)
s2 = get_screenshot(info, info.step + 1, som_or_not)

if s1 is not None:
s1 = agent_utils.tag_screenshot_with_action(
s1, info.exp_result.steps_info[info.step].action
)
s1, action_str = get_screenshot(info, info.step, som_or_not, annotate=True)
s2, action_str = get_screenshot(info, info.step + 1, som_or_not)
return s1, s2


def update_screenshot_gallery(som_or_not: str):
global info
screenshots = info.exp_result.get_screenshots(som=som_or_not == "SOM Screenshots")
max_steps = len(info.exp_result.steps_info)

screenshots = [get_screenshot(info, step=i, som_or_not=som_or_not)[0] for i in range(max_steps)]

This comment was marked as resolved.


screenshots_and_label = [(s, f"Step {i}") for i, s in enumerate(screenshots)]

gallery = gr.Gallery(
value=screenshots_and_label,
columns=2,
Expand Down Expand Up @@ -595,7 +603,8 @@ def update_pruned_html():


def update_axtree():
return get_obs(key="axtree_txt", default="No AXTree")
obs = get_obs(key="axtree_txt", default="No AXTree")
return f"```\n{obs}\n```"


def dict_to_markdown(d: dict):
Expand Down Expand Up @@ -645,7 +654,7 @@ def dict_msg_to_markdown(d: dict):
case "text":
parts.append(f"\n```\n{item['text']}\n```\n")
case "tool_use":
tool_use = f"Tool Use: {item['name']} {item['input']} (id = {item['id']})"
tool_use = _format_tool_call(item["name"], item["input"], item["call_id"])
parts.append(f"\n```\n{tool_use}\n```\n")
case _:
parts.append(f"\n```\n{str(item)}\n```\n")
Expand All @@ -655,27 +664,40 @@ def dict_msg_to_markdown(d: dict):
return markdown


def _format_tool_call(name: str, input: str, call_id: str):
"""
Format a tool call to markdown.
"""
return f"Tool Call: {name} `{input}` (call_id: {call_id})"


def format_chat_message(message: BaseMessage | MessageBuilder | dict):
"""
Format a message to markdown.
"""
if isinstance(message, BaseMessage):
return message.content
elif isinstance(message, MessageBuilder):
return message.to_markdown()
elif isinstance(message, dict):
return dict_msg_to_markdown(message)
elif isinstance(message, ResponseFunctionToolCall): # type: ignore[return]
too_use_str = _format_tool_call(message.name, message.arguments, message.call_id)
return f"### Tool Use\n```\n{too_use_str}\n```\n"
else:
return str(message)


def update_chat_messages():
global info
agent_info = info.exp_result.steps_info[info.step].agent_info
chat_messages = agent_info.get("chat_messages", ["No Chat Messages"])
if isinstance(chat_messages, Discussion):
return chat_messages.to_markdown()

if isinstance(chat_messages, list) and isinstance(chat_messages[0], MessageBuilder):
chat_messages = [
m.to_markdown() if isinstance(m, MessageBuilder) else dict_msg_to_markdown(m)
for m in chat_messages
]
if isinstance(chat_messages, list):
chat_messages = [format_chat_message(m) for m in chat_messages]
return "\n\n".join(chat_messages)
messages = [] # TODO(ThibaultLSDC) remove this at some point
for i, m in enumerate(chat_messages):
if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated
m = m.content
elif isinstance(m, dict):
m = m.get("content", "No Content")
messages.append(f"""# Message {i}\n```\n{m}\n```\n\n""")
return "\n".join(messages)


def update_task_error():
Expand Down Expand Up @@ -722,8 +744,8 @@ def update_agent_info_html():
global info
# screenshots from current and next step
try:
s1 = get_screenshot(info, info.step, False)
s2 = get_screenshot(info, info.step + 1, False)
s1, action_str = get_screenshot(info, info.step, False)
s2, action_str = get_screenshot(info, info.step + 1, False)
agent_info = info.exp_result.steps_info[info.step].agent_info
page = agent_info.get("html_page", ["No Agent Info"])
if page is None:
Expand Down Expand Up @@ -854,6 +876,8 @@ def get_episode_info(info: Info):

def get_action_info(info: Info):
steps_info = info.exp_result.steps_info
img, action_str = get_screenshot(info, step=info.step, annotate=True) # to update click_mapper

if len(steps_info) == 0:
return "No steps were taken"
if len(steps_info) <= info.step:
Expand All @@ -863,7 +887,7 @@ def get_action_info(info: Info):
action_info = f"""\
**Action:**

{code(step_info.action)}
{action_str}
"""
think = step_info.agent_info.get("think", None)
if think is not None:
Expand Down Expand Up @@ -1084,16 +1108,19 @@ def get_directory_contents(results_dir: Path):
most_recent_summary = max(summary_files, key=os.path.getctime)
summary_df = pd.read_csv(most_recent_summary)

if len(summary_df) == 0 or summary_df["avg_reward"].isna().all():
continue # skip if all avg_reward are NaN

# get row with max avg_reward
max_reward_row = summary_df.loc[summary_df["avg_reward"].idxmax()]
max_reward_row = summary_df.loc[summary_df["avg_reward"].idxmax(skipna=True)]
reward = max_reward_row["avg_reward"] * 100
completed = max_reward_row["n_completed"]
n_err = max_reward_row["n_err"]
exp_description += (
f" - avg-reward: {reward:.1f}% - completed: {completed} - errors: {n_err}"
)
except Exception as e:
print(f"Error while reading summary file: {e}")
print(f"Error while reading summary file {most_recent_summary}: {e}")

exp_descriptions.append(exp_description)

Expand Down Expand Up @@ -1219,7 +1246,6 @@ def plot_profiling(ax, step_info_list: list[StepInfo], summary_info: dict, progr
horizontalalignment="left",
rotation=0,
clip_on=True,
antialiased=True,
fontweight=1000,
backgroundcolor=colors[12],
)
Expand Down
Loading