diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 91b2f70f..cfbd19bd 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -111,10 +111,11 @@ def get_action(self, obs): previous_plan=self.plan, step=self.plan_step, flags=self.flags, + llm=self.chat_llm, ) # Set task name for task hints if available - if self.flags.use_task_hint and hasattr(self, 'task_name'): + if self.flags.use_task_hint and hasattr(self, "task_name"): main_prompt.set_task_name(self.task_name) max_prompt_tokens, max_trunc_itr = self._get_maxes() diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index bc12cc2c..983c9d48 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -6,15 +6,16 @@ import logging from dataclasses import dataclass +from pathlib import Path +from typing import Literal -from browsergym.core import action +import pandas as pd from browsergym.core.action.base import AbstractActionSet from agentlab.agents import dynamic_prompting as dp +from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource +from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise -import fnmatch -import pandas as pd -from pathlib import Path @dataclass @@ -49,6 +50,8 @@ class GenericPromptFlags(dp.Flags): use_abstract_example: bool = False use_hints: bool = False use_task_hint: bool = False + task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" + skip_hints_for_current_task: bool = False hint_db_path: str = None enable_chat: bool = False max_prompt_tokens: int = None @@ -70,10 +73,12 @@ def __init__( previous_plan: str, step: int, flags: GenericPromptFlags, + llm: ChatModel, ) -> None: super().__init__() self.flags = flags self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs) + goal = obs_history[-1]["goal_object"] if self.flags.enable_chat: self.instructions = dp.ChatInstructions( obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions @@ -84,7 +89,7 @@ def __init__( "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) self.instructions = dp.GoalInstructions( - obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions + goal, extra_instructions=flags.extra_instructions ) self.obs = dp.Observation( @@ -103,9 +108,14 @@ def time_for_caution(): self.be_cautious = dp.BeCautious(visible=time_for_caution) self.think = dp.Think(visible=lambda: flags.use_thinking) self.hints = dp.Hints(visible=lambda: flags.use_hints) + goal_str: str = goal[0]["text"] self.task_hint = TaskHint( use_task_hint=flags.use_task_hint, - hint_db_path=flags.hint_db_path + hint_db_path=flags.hint_db_path, + goal=goal_str, + hint_retrieval_mode=flags.task_hint_retrieval_mode, + llm=llm, + skip_hints_for_current_task=flags.skip_hints_for_current_task, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -114,12 +124,12 @@ def time_for_caution(): @property def _prompt(self) -> HumanMessage: prompt = HumanMessage(self.instructions.prompt) - + # Add task hints if enabled task_hints_text = "" - if self.flags.use_task_hint and hasattr(self, 'task_name'): + if self.flags.use_task_hint and hasattr(self, "task_name"): task_hints_text = self.task_hint.get_hints_for_task(self.task_name) - + prompt.add_text( f"""\ {self.obs.prompt}\ @@ -286,11 +296,23 @@ def _parse_answer(self, text_answer): class TaskHint(dp.PromptElement): - def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None: + def __init__( + self, + use_task_hint: bool, + hint_db_path: str, + goal: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"], + skip_hints_for_current_task: bool, + llm: ChatModel, + ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint self.hint_db_rel_path = "hint_db.csv" self.hint_db_path = hint_db_path # Allow external path override + self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.goal = goal + self.llm = llm self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -316,42 +338,50 @@ def _init(self): hint_db_path = Path(self.hint_db_path) else: hint_db_path = Path(__file__).parent / self.hint_db_rel_path - + if hint_db_path.exists(): self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) # Verify the expected columns exist if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: - print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}") + print( + f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" + ) self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( + hint_db_path=hint_db_path.as_posix(), + hint_retrieval_mode=self.hint_retrieval_mode, + skip_hints_for_current_task=self.skip_hints_for_current_task, + ) except Exception as e: # Fallback to empty database on any error print(f"Warning: Could not load hint database: {e}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) - def get_hints_for_task(self, task_name: str) -> str: """Get hints for a specific task.""" if not self.use_task_hint: return "" # Ensure hint_db is initialized - if not hasattr(self, 'hint_db'): + if not hasattr(self, "hint_db"): self._init() # Check if hint_db has the expected structure - if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: + if ( + self.hint_db.empty + or "task_name" not in self.hint_db.columns + or "hint" not in self.hint_db.columns + ): return "" try: - task_hints = self.hint_db[ - self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) - ] + task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) hints = [] - for hint in task_hints["hint"]: + for hint in task_hints: hint = hint.strip() if hint: hints.append(f"- {hint}") @@ -364,5 +394,5 @@ def get_hints_for_task(self, task_name: str) -> str: return hints_str except Exception as e: print(f"Warning: Error getting hints for task {task_name}: {e}") - + return "" diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 375c829e..bd200da3 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -28,6 +28,7 @@ from agentlab.benchmarks.abstract_env import AbstractBenchmark as AgentLabBenchmark from agentlab.benchmarks.osworld import OSWorldActionSet from agentlab.llm.base_api import BaseModelArgs +from agentlab.llm.chat_api import ChatModel from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( APIPayload, @@ -316,39 +317,21 @@ class TaskHint(Block): def _init(self): """Initialize the block.""" - if Path(self.hint_db_rel_path).is_absolute(): - hint_db_path = Path(self.hint_db_rel_path) - else: - hint_db_path = Path(__file__).parent / self.hint_db_rel_path - self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) - if self.hint_retrieval_mode == "emb": - self.encode_hints() - - def oai_embed(self, text: str): - response = self._oai_emb.create(input=text, model="text-embedding-3-small") - return response.data[0].embedding - - def encode_hints(self): - self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") - logger.info( - f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + self.hints_source = HintsSource( + hint_db_path=self.hint_db_rel_path, + hint_retrieval_mode=self.hint_retrieval_mode, + top_n=self.top_n, + embedder_model=self.embedder_model, + embedder_server=self.embedder_server, + llm_prompt=self.llm_prompt, ) - hints = self.uniq_hints["hint"].tolist() - semantic_keys = self.uniq_hints["semantic_keys"].tolist() - lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] - emb_path = f"{self.hint_db_rel_path}.embs.npy" - assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" - logger.info(f"Loading hint embeddings from: {emb_path}") - emb_dict = np.load(emb_path, allow_pickle=True).item() - self.hint_embeddings = np.array([emb_dict[k] for k in lines]) - logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if not self.use_task_hint: return {} goal = "\n".join([c.get("text", "") for c in discussion.groups[0].messages[1].content]) - task_hints = self.choose_hints(llm, task_name, goal) + task_hints = self.hints_source.choose_hints(llm, task_name, goal) hints = [] for hint in task_hints: @@ -358,58 +341,132 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: if len(hints) > 0: hints_str = ( - "# Hints:\nHere are some hints for the task you are working on:\n" + "\n# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(hints) ) msg = llm.msg.user().add_text(hints_str) discussion.append(msg) + +class HintsSource: + def __init__( + self, + hint_db_path: str, + hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + skip_hints_for_current_task: bool = False, + top_n: int = 4, + embedder_model: str = "Qwen/Qwen3-Embedding-0.6B", + embedder_server: str = "http://localhost:5000", + llm_prompt: str = """We're choosing hints to help solve the following task:\n{goal}.\n +You need to choose the most relevant hints topic from the following list:\n\nHint topics:\n{topics}\n +Choose hint topic for the task and return only its number, e.g. 1. If you don't know the answer, return -1.""", + ) -> None: + self.hint_db_path = hint_db_path + self.hint_retrieval_mode = hint_retrieval_mode + self.skip_hints_for_current_task = skip_hints_for_current_task + self.top_n = top_n + self.embedder_model = embedder_model + self.embedder_server = embedder_server + self.llm_prompt = llm_prompt + + if Path(hint_db_path).is_absolute(): + self.hint_db_path = Path(hint_db_path).as_posix() + else: + self.hint_db_path = (Path(__file__).parent / self.hint_db_path).as_posix() + self.hint_db = pd.read_csv(self.hint_db_path, header=0, index_col=None, dtype=str) + logger.info(f"Loaded {len(self.hint_db)} hints from database {self.hint_db_path}") + if self.hint_retrieval_mode == "emb": + self.load_hint_vectors() + + def load_hint_vectors(self): + self.uniq_hints = self.hint_db.drop_duplicates(subset=["hint"], keep="first") + logger.info( + f"Encoding {len(self.uniq_hints)} unique hints with semantic keys using {self.embedder_model} model." + ) + hints = self.uniq_hints["hint"].tolist() + semantic_keys = self.uniq_hints["semantic_keys"].tolist() + lines = [f"{k}: {h}" for h, k in zip(hints, semantic_keys)] + emb_path = f"{self.hint_db_path}.embs.npy" + assert os.path.exists(emb_path), f"Embedding file not found: {emb_path}" + logger.info(f"Loading hint embeddings from: {emb_path}") + emb_dict = np.load(emb_path, allow_pickle=True).item() + self.hint_embeddings = np.array([emb_dict[k] for k in lines]) + logger.info(f"Loaded hint embeddings shape: {self.hint_embeddings.shape}") + def choose_hints(self, llm, task_name: str, goal: str) -> list[str]: """Choose hints based on the task name.""" + logger.info( + f"Choosing hints for task: {task_name}, goal: {goal} from db: {self.hint_db_path} using mode: {self.hint_retrieval_mode}" + ) if self.hint_retrieval_mode == "llm": - return self.choose_hints_llm(llm, goal) + return self.choose_hints_llm(llm, goal, task_name) elif self.hint_retrieval_mode == "direct": return self.choose_hints_direct(task_name) elif self.hint_retrieval_mode == "emb": - return self.choose_hints_emb(goal) + return self.choose_hints_emb(goal, task_name) else: raise ValueError(f"Unknown hint retrieval mode: {self.hint_retrieval_mode}") - def choose_hints_llm(self, llm, goal: str) -> list[str]: + def choose_hints_llm(self, llm, goal: str, task_name: str) -> list[str]: """Choose hints using LLM to filter the hints.""" topic_to_hints = defaultdict(list) - for i, row in self.hint_db.iterrows(): - topic_to_hints[row["semantic_keys"]].append(i) + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + for _, row in self.hint_db.iterrows(): + hint = row["hint"] + if hint in skip_hints: + continue + topic_to_hints[row["semantic_keys"]].append(hint) + logger.info(f"Collected {len(topic_to_hints)} hint topics") hint_topics = list(topic_to_hints.keys()) topics = "\n".join([f"{i}. {h}" for i, h in enumerate(hint_topics)]) prompt = self.llm_prompt.format(goal=goal, topics=topics) - response = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])) + + if isinstance(llm, ChatModel): + response: str = llm(messages=[dict(role="user", content=prompt)])["content"] + else: + response: str = llm(APIPayload(messages=[llm.msg.user().add_text(prompt)])).think try: - hint_topic_idx = json.loads(response.think) - if hint_topic_idx < 0 or hint_topic_idx >= len(hint_topics): - logger.error(f"Wrong LLM hint id response: {response.think}, no hints") + topic_number = json.loads(response) + if topic_number < 0 or topic_number >= len(hint_topics): + logger.error(f"Wrong LLM hint id response: {response}, no hints") return [] - hint_topic = hint_topics[hint_topic_idx] - hint_indices = topic_to_hints[hint_topic] - df = self.hint_db.iloc[hint_indices].copy() - df = df.drop_duplicates(subset=["hint"], keep="first") # leave only unique hints - hints = df["hint"].tolist() - logger.debug(f"LLM hint topic {hint_topic_idx}, chosen hints: {df['hint'].tolist()}") - except json.JSONDecodeError: - logger.error(f"Failed to parse LLM hint id response: {response.think}, no hints") + hint_topic = hint_topics[topic_number] + hints = list(set(topic_to_hints[hint_topic])) + logger.info(f"LLM hint topic {topic_number}:'{hint_topic}', chosen hints: {hints}") + except Exception as e: + logger.exception(f"Failed to parse LLM hint id response: {response}:\n{e}") hints = [] return hints - def choose_hints_emb(self, goal: str) -> list[str]: + def choose_hints_emb(self, goal: str, task_name: str) -> list[str]: """Choose hints using embeddings to filter the hints.""" - goal_embeddings = self._encode([goal], prompt="task description") - similarities = self._similarity(goal_embeddings.tolist(), self.hint_embeddings.tolist()) - top_indices = similarities.argsort()[0][-self.top_n :].tolist() - logger.info(f"Top hint indices based on embedding similarity: {top_indices}") - hints = self.uniq_hints.iloc[top_indices] - logger.info(f"Embedding-based hints chosen: {hints}") - return hints["hint"].tolist() + try: + goal_embeddings = self._encode([goal], prompt="task description") + hint_embeddings = self.hint_embeddings.copy() + all_hints = self.uniq_hints["hint"].tolist() + skip_hints = [] + if self.skip_hints_for_current_task: + skip_hints = self.get_current_task_hints(task_name) + hint_embeddings = [] + id_to_hint = {} + for hint, emb in zip(all_hints, self.hint_embeddings): + if hint in skip_hints: + continue + hint_embeddings.append(emb.tolist()) + id_to_hint[len(hint_embeddings) - 1] = hint + logger.info(f"Prepared hint embeddings for {len(hint_embeddings)} hints") + similarities = self._similarity(goal_embeddings.tolist(), hint_embeddings) + top_indices = similarities.argsort()[0][-self.top_n :].tolist() + logger.info(f"Top hint indices based on embedding similarity: {top_indices}") + hints = [id_to_hint[idx] for idx in top_indices] + logger.info(f"Embedding-based hints chosen: {hints}") + except Exception as e: + logger.exception(f"Failed to choose hints using embeddings: {e}") + hints = [] + return hints def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_retries: int = 5): """Call the encode API endpoint with timeout and retries""" @@ -427,9 +484,14 @@ def _encode(self, texts: list[str], prompt: str = "", timeout: int = 10, max_ret raise e time.sleep(random.uniform(1, timeout)) continue + raise ValueError("Failed to encode hints") def _similarity( - self, texts1: list[str], texts2: list[str], timeout: int = 2, max_retries: int = 5 + self, + texts1: list, + texts2: list, + timeout: int = 2, + max_retries: int = 5, ): """Call the similarity API endpoint with timeout and retries""" for attempt in range(max_retries): @@ -446,12 +508,18 @@ def _similarity( raise e time.sleep(random.uniform(1, timeout)) continue + raise ValueError("Failed to compute similarity") def choose_hints_direct(self, task_name: str) -> list[str]: - hints = self.hint_db[ + hints = self.get_current_task_hints(task_name) + logger.info(f"Direct hints chosen: {hints}") + return hints + + def get_current_task_hints(self, task_name): + hints_df = self.hint_db[ self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) ] - return hints["hint"].tolist() + return hints_df["hint"].tolist() @dataclass diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 8accbfd6..fed78b3e 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -537,9 +537,10 @@ def run_gradio(results_dir: Path): do_share = os.getenv("AGENTXRAY_SHARE_GRADIO", "false").lower() == "true" port = os.getenv("AGENTXRAY_APP_PORT", None) + server_name = "0.0.0.0" if os.getenv("AGENTXRAY_PUBLIC", "false") == "true" else "127.0.0.1" if isinstance(port, str): port = int(port) - demo.launch(server_port=port, share=do_share) + demo.launch(server_name=server_name, server_port=port, share=do_share) def handle_key_event(key_event, step_id: StepId): diff --git a/src/agentlab/experiments/graph_execution_ray.py b/src/agentlab/experiments/graph_execution_ray.py index f047f866..f7aad780 100644 --- a/src/agentlab/experiments/graph_execution_ray.py +++ b/src/agentlab/experiments/graph_execution_ray.py @@ -3,9 +3,8 @@ import bgym import ray -from ray.util import state - from agentlab.experiments.exp_utils import _episode_timeout, run_exp +from ray.util import state logger = logging.getLogger(__name__) @@ -79,6 +78,7 @@ def poll_for_timeout(tasks: dict[str, ray.ObjectRef], timeout: float, poll_inter try: result = ray.get(task) except Exception as e: + logger.exception(f"Task failed: {e}") result = e results.append(result) diff --git a/src/agentlab/llm/llm_configs.py b/src/agentlab/llm/llm_configs.py index 3d5828b9..7ac2450a 100644 --- a/src/agentlab/llm/llm_configs.py +++ b/src/agentlab/llm/llm_configs.py @@ -20,22 +20,6 @@ ] CHAT_MODEL_ARGS_DICT = { - "openai/gpt-5-nano-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-nano-2025-08-07", - max_total_tokens=128_000, - max_input_tokens=128_000, - max_new_tokens=16_384, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), - "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( - model_name="gpt-5-mini-2025-08-07", - max_total_tokens=128_000, - max_input_tokens=128_000, - max_new_tokens=16_384, - temperature=1, # gpt-5 supports temperature of 1 only - vision_support=True, - ), "openai/gpt-4.1-mini-2025-04-14": OpenAIModelArgs( model_name="gpt-4.1-mini-2025-04-14", max_total_tokens=128_000, @@ -117,6 +101,7 @@ max_input_tokens=400_000 - 4_000, max_new_tokens=4_000, temperature=1, # temperature param not supported by gpt-5 + vision_support=True, ), "openai/gpt-5-mini-2025-08-07": OpenAIModelArgs( model_name="gpt-5-mini-2025-08-07", @@ -124,6 +109,15 @@ max_input_tokens=400_000 - 4_000, max_new_tokens=4_000, temperature=1, # temperature param not supported by gpt-5 + vision_support=True, + ), + "openai/gpt-5-2025-08-07": OpenAIModelArgs( + model_name="gpt-5-2025-08-07", + max_total_tokens=400_000, + max_input_tokens=400_000 - 4_000, + max_new_tokens=4_000, + temperature=1, # temperature param not supported by gpt-5 + vision_support=True, ), "azure/gpt-35-turbo/gpt-35-turbo": AzureModelArgs( model_name="gpt-35-turbo",