-
Notifications
You must be signed in to change notification settings - Fork 105
Hints retrieval in tool use agent #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,20 @@ | ||
| import fnmatch | ||
| import json | ||
| import logging | ||
| import os | ||
| import random | ||
| import time | ||
| from abc import ABC, abstractmethod | ||
| from collections import defaultdict | ||
| from copy import copy | ||
| from dataclasses import asdict, dataclass, field | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from typing import Any, Literal | ||
|
|
||
| import bgym | ||
| import numpy as np | ||
| import pandas as pd | ||
| import requests | ||
| from bgym import Benchmark as BgymBenchmark | ||
| from browsergym.core.observation import extract_screenshot | ||
| from browsergym.utils.obs import ( | ||
|
|
@@ -34,6 +41,8 @@ | |
| ) | ||
| from agentlab.llm.tracking import cost_tracker_decorator | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Block(ABC): | ||
|
|
@@ -176,7 +185,6 @@ class Obs(Block): | |
| def apply( | ||
| self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput | ||
| ) -> dict: | ||
|
|
||
| obs_msg = llm.msg.user() | ||
| tool_calls = last_llm_output.tool_calls | ||
| if self.use_last_error: | ||
|
|
@@ -298,22 +306,52 @@ def apply_init(self, llm, discussion: StructuredDiscussion) -> dict: | |
| class TaskHint(Block): | ||
| use_task_hint: bool = True | ||
| hint_db_rel_path: str = "hint_db.csv" | ||
| hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" | ||
| top_n: int = 4 # Number of top hints to return when using embedding retrieval | ||
| embedder_model: str = "Qwen/Qwen3-Embedding-0.6B" # Model for embedding hints | ||
| embedder_server: str = "http://localhost:5000" | ||
| llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n | ||
| You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n | ||
| Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""" | ||
|
|
||
| def _init(self): | ||
| """Initialize the block.""" | ||
| hint_db_path = Path(__file__).parent / self.hint_db_rel_path | ||
| if Path(self.hint_db_rel_path).is_absolute(): | ||
| hint_db_path = Path(self.hint_db_rel_path) | ||
| else: | ||
| hint_db_path = Path(__file__).parent / self.hint_db_rel_path | ||
| self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) | ||
| if self.hint_retrieval_mode == "emb": | ||
| self.encode_hints() | ||
|
|
||
| def oai_embed(self, text: str): | ||
| response = self._oai_emb.create(input=text, model="text-embedding-3-small") | ||
| return response.data[0].embedding | ||
|
|
||
| def encode_hints(self): | ||
| self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") | ||
| logger.info( | ||
| f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." | ||
| ) | ||
| hints = self.uniq_hints["hint"].tolist() | ||
| semantic_keys = self.uniq_hints["semantic_keys"].tolist() | ||
| lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] | ||
| emb_path = f"{self.hint_db_rel_path}.embs.npy" | ||
| assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" | ||
| logger.info(f"Loading hint embeddings from: {emb_path}") | ||
| emb_dict = np.load(emb_path, allow_pickle=True).item() | ||
| self.hint_embeddings = np.array([emb_dict[k] for k in lines]) | ||
| logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") | ||
|
|
||
| def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: | ||
| if not self.use_task_hint: | ||
| return | ||
| return {} | ||
|
|
||
| task_hints = self.hint_db[ | ||
| self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) | ||
| ] | ||
| goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) | ||
| task_hints = self.choose_hints(llm, task_name, goal) | ||
|
|
||
| hints = [] | ||
| for hint in task_hints["hint"]: | ||
| for hint in task_hints: | ||
| hint = hint.strip() | ||
| if hint: | ||
| hints.append(f"- {hint}") | ||
|
|
@@ -327,6 +365,94 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: | |
|
|
||
| discussion.append(msg) | ||
|
|
||
| def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: | ||
| """Choose hints based on the task name.""" | ||
| if self.hint_retrieval_mode == "llm": | ||
| return self.choose_hints_llm(llm, goal) | ||
| elif self.hint_retrieval_mode == "direct": | ||
| return self.choose_hints_direct(task_name) | ||
| elif self.hint_retrieval_mode == "emb": | ||
| return self.choose_hints_emb(goal) | ||
| else: | ||
| raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") | ||
|
|
||
| def choose_hints_llm(self, llm, goal: str) -> list[str]: | ||
| """Choose hints using LLM to filter the hints.""" | ||
| topic_to_hints = defaultdict(list) | ||
| for i, row in self.hint_db.iterrows(): | ||
| topic_to_hints[row["semantic_keys"]].append(i) | ||
| hint_topics = list(topic_to_hints.keys()) | ||
| topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) | ||
| prompt = self.llm_prompt.format(goal=goal, topics=topics) | ||
| response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) | ||
| try: | ||
| hint_topic_idx = json.loads(response.think) | ||
| if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): | ||
| logger.error(f"Wrong LLM hint id response: {response.think}, no hints") | ||
| return [] | ||
| hint_topic = hint_topics[hint_topic_idx] | ||
| hint_indices = topic_to_hints[hint_topic] | ||
| df = self.hint_db.iloc[hint_indices].copy() | ||
| df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints | ||
| hints = df["hint"].tolist() | ||
| logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") | ||
| except json.JSONDecodeError: | ||
| logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") | ||
| hints = [] | ||
|
Comment on lines
+388
to
+401
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incomplete Error Logging
Tell me moreWhat is the issue?The error handling in choose_hints_llm() loses potentially useful error context by only logging the error message without including the original exception details. Why this mattersWithout the original exception details in the logs, debugging production issues will be more difficult as developers won't have access to the full stack trace and error context. Suggested change ∙ Feature PreviewInclude the exception details in the error logging using exc_info=True: try:
hint_topic_idx = json.loads(response.think)
if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics):
logger.error(f"Wrong LLM hint id response: {response.think}, no hints")
return []
hint_topic = hint_topics[hint_topic_idx]
hint_indices = topic_to_hints[hint_topic]
df = self.hint_db.iloc[hint_indices].copy()
df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints
hints = df["hint"].tolist()
logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints", exc_info=True)
hints = []Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| return hints | ||
|
|
||
| def choose_hints_emb(self, goal: str) -> list[str]: | ||
| """Choose hints using embeddings to filter the hints.""" | ||
| goal_embeddings = self._encode([goal], prompt="task description") | ||
| similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) | ||
| top_indices = similarities.argsort()[0][-self.top_n :].tolist() | ||
| logger.info(f"Top hint indices based on embedding similarity: {top_indices}") | ||
| hints = self.uniq_hints.iloc[top_indices] | ||
| logger.info(f"Embedding-based hints chosen: {hints}") | ||
| return hints["hint"].tolist() | ||
|
|
||
| def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): | ||
| """Call the encode API endpoint with timeout and retries""" | ||
| for attempt in range(max_retries): | ||
| try: | ||
| response = requests.post( | ||
| f"{self.embedder_server}/encode", | ||
| json={"texts": texts, "prompt": prompt}, | ||
| timeout=timeout, | ||
| ) | ||
| embs = response.json()["embeddings"] | ||
| return np.asarray(embs) | ||
| except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: | ||
| if attempt == max_retries - 1: | ||
| raise e | ||
| time.sleep(random.uniform(1, timeout)) | ||
| continue | ||
|
|
||
| def _similarity( | ||
| self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 | ||
| ): | ||
| """Call the similarity API endpoint with timeout and retries""" | ||
| for attempt in range(max_retries): | ||
| try: | ||
| response = requests.post( | ||
| f"{self.embedder_server}/similarity", | ||
| json={"texts1": texts1, "texts2": texts2}, | ||
| timeout=timeout, | ||
| ) | ||
| similarities = response.json()["similarities"] | ||
| return np.asarray(similarities) | ||
| except (requests.exceptions.RequestException, requests.exceptions.Timeout) as e: | ||
| if attempt == max_retries - 1: | ||
| raise e | ||
| time.sleep(random.uniform(1, timeout)) | ||
| continue | ||
|
|
||
| def choose_hints_direct(self, task_name: str) -> list[str]: | ||
| hints = self.hint_db[ | ||
| self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) | ||
| ] | ||
| return hints["hint"].tolist() | ||
|
|
||
|
|
||
| @dataclass | ||
| class PromptConfig: | ||
|
|
@@ -386,7 +512,8 @@ def __init__( | |
| self.model_args = model_args | ||
| self.config = config | ||
| self.action_set: bgym.AbstractActionSet = action_set or bgym.HighLevelActionSet( | ||
| self.config.action_subsets, multiaction=self.config.multiaction # type: ignore | ||
| self.config.action_subsets, | ||
| multiaction=self.config.multiaction, # type: ignore | ||
| ) | ||
| self.tools = self.action_set.to_tool_description(api=model_args.api) | ||
|
|
||
|
|
@@ -510,6 +637,15 @@ def get_action(self, obs: Any) -> float: | |
| vision_support=True, | ||
| ) | ||
|
|
||
| GPT_4_1_CC_API = OpenAIChatModelArgs( | ||
| model_name="gpt-4.1", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=2_000, | ||
| temperature=0.1, | ||
| vision_support=True, | ||
| ) | ||
|
|
||
| GPT_5_mini = OpenAIChatModelArgs( | ||
| model_name="gpt-5-mini-2025-08-07", | ||
| max_total_tokens=400_000, | ||
|
|
@@ -548,7 +684,7 @@ def get_action(self, obs: Any) -> float: | |
| vision_support=True, | ||
| ) | ||
|
|
||
| CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( | ||
| CLAUDE_SONNET_37 = ClaudeResponseModelArgs( | ||
| model_name="claude-3-7-sonnet-20250219", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
|
|
@@ -557,6 +693,15 @@ def get_action(self, obs: Any) -> float: | |
| vision_support=True, | ||
| ) | ||
|
|
||
| CLAUDE_SONNET_4 = ClaudeResponseModelArgs( | ||
| model_name="claude-sonnet-4-20250514", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=2_000, | ||
| temperature=0.1, | ||
| vision_support=True, | ||
| ) | ||
|
|
||
| O3_RESPONSE_MODEL = OpenAIResponseModelArgs( | ||
| model_name="o3-2025-04-16", | ||
| max_total_tokens=200_000, | ||
|
|
@@ -574,6 +719,25 @@ def get_action(self, obs: Any) -> float: | |
| vision_support=True, | ||
| ) | ||
|
|
||
| GPT_5 = OpenAIChatModelArgs( | ||
| model_name="gpt-5", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=8_000, | ||
| temperature=None, | ||
| vision_support=True, | ||
| ) | ||
|
|
||
|
|
||
| GPT_5_MINI = OpenAIChatModelArgs( | ||
| model_name="gpt-5-mini-2025-08-07", | ||
| max_total_tokens=200_000, | ||
| max_input_tokens=200_000, | ||
| max_new_tokens=2_000, | ||
| temperature=1.0, | ||
| vision_support=True, | ||
| ) | ||
|
|
||
| GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs( | ||
| model_name="openai/gpt-4.1", | ||
| max_total_tokens=200_000, | ||
|
|
@@ -600,12 +764,12 @@ def get_action(self, obs: Any) -> float: | |
| keep_last_n_obs=None, | ||
| multiaction=False, # whether to use multi-action or not | ||
| # action_subsets=("bid",), | ||
| action_subsets=("coord"), | ||
| action_subsets=("coord",), | ||
| # action_subsets=("coord", "bid"), | ||
| ) | ||
|
|
||
| AGENT_CONFIG = ToolUseAgentArgs( | ||
| model_args=CLAUDE_MODEL_CONFIG, | ||
| model_args=CLAUDE_SONNET_37, | ||
| config=DEFAULT_PROMPT_CONFIG, | ||
| ) | ||
|
|
||
|
|
@@ -633,7 +797,7 @@ def get_action(self, obs: Any) -> float: | |
| ) | ||
|
|
||
| OSWORLD_CLAUDE = ToolUseAgentArgs( | ||
| model_args=CLAUDE_MODEL_CONFIG, | ||
| model_args=CLAUDE_SONNET_37, | ||
| config=PromptConfig( | ||
| tag_screenshot=True, | ||
| goal=Goal(goal_as_system_msg=True), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect TaskHint.choose_hints() docstring
Tell me more
What is the issue?
The docstring is inaccurate as the method chooses hints based on task_name OR goal depending on hint_retrieval_mode, not just task_name.
Why this matters
Misleading docstring could cause confusion about the method's behavior when using different hint retrieval modes.
Suggested change ∙ Feature Preview
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.