Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/agentlab/agents/generic_agent_hinter/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@ def get_action(self, obs):
previous_plan=self.plan,
step=self.plan_step,
flags=self.flags,
llm=self.chat_llm,
)

# Set task name for task hints if available
if self.flags.use_task_hint and hasattr(self, 'task_name'):
if self.flags.use_task_hint and hasattr(self, "task_name"):
main_prompt.set_task_name(self.task_name)

max_prompt_tokens, max_trunc_itr = self._get_maxes()
Expand Down
70 changes: 50 additions & 20 deletions src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

from browsergym.core import action
import pandas as pd
from browsergym.core.action.base import AbstractActionSet

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.tool_use_agent.tool_use_agent import HintsSource
from agentlab.llm.chat_api import ChatModel
from agentlab.llm.llm_utils import HumanMessage, parse_html_tags_raise
import fnmatch
import pandas as pd
from pathlib import Path


@dataclass
Expand Down Expand Up @@ -49,6 +50,8 @@ class GenericPromptFlags(dp.Flags):
use_abstract_example: bool = False
use_hints: bool = False
use_task_hint: bool = False
task_hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct"
skip_hints_for_current_task: bool = False
hint_db_path: str = None
enable_chat: bool = False
max_prompt_tokens: int = None
Expand All @@ -70,10 +73,12 @@ def __init__(
previous_plan: str,
step: int,
flags: GenericPromptFlags,
llm: ChatModel,
) -> None:
super().__init__()
self.flags = flags
self.history = dp.History(obs_history, actions, memories, thoughts, flags.obs)
goal = obs_history[-1]["goal_object"]
if self.flags.enable_chat:
self.instructions = dp.ChatInstructions(
obs_history[-1]["chat_messages"], extra_instructions=flags.extra_instructions
Expand All @@ -84,7 +89,7 @@ def __init__(
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
)
self.instructions = dp.GoalInstructions(
obs_history[-1]["goal_object"], extra_instructions=flags.extra_instructions
goal, extra_instructions=flags.extra_instructions
)

self.obs = dp.Observation(
Expand All @@ -103,9 +108,14 @@ def time_for_caution():
self.be_cautious = dp.BeCautious(visible=time_for_caution)
self.think = dp.Think(visible=lambda: flags.use_thinking)
self.hints = dp.Hints(visible=lambda: flags.use_hints)
goal_str: str = goal[0]["text"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe Goal Object Access category Functionality

Tell me more
What is the issue?

Direct array access of goal[0] without validation assumes the goal object always has at least one element with a 'text' key.

Why this matters

This assumption could cause the agent to crash if the goal object is empty or malformed, preventing proper task execution.

Suggested change ∙ Feature Preview

Add proper validation for the goal object:

if not goal or not isinstance(goal, list) or len(goal) == 0:
    raise ValueError("Invalid goal object: empty or not a list")
if "text" not in goal[0]:
    raise ValueError("Invalid goal object: missing 'text' field")
goal_str: str = goal[0]["text"]
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

self.task_hint = TaskHint(
use_task_hint=flags.use_task_hint,
hint_db_path=flags.hint_db_path
hint_db_path=flags.hint_db_path,
goal=goal_str,
hint_retrieval_mode=flags.task_hint_retrieval_mode,
llm=llm,
skip_hints_for_current_task=flags.skip_hints_for_current_task,
)
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
self.criticise = Criticise(visible=lambda: flags.use_criticise)
Expand All @@ -114,12 +124,12 @@ def time_for_caution():
@property
def _prompt(self) -> HumanMessage:
prompt = HumanMessage(self.instructions.prompt)

# Add task hints if enabled
task_hints_text = ""
if self.flags.use_task_hint and hasattr(self, 'task_name'):
if self.flags.use_task_hint and hasattr(self, "task_name"):
task_hints_text = self.task_hint.get_hints_for_task(self.task_name)

prompt.add_text(
f"""\
{self.obs.prompt}\
Expand Down Expand Up @@ -286,11 +296,23 @@ def _parse_answer(self, text_answer):


class TaskHint(dp.PromptElement):
def __init__(self, use_task_hint: bool = True, hint_db_path: str = None) -> None:
def __init__(
self,
use_task_hint: bool,
hint_db_path: str,
goal: str,
hint_retrieval_mode: Literal["direct", "llm", "emb"],
skip_hints_for_current_task: bool,
llm: ChatModel,
) -> None:
super().__init__(visible=use_task_hint)
self.use_task_hint = use_task_hint
self.hint_db_rel_path = "hint_db.csv"
self.hint_db_path = hint_db_path # Allow external path override
self.hint_retrieval_mode: Literal["direct", "llm", "emb"] = hint_retrieval_mode
self.skip_hints_for_current_task = skip_hints_for_current_task
self.goal = goal
self.llm = llm
self._init()

_prompt = "" # Task hints are added dynamically in MainPrompt
Expand All @@ -316,42 +338,50 @@ def _init(self):
hint_db_path = Path(self.hint_db_path)
else:
hint_db_path = Path(__file__).parent / self.hint_db_rel_path

if hint_db_path.exists():
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
# Verify the expected columns exist
if "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
print(f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}")
print(
f"Warning: Hint database missing expected columns. Found: {list(self.hint_db.columns)}"
)
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
else:
print(f"Warning: Hint database not found at {hint_db_path}")
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
self.hints_source = HintsSource(
hint_db_path=hint_db_path.as_posix(),
hint_retrieval_mode=self.hint_retrieval_mode,
skip_hints_for_current_task=self.skip_hints_for_current_task,
)
except Exception as e:
# Fallback to empty database on any error
print(f"Warning: Could not load hint database: {e}")
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])


def get_hints_for_task(self, task_name: str) -> str:
"""Get hints for a specific task."""
if not self.use_task_hint:
return ""

# Ensure hint_db is initialized
if not hasattr(self, 'hint_db'):
if not hasattr(self, "hint_db"):
self._init()

# Check if hint_db has the expected structure
if self.hint_db.empty or "task_name" not in self.hint_db.columns or "hint" not in self.hint_db.columns:
if (
self.hint_db.empty
or "task_name" not in self.hint_db.columns
or "hint" not in self.hint_db.columns
):
return ""
Comment on lines 363 to 378
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex Validation Flow category Readability

Tell me more
What is the issue?

Multiple early returns with empty strings and nested validation checks make the flow hard to follow.

Why this matters

The multiple validation checks and early returns create a complex flow that makes it difficult to understand the main purpose of the function.

Suggested change ∙ Feature Preview

Consolidate validation checks:

def get_hints_for_task(self, task_name: str) -> str:
    """Get hints for a specific task."""
    if not self._is_hints_retrieval_valid():
        return ""
    
    try:
        task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
        return self._format_hints(task_hints)
    except Exception as e:
        logging.warning(f"Error getting hints for task {task_name}: {e}")
        return ""

def _is_hints_retrieval_valid(self) -> bool:
    """Check if hint retrieval is possible and properly configured."""
    if not self.use_task_hint:
        return False
    if not hasattr(self, "hint_db"):
        self._init()
    return not (
        self.hint_db.empty
        or "task_name" not in self.hint_db.columns
        or "hint" not in self.hint_db.columns
    )
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


try:
task_hints = self.hint_db[
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
]
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unvalidated Input to LLM category Security

Tell me more
What is the issue?

The task_name and goal parameters are passed directly to an LLM without any input validation or sanitization.

Why this matters

Malicious task names or goals could be crafted to perform prompt injection attacks against the LLM, potentially causing it to reveal sensitive information or execute harmful commands.

Suggested change ∙ Feature Preview

Add input validation before passing task_name and goal to the LLM:

def sanitize_input(text: str) -> str:
    # Remove any potential prompt injection patterns
    # This is a basic example - implement more comprehensive validation
    if not isinstance(text, str):
        raise ValueError("Input must be string")
    text = text.strip()
    if len(text) > 1000:  # Reasonable length limit
        raise ValueError("Input too long")
    return text

task_name = sanitize_input(task_name)
goal = sanitize_input(goal)
task_hints = self.hints_source.choose_hints(self.llm, task_name, goal)
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


hints = []
for hint in task_hints["hint"]:
for hint in task_hints:
hint = hint.strip()
if hint:
hints.append(f"- {hint}")
Expand All @@ -364,5 +394,5 @@ def get_hints_for_task(self, task_name: str) -> str:
return hints_str
except Exception as e:
print(f"Warning: Error getting hints for task {task_name}: {e}")

return ""
Loading
Loading