-
Notifications
You must be signed in to change notification settings - Fork 105
Hints retrieval in generic agent #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7af2d15
3f9e4a2
c88d7f3
74fc47f
1de1e51
2b4633a
380c69f
d3054cd
ecb8f11
bf0b6e7
f7d1545
24a14f2
55ce26a
cad1209
d920b8e
5315f14
26f0abb
5393a34
deddc50
b9d09d4
725e7a0
e93fde5
5604ac3
0e68bca
e4cad16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,15 +6,16 @@ | |
|
|
||
| import logging | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
| from typing import Literal | ||
|
|
||
| from browsergym.core import action | ||
| import pandas as pd | ||
| from browsergym.core.action.base import AbstractActionSet | ||
|
|
||
| from agentlab.agents import dynamic_prompting as dp | ||
| from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource | ||
| from agentlab.llm.chat_api import ChatModel | ||
| from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise | ||
| import fnmatch | ||
| import pandas as pd | ||
| from pathlib import Path | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -49,6 +50,8 @@ class GenericPromptFlags(dp.Flags): | |
| use_abstract_example: bool = False | ||
| use_hints: bool = False | ||
| use_task_hint: bool = False | ||
| task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct" | ||
| skip_hints_for_current_task: bool = False | ||
| hint_db_path: str = None | ||
| enable_chat: bool = False | ||
| max_prompt_tokens: int = None | ||
|
|
@@ -70,10 +73,12 @@ def __init__( | |
| previous_plan: str, | ||
| step: int, | ||
| flags: GenericPromptFlags, | ||
| llm: ChatModel, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.flags = flags | ||
| self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs) | ||
| goal = obs_history[-1]["goal_object"] | ||
| if self.flags.enable_chat: | ||
| self.instructions = dp.ChatInstructions( | ||
| obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions | ||
|
|
@@ -84,7 +89,7 @@ def __init__( | |
| "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." | ||
| ) | ||
| self.instructions = dp.GoalInstructions( | ||
| obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions | ||
| goal, extra_instructions=flags.extra_instructions | ||
| ) | ||
|
|
||
| self.obs = dp.Observation( | ||
|
|
@@ -103,9 +108,14 @@ def time_for_caution(): | |
| self.be_cautious = dp.BeCautious(visible=time_for_caution) | ||
| self.think = dp.Think(visible=lambda: flags.use_thinking) | ||
| self.hints = dp.Hints(visible=lambda: flags.use_hints) | ||
| goal_str: str = goal[0]["text"] | ||
| self.task_hint = TaskHint( | ||
| use_task_hint=flags.use_task_hint, | ||
| hint_db_path=flags.hint_db_path | ||
| hint_db_path=flags.hint_db_path, | ||
| goal=goal_str, | ||
| hint_retrieval_mode=flags.task_hint_retrieval_mode, | ||
| llm=llm, | ||
| skip_hints_for_current_task=flags.skip_hints_for_current_task, | ||
| ) | ||
| self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan | ||
| self.criticise = Criticise(visible=lambda: flags.use_criticise) | ||
|
|
@@ -114,12 +124,12 @@ def time_for_caution(): | |
| @property | ||
| def _prompt(self) -> HumanMessage: | ||
| prompt = HumanMessage(self.instructions.prompt) | ||
|
|
||
| # Add task hints if enabled | ||
| task_hints_text = "" | ||
| if self.flags.use_task_hint and hasattr(self, 'task_name'): | ||
| if self.flags.use_task_hint and hasattr(self, "task_name"): | ||
| task_hints_text = self.task_hint.get_hints_for_task(self.task_name) | ||
|
|
||
| prompt.add_text( | ||
| f"""\ | ||
| {self.obs.prompt}\ | ||
|
|
@@ -286,11 +296,23 @@ def _parse_answer(self, text_answer): | |
|
|
||
|
|
||
| class TaskHint(dp.PromptElement): | ||
| def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None: | ||
| def __init__( | ||
| self, | ||
| use_task_hint: bool, | ||
| hint_db_path: str, | ||
| goal: str, | ||
| hint_retrieval_mode: Literal["direct", "llm", "emb"], | ||
| skip_hints_for_current_task: bool, | ||
| llm: ChatModel, | ||
| ) -> None: | ||
| super().__init__(visible=use_task_hint) | ||
| self.use_task_hint = use_task_hint | ||
| self.hint_db_rel_path = "hint_db.csv" | ||
| self.hint_db_path = hint_db_path # Allow external path override | ||
| self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode | ||
| self.skip_hints_for_current_task = skip_hints_for_current_task | ||
| self.goal = goal | ||
| self.llm = llm | ||
| self._init() | ||
|
|
||
| _prompt = "" # Task hints are added dynamically in MainPrompt | ||
|
|
@@ -316,42 +338,50 @@ def _init(self): | |
| hint_db_path = Path(self.hint_db_path) | ||
| else: | ||
| hint_db_path = Path(__file__).parent / self.hint_db_rel_path | ||
|
|
||
| if hint_db_path.exists(): | ||
| self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str) | ||
| # Verify the expected columns exist | ||
| if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: | ||
| print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}") | ||
| print( | ||
| f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}" | ||
| ) | ||
| self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
| else: | ||
| print(f"Warning: Hint database not found at {hint_db_path}") | ||
| self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
| self.hints_source = HintsSource( | ||
| hint_db_path=hint_db_path.as_posix(), | ||
| hint_retrieval_mode=self.hint_retrieval_mode, | ||
| skip_hints_for_current_task=self.skip_hints_for_current_task, | ||
| ) | ||
| except Exception as e: | ||
| # Fallback to empty database on any error | ||
| print(f"Warning: Could not load hint database: {e}") | ||
| self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
|
|
||
|
|
||
| def get_hints_for_task(self, task_name: str) -> str: | ||
| """Get hints for a specific task.""" | ||
| if not self.use_task_hint: | ||
| return "" | ||
|
|
||
| # Ensure hint_db is initialized | ||
| if not hasattr(self, 'hint_db'): | ||
| if not hasattr(self, "hint_db"): | ||
| self._init() | ||
|
|
||
| # Check if hint_db has the expected structure | ||
| if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns: | ||
| if ( | ||
| self.hint_db.empty | ||
| or "task_name" not in self.hint_db.columns | ||
| or "hint" not in self.hint_db.columns | ||
| ): | ||
| return "" | ||
|
Comment on lines
363
to
378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Complex Validation Flow
Tell me moreWhat is the issue?Multiple early returns with empty strings and nested validation checks make the flow hard to follow. Why this mattersThe multiple validation checks and early returns create a complex flow that makes it difficult to understand the main purpose of the function. Suggested change ∙ Feature PreviewConsolidate validation checks: def get_hints_for_task(self, task_name: str) -> str:
"""Get hints for a specific task."""
if not self._is_hints_retrieval_valid():
return ""
try:
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
return self._format_hints(task_hints)
except Exception as e:
logging.warning(f"Error getting hints for task {task_name}: {e}")
return ""
def _is_hints_retrieval_valid(self) -> bool:
"""Check if hint retrieval is possible and properly configured."""
if not self.use_task_hint:
return False
if not hasattr(self, "hint_db"):
self._init()
return not (
self.hint_db.empty
or "task_name" not in self.hint_db.columns
or "hint" not in self.hint_db.columns
)Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| try: | ||
| task_hints = self.hint_db[ | ||
| self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name)) | ||
| ] | ||
| task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unvalidated Input to LLM
Tell me moreWhat is the issue?The task_name and goal parameters are passed directly to an LLM without any input validation or sanitization. Why this mattersMalicious task names or goals could be crafted to perform prompt injection attacks against the LLM, potentially causing it to reveal sensitive information or execute harmful commands. Suggested change ∙ Feature PreviewAdd input validation before passing task_name and goal to the LLM: def sanitize_input(text: str) -> str:
# Remove any potential prompt injection patterns
# This is a basic example - implement more comprehensive validation
if not isinstance(text, str):
raise ValueError("Input must be string")
text = text.strip()
if len(text) > 1000: # Reasonable length limit
raise ValueError("Input too long")
return text
task_name = sanitize_input(task_name)
goal = sanitize_input(goal)
task_hints = self.hints_source.choose_hints(self.llm, task_name, goal)Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| hints = [] | ||
| for hint in task_hints["hint"]: | ||
| for hint in task_hints: | ||
| hint = hint.strip() | ||
| if hint: | ||
| hints.append(f"- {hint}") | ||
|
|
@@ -364,5 +394,5 @@ def get_hints_for_task(self, task_name: str) -> str: | |
| return hints_str | ||
| except Exception as e: | ||
| print(f"Warning: Error getting hints for task {task_name}: {e}") | ||
|
|
||
| return "" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe Goal Object Access
Tell me more
What is the issue?
Direct array access of goal[0] without validation assumes the goal object always has at least one element with a 'text' key.
Why this matters
This assumption could cause the agent to crash if the goal object is empty or malformed, preventing proper task execution.
Suggested change ∙ Feature Preview
Add proper validation for the goal object:
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.