From 7af2d152adeed3d48c7a074e175125823e76f203 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 13 Aug 2025 12:16:50 +0200 Subject: [PATCH 1/4] fixes --- .../agents/tool_use_agent/tool_use_agent.py | 139 ++++++++++++++++-- src/agentlab/analyze/agent_xray.py | 2 +- src/agentlab/llm/tracking.py | 12 +- 3 files changed, 137 insertions(+), 16 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 6ac61180..b1407a87 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -1,10 +1,12 @@ import fnmatch import json +import logging from abc import ABC, abstractmethod +from collections import defaultdict from copy import copy from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Literal import bgym import pandas as pd @@ -16,6 +18,7 @@ overlay_som, prune_html, ) +from sentence_transformers import SentenceTransformer from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark @@ -34,6 +37,8 @@ ) from agentlab.llm.tracking import cost_tracker_decorator +logger = logging.getLogger(__name__) + @dataclass class Block(ABC): @@ -298,22 +303,45 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict: class TaskHint(Block): use_task_hint: bool = True hint_db_rel_path: str = "hint_db.csv" + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" + top_n: int = 4 # Number of top hints to return when using embedding retrieval + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" def _init(self): """Initialize the block.""" - hint_db_path = Path(__file__).parent / self.hint_db_rel_path + if Path(self.hint_db_rel_path).is_absolute(): + hint_db_path = Path(self.hint_db_rel_path) + else: + hint_db_path = Path(__file__).parent / self.hint_db_rel_path self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) + if self.hint_retrieval_mode == "emb": + logger.info("Load sentence transformer model for hint embeddings.") + self.emb_model = SentenceTransformer( + "Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"} + ) + self.encode_hints() + + def encode_hints(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model." + ) + self.hint_embeddings = self.emb_model.encode( + self.uniq_hints["hint"].tolist(), prompt="task hint" + ) def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: - return + return {} - task_hints = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] + goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) + task_hints = self.choose_hints(llm, task_name, goal) hints = [] - for hint in task_hints["hint"]: + for hint in task_hints: hint = hint.strip() if hint: hints.append(f"- {hint}") @@ -327,6 +355,58 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: + """Choose hints based on the task name.""" + if self.hint_retrieval_mode == "llm": + return self.choose_hints_llm(llm, goal) + elif self.hint_retrieval_mode == "direct": + return self.choose_hints_direct(task_name) + elif self.hint_retrieval_mode == "emb": + return self.choose_hints_emb(goal) + else: + raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") + + def choose_hints_llm(self, llm, goal: str) -> list[str]: + """Choose hints using LLM to filter the hints.""" + topic_to_hints = defaultdict(list) + for i, row in self.hint_db.iterrows(): + topic_to_hints[row["semantic_keys"]].append(i) + hint_topics = list(topic_to_hints.keys()) + topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) + prompt = self.llm_prompt.format(goal=goal, topics=topics) + response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) + try: + hint_topic_idx = json.loads(response.think) + if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response.think}, no hints") + return [] + hint_topic = hint_topics[hint_topic_idx] + hint_indices = topic_to_hints[hint_topic] + df = self.hint_db.iloc[hint_indices].copy() + df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints + hints = df["hint"].tolist() + logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") + except json.JSONDecodeError: + logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") + hints = [] + return hints + + def choose_hints_emb(self, goal: str) -> list[str]: + """Choose hints using embeddings to filter the hints.""" + goal_embeddings = self.emb_model.encode([goal], prompt="task description") + similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = self.uniq_hints.iloc[top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + return hints["hint"].tolist() + + def choose_hints_direct(self, task_name: str) -> list[str]: + hints = self.hint_db[ + self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) + ] + return hints["hint"].tolist() + @dataclass class PromptConfig: @@ -510,6 +590,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +GPT_4_1_CC_API = OpenAIChatModelArgs( + model_name="gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + GPT_4_1_MINI = OpenAIResponseModelArgs( model_name="gpt-4.1-mini", max_total_tokens=200_000, @@ -528,7 +617,7 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) -CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( +CLAUDE_SONNET_37 = ClaudeResponseModelArgs( model_name="claude-3-7-sonnet-20250219", max_total_tokens=200_000, max_input_tokens=200_000, @@ -537,6 +626,15 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +CLAUDE_SONNET_4 = ClaudeResponseModelArgs( + model_name="claude-sonnet-4-20250514", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) + O3_RESPONSE_MODEL = OpenAIResponseModelArgs( model_name="o3-2025-04-16", max_total_tokens=200_000, @@ -554,6 +652,25 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +GPT_5 = OpenAIChatModelArgs( + model_name="gpt-5", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=None, + vision_support=True, +) + + +GPT_5_MINI = OpenAIChatModelArgs( + model_name="gpt-5-mini-2025-08-07", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=1.0, + vision_support=True, +) + GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs( model_name="openai/gpt-4.1", max_total_tokens=200_000, @@ -580,12 +697,12 @@ def get_action(self, obs: Any) -> float: keep_last_n_obs=None, multiaction=True, # whether to use multi-action or not # action_subsets=("bid",), - action_subsets=("coord"), + action_subsets=("coord",), # action_subsets=("coord", "bid"), ) AGENT_CONFIG = ToolUseAgentArgs( - model_args=CLAUDE_MODEL_CONFIG, + model_args=CLAUDE_SONNET_37, config=DEFAULT_PROMPT_CONFIG, ) @@ -605,7 +722,7 @@ def get_action(self, obs: Any) -> float: ) OSWORLD_CLAUDE = ToolUseAgentArgs( - model_args=CLAUDE_MODEL_CONFIG, + model_args=CLAUDE_SONNET_37, config=PromptConfig( tag_screenshot=True, goal=Goal(goal_as_system_msg=True), diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 84dc423d..37ead1c3 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -735,7 +735,7 @@ def dict_msg_to_markdown(d: dict): case _: parts.append(f"\n```\n{str(item)}\n```\n") - markdown = f"### {d["role"].capitalize()}\n" + markdown = f"### {d['role'].capitalize()}\n" markdown += "\n".join(parts) return markdown diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index e761a7f6..afcf5e07 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -178,9 +178,9 @@ def __call__(self, *args, **kwargs): # 'self' here calls ._call_api() method of the subclass response = self._call_api(*args, **kwargs) usage = dict(getattr(response, "usage", {})) - if "prompt_tokens_details" in usage: + if "prompt_tokens_details" in usage and usage["prompt_tokens_details"]: usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens - if "input_tokens_details" in usage: + if "input_tokens_details" in usage and usage["input_tokens_details"]: usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))} usage |= {"n_api_calls": 1} @@ -332,12 +332,16 @@ def get_effective_cost_from_openai_api(self, response) -> float: if api_type == "chatcompletion": total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens) output_tokens = usage.completion_tokens - cached_input_tokens = usage.prompt_tokens_details.cached_tokens + cached_input_tokens = ( + usage.prompt_tokens_details.cached_tokens if usage.prompt_tokens_details else 0 + ) new_input_tokens = total_input_tokens - cached_input_tokens elif api_type == "response": total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens) output_tokens = usage.output_tokens - cached_input_tokens = usage.input_tokens_details.cached_tokens + cached_input_tokens = ( + usage.input_tokens_details.cached_tokens if usage.input_tokens_details else 0 + ) new_input_tokens = total_input_tokens - cached_input_tokens else: logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.") From 3f9e4a2191f81d1a177e9be3d6eea734924754cd Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Wed, 13 Aug 2025 12:16:57 +0200 Subject: [PATCH 2/4] add new deps --- requirements.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index 6322ffd3..a2798f2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,3 +27,5 @@ ray[default] python-slugify pillow gymnasium>=0.27 +sentence-transformers>=5.0.0 +python-dotenv>=1.1.1 \ No newline at end of file From c88d7f3fd0f7942e700d5d79ee79555f45cf3f6b Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 19 Aug 2025 14:14:38 +0200 Subject: [PATCH 3/4] use external embedding service in task hints retrieval --- .../agents/tool_use_agent/tool_use_agent.py | 75 +++++++++++++++---- 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b1407a87..f6ace3a8 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -1,6 +1,9 @@ import fnmatch import json import logging +import os +import random +import time from abc import ABC, abstractmethod from collections import defaultdict from copy import copy @@ -9,7 +12,9 @@ from typing import Any, Literal import bgym +import numpy as np import pandas as pd +import requests from bgym import Benchmark as BgymBenchmark from browsergym.core.observation import extract_screenshot from browsergym.utils.obs import ( @@ -18,7 +23,6 @@ overlay_som, prune_html, ) -from sentence_transformers import SentenceTransformer from agentlab.agents.agent_args import AgentArgs from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark @@ -181,7 +185,6 @@ class Obs(Block): def apply( self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput ) -> dict: - obs_msg = llm.msg.user() tool_calls = last_llm_output.tool_calls if self.use_last_error: @@ -306,6 +309,7 @@ class TaskHint(Block): hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" top_n: int = 4 # Number of top hints to return when using embedding retrieval embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints + embedder_server: str = "http://localhost:5000" llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" @@ -318,20 +322,26 @@ def _init(self): hint_db_path = Path(__file__).parent / self.hint_db_rel_path self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) if self.hint_retrieval_mode == "emb": - logger.info("Load sentence transformer model for hint embeddings.") - self.emb_model = SentenceTransformer( - "Qwen/Qwen3-Embedding-0.6B", model_kwargs={"torch_dtype": "bfloat16"} - ) self.encode_hints() + def oai_embed(self, text: str): + response = self._oai_emb.create(input=text, model="text-embedding-3-small") + return response.data[0].embedding + def encode_hints(self): self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") logger.info( - f"Encoding {len(self.uniq_hints)} unique hints using {self.embedder_model} model." - ) - self.hint_embeddings = self.emb_model.encode( - self.uniq_hints["hint"].tolist(), prompt="task hint" + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_rel_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: @@ -393,14 +403,50 @@ def choose_hints_llm(self, llm, goal: str) -> list[str]: def choose_hints_emb(self, goal: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" - goal_embeddings = self.emb_model.encode([goal], prompt="task description") - similarities = self.emb_model.similarity(goal_embeddings, self.hint_embeddings) + goal_embeddings = self._encode([goal], prompt="task description") + similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) top_indices = similarities.argsort()[0][-self.top_n :].tolist() logger.info(f"Top hint indices based on embedding similarity: {top_indices}") hints = self.uniq_hints.iloc[top_indices] logger.info(f"Embedding-based hints chosen: {hints}") return hints["hint"].tolist() + def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): + """Call the encode API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/encode", + json={"texts": texts, "prompt": prompt}, + timeout=timeout, + ) + embs = response.json()["embeddings"] + return np.asarray(embs) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + + def _similarity( + self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 + ): + """Call the similarity API endpoint with timeout and retries""" + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.embedder_server}/similarity", + json={"texts1": texts1, "texts2": texts2}, + timeout=timeout, + ) + similarities = response.json()["similarities"] + return np.asarray(similarities) + except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: + if attempt == max_retries - 1: + raise e + time.sleep(random.uniform(1, timeout)) + continue + def choose_hints_direct(self, task_name: str) -> list[str]: hints = self.hint_db[ self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) @@ -466,7 +512,8 @@ def __init__( self.model_args = model_args self.config = config self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet( - self.config.action_subsets, multiaction=self.config.multiaction # type: ignore + self.config.action_subsets, + multiaction=self.config.multiaction, # type: ignore ) self.tools = self.action_set.to_tool_description(api=model_args.api) @@ -656,7 +703,7 @@ def get_action(self, obs: Any) -> float: model_name="gpt-5", max_total_tokens=200_000, max_input_tokens=200_000, - max_new_tokens=2_000, + max_new_tokens=8_000, temperature=None, vision_support=True, ) From 74fc47f2820ec6dde79035a4d3bb5e5949d2c2bf Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Tue, 19 Aug 2025 14:14:49 +0200 Subject: [PATCH 4/4] gpt5 fixes --- src/agentlab/llm/chat_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index ff341356..dc9667b5 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -292,7 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float messages=messages, n=n_samples, temperature=temperature, - max_tokens=self.max_tokens, + max_completion_tokens=self.max_tokens, logprobs=self.log_probs, ) @@ -359,7 +359,7 @@ def __init__( min_retry_wait_time=min_retry_wait_time, api_key_env_var="OPENAI_API_KEY", client_class=OpenAI, - pricing_func=tracking.get_pricing_openai, + pricing_func=partial(tracking.get_pricing_litellm, model_name=model_name), log_probs=log_probs, )