Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,4 @@ miniwob-plusplus/
debugging_results/

# working files
main_miniwob_debug.py
main_workarena_debug.py
experiments/*
15 changes: 9 additions & 6 deletions src/agentlab/agents/generic_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
from .agent_configs import (
AGENT_3_5,
AGENT_8B,
AGENT_37_SONNET,
AGENT_CLAUDE_SONNET_35,
AGENT_CLAUDE_SONNET_35_VISION,
AGENT_CUSTOM,
AGENT_LLAMA4_17B_INSTRUCT,
AGENT_LLAMA3_70B,
AGENT_LLAMA4_17B_INSTRUCT,
AGENT_LLAMA31_70B,
CHAT_MODEL_ARGS_DICT,
RANDOM_SEARCH_AGENT,
AGENT_4o,
AGENT_4o_MINI,
AGENT_CLAUDE_SONNET_35,
AGENT_37_SONNET,
AGENT_CLAUDE_SONNET_35_VISION,
AGENT_4o_VISION,
AGENT_4o_MINI_VISION,
AGENT_o3_MINI,
AGENT_4o_VISION,
AGENT_o1_MINI,
AGENT_o3_MINI,
FLAGS_GPT_4o,
GenericAgentArgs,
)

__all__ = [
Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/agents/tool_use_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import sys

from agentlab.agents.tool_use_agent.tool_use_agent import *

# for backward compatibility of unpickling
sys.modules[__name__ + ".multi_tool_agent"] = sys.modules[__name__]
13 changes: 11 additions & 2 deletions src/agentlab/agents/tool_use_agent/tool_use_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:

AXTREE_NOTE = """
AXTree extracts most of the interactive elements of the DOM in a tree structure. It may also contain information that is not visible in the screenshot.
A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools.
A line starting with [bid] is a node in the AXTree. It is a unique alpha-numeric identifier to be used when calling tools, e.g, click(bid="a253"). Make sure to include letters and numbers in the bid.
"""


Expand Down Expand Up @@ -347,7 +347,7 @@ class PromptConfig:
task_hint: TaskHint = None
keep_last_n_obs: int = 1
multiaction: bool = False
action_subsets: tuple[str] = field(default_factory=lambda: ("coord",))
action_subsets: tuple[str] = None


@dataclass
Expand Down Expand Up @@ -498,6 +498,15 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

GPT_4_1_MINI = OpenAIResponseModelArgs(
model_name="gpt-4.1-mini",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=0.1,
vision_support=True,
)

OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs(
model_name="gpt-4o-2024-08-06",
max_total_tokens=200_000,
Expand Down
8 changes: 7 additions & 1 deletion src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,13 @@ def get_screenshot(
if annotate:
action_str = step_info.action
properties = step_info.obs.get("extra_element_properties", None)
action_colored = annotate_action(img, action_string=action_str, properties=properties)
try:
action_colored = annotate_action(
img, action_string=action_str, properties=properties
)
except Exception as e:
warning(f"Failed to annotate action: {e}")
action_colored = action_str
else:
action_colored = None
return img, action_colored
Expand Down
122 changes: 122 additions & 0 deletions src/agentlab/analyze/archive_studies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from dataclasses import dataclass
from pathlib import Path

import pandas as pd
from tqdm import tqdm

from agentlab.analyze import inspect_results
from agentlab.experiments.exp_utils import RESULTS_DIR
from agentlab.experiments.study import Study


@dataclass
class StudyInfo:
study_dir: Path
study: Study
summary_df: pd.DataFrame
should_delete: bool = False
reason: str = ""


def search_for_reasons_to_archive(result_dir: Path, min_study_size: int = 0) -> list[StudyInfo]:

study_info_list = []
study_dirs = list(result_dir.iterdir())
progress = tqdm(study_dirs, desc="Processing studies")
for study_dir in progress:

progress.set_postfix({"study_dir": study_dir})
if not study_dir.is_dir():
progress.set_postfix({"status": "skipped"})
continue

try:
study = Study.load(study_dir)
except Exception:
study = None
# get summary*.csv files and find the most recent
summary_files = list(study_dir.glob("summary*.csv"))

if len(summary_files) != 0:
most_recent_summary = max(summary_files, key=os.path.getctime)
summary_df = pd.read_csv(most_recent_summary)

else:
try:
result_df = inspect_results.load_result_df(study_dir, progress_fn=None)
summary_df = inspect_results.summarize_study(result_df)
except Exception as e:
print(f" Error processing {study_dir}: {e}")
continue

study_info = StudyInfo(
study_dir=study_dir,
study=study,
summary_df=summary_df,
)

if len(study_info.summary_df) == 0:
study_info.should_delete = True
study_info.reason = "Empty summary DataFrame"

n_completed, n_total, n_err = 0, 0, 0

for _, row in study_info.summary_df.iterrows():
n_comp, n_tot = row["n_completed"].split("/")
n_completed += int(n_comp)
n_total += int(n_tot)
n_err += int(row.get("n_err"))

n_finished = n_completed - n_err

# print(summary_df)
# print(f" {n_completed} / {n_total}, {n_err} errors")

if "miniwob-tiny-test" in study_dir.name:
study_info.should_delete = True
study_info.reason += "Miniwob tiny test\n"
if n_total == 0:
study_info.should_delete = True
study_info.reason += "No tasks\n"
if n_completed == 0:
study_info.should_delete = True
study_info.reason += "No tasks completed\n"
if float(n_finished) / float(n_total) < 0.5:
study_info.should_delete = True
study_info.reason += f"Less than 50% tasks finished, n_err: {n_err}, n_total: {n_total}, n_finished: {n_finished}, n_completed: {n_completed}\n"

if n_total <= min_study_size:
study_info.should_delete = True
study_info.reason += (
f"Too few tasks. n_total ({n_total}) <= min_study_size ({min_study_size})\n"
)

study_info_list.append(study_info)
return study_info_list


if __name__ == "__main__":
study_list_info = search_for_reasons_to_archive(RESULTS_DIR, min_study_size=5)
archive_dir = RESULTS_DIR.parent / "archived_agentlab_results" # type: Path
archive_dir.mkdir(parents=True, exist_ok=True)

# Uncomment the line below to prevent moving studies to archive
archive_dir = None

for study_info in study_list_info:
if not study_info.should_delete:
continue

print(f"Study: {study_info.study_dir.name}")
print(f" Reason: {study_info.reason}")
print(study_info.summary_df)
print()

if archive_dir is not None:
# move to new dir
new_path = archive_dir / study_info.study_dir.name
study_info.study_dir.rename(new_path)
# save reason in a file
reason_file = new_path / "reason_to_archive.txt"
reason_file.write_text(study_info.reason)
26 changes: 16 additions & 10 deletions src/agentlab/experiments/graph_execution_ray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# import os

# # Disable Ray log deduplication
# os.environ["RAY_DEDUP_LOGS"] = "0"
import logging
import time

Expand Down Expand Up @@ -90,12 +86,22 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter


def get_elapsed_time(task_ref: ray.ObjectRef):
task_id = task_ref.task_id().hex()
task_info = state.get_task(task_id, address="auto")
if task_info and task_info.start_time_ms is not None:
start_time_s = task_info.start_time_ms / 1000.0 # Convert ms to s
try:
task_id = task_ref.task_id().hex()
task_info = state.get_task(task_id, address="auto")
if not task_info:
return None
if not isinstance(task_info, list):
task_info = [task_info]

start_times_ms = [getattr(t, "start_time_ms", None) for t in task_info]
start_time_s = max([t / 1000.0 if t is not None else -1 for t in start_times_ms])
if start_time_s < 0:
return None # Task has not started yet

current_time_s = time.time()
elapsed_time = current_time_s - start_time_s
return elapsed_time
else:
return None # Task has not started yet
except Exception as e:
logger.warning(f"Could not get elapsed time for task {task_id}: {e}")
return None
8 changes: 6 additions & 2 deletions src/agentlab/experiments/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from PIL import Image
from tqdm import tqdm

from agentlab.agents.tapeagent import TapeAgent, save_tape
try:
from agentlab.agents.tapeagent import TapeAgent, save_tape
except ImportError:
TapeAgent = None


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -474,7 +478,7 @@ def run(self):
err_msg = f"Exception uncaught by agent or environment in task {self.env_args.task_name}.\n{type(e).__name__}:\n{e}"
logger.info("Saving experiment info.")
self.save_summary_info(episode_info, Path(self.exp_dir), err_msg, stack_trace)
if isinstance(agent, TapeAgent):
if TapeAgent is not None and isinstance(agent, TapeAgent):
task = getattr(env, "task", {})
save_tape(self.exp_dir, episode_info, task, agent.final_tape)
except Exception as e:
Expand Down
75 changes: 75 additions & 0 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from typing import Optional

import anthropic
import openai
from huggingface_hub import InferenceClient
from openai import AzureOpenAI, OpenAI
Expand Down Expand Up @@ -471,3 +472,77 @@ def __init__(
client_args={"base_url": "http://0.0.0.0:8000/v1"},
pricing_func=None,
)


class AnthropicChatModel(AbstractChatModel):
def __init__(
self,
model_name,
api_key=None,
temperature=0.5,
max_tokens=100,
max_retry=4,
log_probs=False,
):
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.max_retry = max_retry
self.log_probs = log_probs

api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic(api_key=api_key)

def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
# Convert OpenAI format to Anthropic format
system_message = None
anthropic_messages = []

for msg in messages:
if msg["role"] == "system":
system_message = msg["content"]
else:
anthropic_messages.append({"role": msg["role"], "content": msg["content"]})

temperature = temperature if temperature is not None else self.temperature

for attempt in range(self.max_retry):
try:
kwargs = {
"model": self.model_name,
"messages": anthropic_messages,
"max_tokens": self.max_tokens,
"temperature": temperature,
}

if system_message:
kwargs["system"] = system_message

response = self.client.messages.create(**kwargs)

# Track usage if available
if hasattr(tracking.TRACKER, "instance"):
tracking.TRACKER.instance(
response.usage.input_tokens,
response.usage.output_tokens,
0, # cost calculation would need pricing info
)

return AIMessage(response.content[0].text)

except Exception as e:
if attempt == self.max_retry - 1:
raise e
logging.warning(f"Anthropic API error (attempt {attempt + 1}): {e}")
time.sleep(60) # Simple retry delay


@dataclass
class AnthropicModelArgs(BaseModelArgs):
def make_model(self):
return AnthropicChatModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)
14 changes: 14 additions & 0 deletions src/agentlab/llm/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@
]

CHAT_MODEL_ARGS_DICT = {
"openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs(
model_name="gpt-4.1-mini-2025-04-14",
max_total_tokens=128_000,
max_input_tokens=128_000,
max_new_tokens=16_384,
vision_support=True,
),
"openai/gpt-4.1-2025-04-14": OpenAIModelArgs(
model_name="gpt-4.1-2025-04-14",
max_total_tokens=128_000,
max_input_tokens=128_000,
max_new_tokens=16_384,
vision_support=True,
),
"openai/o3-mini-2025-01-31": OpenAIModelArgs(
model_name="o3-mini-2025-01-31",
max_total_tokens=200_000,
Expand Down
Loading
Loading