Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ dev = [
"ipykernel>=6.30.1",
"pip>=25.2",
]
hint = [
"sentence-transformers>=5.0.0",
]


[project.scripts]
Expand Down
190 changes: 177 additions & 13 deletions src/agentlab/agents/tool_use_agent/tool_use_agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import fnmatch
import json
import logging
import os
import random
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import copy
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, Literal

import bgym
import numpy as np
import pandas as pd
import requests
from bgym import Benchmark as BgymBenchmark
from browsergym.core.observation import extract_screenshot
from browsergym.utils.obs import (
Expand All @@ -34,6 +41,8 @@
)
from agentlab.llm.tracking import cost_tracker_decorator

logger = logging.getLogger(__name__)


@dataclass
class Block(ABC):
Expand Down Expand Up @@ -176,7 +185,6 @@ class Obs(Block):
def apply(
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
) -> dict:

obs_msg = llm.msg.user()
tool_calls = last_llm_output.tool_calls
if self.use_last_error:
Expand Down Expand Up @@ -298,22 +306,52 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
class TaskHint(Block):
use_task_hint: bool = True
hint_db_rel_path: str = "hint_db.csv"
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
top_n: int = 4 # Number of top hints to return when using embedding retrieval
embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints
embedder_server: str = "http://localhost:5000"
llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n
You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n
Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1."""

def _init(self):
"""Initialize the block."""
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
if Path(self.hint_db_rel_path).is_absolute():
hint_db_path = Path(self.hint_db_rel_path)
else:
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
if self.hint_retrieval_mode == "emb":
self.encode_hints()

def oai_embed(self, text: str):
response = self._oai_emb.create(input=text, model="text-embedding-3-small")
return response.data[0].embedding

def encode_hints(self):
self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first")
logger.info(
f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model."
)
hints = self.uniq_hints["hint"].tolist()
semantic_keys = self.uniq_hints["semantic_keys"].tolist()
lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)]
emb_path = f"{self.hint_db_rel_path}.embs.npy"
assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}"
logger.info(f"Loading hint embeddings from: {emb_path}")
emb_dict = np.load(emb_path, allow_pickle=True).item()
self.hint_embeddings = np.array([emb_dict[k] for k in lines])
logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}")

def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
if not self.use_task_hint:
return
return {}

task_hints = self.hint_db[
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
]
goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content])
task_hints = self.choose_hints(llm, task_name, goal)

hints = []
for hint in task_hints["hint"]:
for hint in task_hints:
hint = hint.strip()
if hint:
hints.append(f"- {hint}")
Expand All @@ -327,6 +365,94 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:

discussion.append(msg)

def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
"""Choose hints based on the task name."""
Comment on lines +368 to +369
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect TaskHint.choose_hints() docstring category Documentation

Tell me more
What is the issue?

The docstring is inaccurate as the method chooses hints based on task_name OR goal depending on hint_retrieval_mode, not just task_name.

Why this matters

Misleading docstring could cause confusion about the method's behavior when using different hint retrieval modes.

Suggested change ∙ Feature Preview
def choose_hints(self, llm, task_name: str, goal: str) -> list[str]:
    """Choose hints based on hint_retrieval_mode using task name or goal text."""
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

if self.hint_retrieval_mode == "llm":
return self.choose_hints_llm(llm, goal)
elif self.hint_retrieval_mode == "direct":
return self.choose_hints_direct(task_name)
elif self.hint_retrieval_mode == "emb":
return self.choose_hints_emb(goal)
else:
raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}")

def choose_hints_llm(self, llm, goal: str) -> list[str]:
"""Choose hints using LLM to filter the hints."""
topic_to_hints = defaultdict(list)
for i, row in self.hint_db.iterrows():
topic_to_hints[row["semantic_keys"]].append(i)
hint_topics = list(topic_to_hints.keys())
topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)])
prompt = self.llm_prompt.format(goal=goal, topics=topics)
response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)]))
try:
hint_topic_idx = json.loads(response.think)
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
return []
hint_topic = hint_topics[hint_topic_idx]
hint_indices = topic_to_hints[hint_topic]
df = self.hint_db.iloc[hint_indices].copy()
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
hints = df["hint"].tolist()
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
except json.JSONDecodeError:
logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints")
hints = []
Comment on lines +388 to +401
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incomplete Error Logging category Error Handling

Tell me more
What is the issue?

The error handling in choose_hints_llm() loses potentially useful error context by only logging the error message without including the original exception details.

Why this matters

Without the original exception details in the logs, debugging production issues will be more difficult as developers won't have access to the full stack trace and error context.

Suggested change ∙ Feature Preview

Include the exception details in the error logging using exc_info=True:

try:
    hint_topic_idx = json.loads(response.think)
    if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
        logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
        return []
    hint_topic = hint_topics[hint_topic_idx]
    hint_indices = topic_to_hints[hint_topic]
    df = self.hint_db.iloc[hint_indices].copy()
    df = df.drop_duplicates(subset=["hint"], keep="first")  # leave only unique hints
    hints = df["hint"].tolist()
    logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
except json.JSONDecodeError as e:
    logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints", exc_info=True)
    hints = []
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

return hints

def choose_hints_emb(self, goal: str) -> list[str]:
"""Choose hints using embeddings to filter the hints."""
goal_embeddings = self._encode([goal], prompt="task description")
similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist())
top_indices = similarities.argsort()[0][-self.top_n :].tolist()
logger.info(f"Top hint indices based on embedding similarity: {top_indices}")
hints = self.uniq_hints.iloc[top_indices]
logger.info(f"Embedding-based hints chosen: {hints}")
return hints["hint"].tolist()

def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5):
"""Call the encode API endpoint with timeout and retries"""
for attempt in range(max_retries):
try:
response = requests.post(
f"{self.embedder_server}/encode",
json={"texts": texts, "prompt": prompt},
timeout=timeout,
)
embs = response.json()["embeddings"]
return np.asarray(embs)
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
if attempt == max_retries - 1:
raise e
time.sleep(random.uniform(1, timeout))
continue

def _similarity(
self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5
):
"""Call the similarity API endpoint with timeout and retries"""
for attempt in range(max_retries):
try:
response = requests.post(
f"{self.embedder_server}/similarity",
json={"texts1": texts1, "texts2": texts2},
timeout=timeout,
)
similarities = response.json()["similarities"]
return np.asarray(similarities)
except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e:
if attempt == max_retries - 1:
raise e
time.sleep(random.uniform(1, timeout))
continue

def choose_hints_direct(self, task_name: str) -> list[str]:
hints = self.hint_db[
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
]
return hints["hint"].tolist()


@dataclass
class PromptConfig:
Expand Down Expand Up @@ -386,7 +512,8 @@ def __init__(
self.model_args = model_args
self.config = config
self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet(
self.config.action_subsets, multiaction=self.config.multiaction # type: ignore
self.config.action_subsets,
multiaction=self.config.multiaction, # type: ignore
)
self.tools = self.action_set.to_tool_description(api=model_args.api)

Expand Down Expand Up @@ -510,6 +637,15 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

GPT_4_1_CC_API = OpenAIChatModelArgs(
model_name="gpt-4.1",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=0.1,
vision_support=True,
)

GPT_5_mini = OpenAIChatModelArgs(
model_name="gpt-5-mini-2025-08-07",
max_total_tokens=400_000,
Expand Down Expand Up @@ -548,7 +684,7 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs(
CLAUDE_SONNET_37 = ClaudeResponseModelArgs(
model_name="claude-3-7-sonnet-20250219",
max_total_tokens=200_000,
max_input_tokens=200_000,
Expand All @@ -557,6 +693,15 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

CLAUDE_SONNET_4 = ClaudeResponseModelArgs(
model_name="claude-sonnet-4-20250514",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=0.1,
vision_support=True,
)

O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
model_name="o3-2025-04-16",
max_total_tokens=200_000,
Expand All @@ -574,6 +719,25 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)

GPT_5 = OpenAIChatModelArgs(
model_name="gpt-5",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=8_000,
temperature=None,
vision_support=True,
)


GPT_5_MINI = OpenAIChatModelArgs(
model_name="gpt-5-mini-2025-08-07",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=2_000,
temperature=1.0,
vision_support=True,
)

GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
model_name="openai/gpt-4.1",
max_total_tokens=200_000,
Expand All @@ -600,12 +764,12 @@ def get_action(self, obs: Any) -> float:
keep_last_n_obs=None,
multiaction=False, # whether to use multi-action or not
# action_subsets=("bid",),
action_subsets=("coord"),
action_subsets=("coord",),
# action_subsets=("coord", "bid"),
)

AGENT_CONFIG = ToolUseAgentArgs(
model_args=CLAUDE_MODEL_CONFIG,
model_args=CLAUDE_SONNET_37,
config=DEFAULT_PROMPT_CONFIG,
)

Expand Down Expand Up @@ -633,7 +797,7 @@ def get_action(self, obs: Any) -> float:
)

OSWORLD_CLAUDE = ToolUseAgentArgs(
model_args=CLAUDE_MODEL_CONFIG,
model_args=CLAUDE_SONNET_37,
config=PromptConfig(
tag_screenshot=True,
goal=Goal(goal_as_system_msg=True),
Expand Down
2 changes: 1 addition & 1 deletion src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def __init__(
min_retry_wait_time=min_retry_wait_time,
api_key_env_var="OPENAI_API_KEY",
client_class=OpenAI,
pricing_func=tracking.get_pricing_openai,
pricing_func=partial(tracking.get_pricing_litellm, model_name=model_name),
log_probs=log_probs,
)

Expand Down
12 changes: 8 additions & 4 deletions src/agentlab/llm/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def __call__(self, *args, **kwargs):
# 'self' here calls ._call_api() method of the subclass
response = self._call_api(*args, **kwargs)
usage = dict(getattr(response, "usage", {}))
if "prompt_tokens_details" in usage:
if "prompt_tokens_details" in usage and usage["prompt_tokens_details"]:
usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens
if "input_tokens_details" in usage:
if "input_tokens_details" in usage and usage["input_tokens_details"]:
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
usage |= {"n_api_calls": 1}
Expand Down Expand Up @@ -338,12 +338,16 @@ def get_effective_cost_from_openai_api(self, response) -> float:
if api_type == "chatcompletion":
total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens)
output_tokens = usage.completion_tokens
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
cached_input_tokens = (
usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0
)
new_input_tokens = total_input_tokens - cached_input_tokens
elif api_type == "response":
total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens)
output_tokens = usage.output_tokens
cached_input_tokens = usage.input_tokens_details.cached_tokens
cached_input_tokens = (
usage.input_tokens_details.cached_tokens if usage.input_tokens_details else 0
)
new_input_tokens = total_input_tokens - cached_input_tokens
else:
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")
Expand Down
Loading
Loading