Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict:
def get_action(self, obs):

self.obs_history.append(obs)

main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
Expand Down
42 changes: 39 additions & 3 deletions src/agentlab/agents/generic_agent_hinter/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt
from .generic_agent_prompt import (
GenericPromptFlags,
MainPrompt,
StepWiseContextIdentificationPrompt,
)


@dataclass
Expand Down Expand Up @@ -102,6 +106,14 @@ def set_task_name(self, task_name: str):
def get_action(self, obs):

self.obs_history.append(obs)

system_prompt = SystemMessage(dp.SystemPrompt().prompt)

queries, think_queries = self._get_queries()

# use those queries to retrieve from the database and pass to prompt if step-level
queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None

main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
Expand All @@ -112,6 +124,7 @@ def get_action(self, obs):
step=self.plan_step,
flags=self.flags,
llm=self.chat_llm,
queries=queries_for_hints,
)

# Set task name for task hints if available
Expand All @@ -120,8 +133,6 @@ def get_action(self, obs):

max_prompt_tokens, max_trunc_itr = self._get_maxes()

system_prompt = SystemMessage(dp.SystemPrompt().prompt)

human_prompt = dp.fit_tokens(
shrinkable=main_prompt,
max_prompt_tokens=max_prompt_tokens,
Expand Down Expand Up @@ -168,6 +179,31 @@ def get_action(self, obs):
)
return ans_dict["action"], agent_info

def _get_queries(self):
"""Retrieve queries for hinting."""
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
query_prompt = StepWiseContextIdentificationPrompt(
obs_history=self.obs_history,
actions=self.actions,
thoughts=self.thoughts,
obs_flags=self.flags.obs,
n_queries=self.flags.n_retrieval_queries, # TODO
)

chat_messages = Discussion([system_prompt, query_prompt.prompt])
ans_dict = retry(
self.chat_llm,
chat_messages,
n_retry=self.max_retry,
parser=query_prompt._parse_answer,
)

queries = ans_dict.get("queries", [])
assert len(queries) == self.flags.n_retrieval_queries
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe Query Length Validation category Functionality

Tell me more
What is the issue?

A hard assertion on query length without proper validation that self.flags.n_retrieval_queries exists and is not None.

Why this matters

The code will crash with an AttributeError if n_retrieval_queries is not defined in flags, or with an AssertionError if the number of queries doesn't match exactly.

Suggested change ∙ Feature Preview

Replace with a safer validation approach:

expected_queries = getattr(self.flags, 'n_retrieval_queries', None)
if expected_queries is not None and len(queries) != expected_queries:
    warn(f'Expected {expected_queries} queries but got {len(queries)}')
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


# TODO: we should probably propagate these chat_messages to be able to see them in xray
return queries, ans_dict.get("think", None)

def reset(self, seed=None):
self.seed = seed
self.plan = "No plan yet"
Expand Down
97 changes: 96 additions & 1 deletion src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
It is based on the dynamic_prompting module from the agentlab package.
"""

import json
import logging
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags):
hint_index_path: str = None
hint_retriever_path: str = None
hint_num_results: int = 5
n_retrieval_queries: int = 3
hint_level: Literal["episode", "step"] = "episode"


class MainPrompt(dp.Shrinkable):
Expand All @@ -81,6 +84,7 @@ def __init__(
step: int,
flags: GenericPromptFlags,
llm: ChatModel,
queries: list[str] | None = None,
) -> None:
super().__init__()
self.flags = flags
Expand Down Expand Up @@ -130,6 +134,8 @@ def time_for_caution():
hint_index_path=flags.hint_index_path,
hint_retriever_path=flags.hint_retriever_path,
hint_num_results=flags.hint_num_results,
hint_level=flags.hint_level,
queries=queries,
)
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
self.criticise = Criticise(visible=lambda: flags.use_criticise)
Expand Down Expand Up @@ -324,6 +330,8 @@ def __init__(
hint_num_results: int = 5,
skip_hints_for_current_task: bool = False,
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
hint_level: Literal["episode", "step"] = "episode",
queries: list[str] | None = None,
) -> None:
super().__init__(visible=use_task_hint)
self.use_task_hint = use_task_hint
Expand All @@ -339,6 +347,8 @@ def __init__(
self.skip_hints_for_current_task = skip_hints_for_current_task
self.goal = goal
self.llm = llm
self.hint_level: Literal["episode", "step"] = hint_level
self.queries: list[str] | None = queries
self._init()

_prompt = "" # Task hints are added dynamically in MainPrompt
Expand Down Expand Up @@ -394,6 +404,7 @@ def _init(self):
else:
print(f"Warning: Hint database not found at {hint_db_path}")
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])

self.hints_source = HintsSource(
hint_db_path=hint_db_path.as_posix(),
hint_retrieval_mode=self.hint_retrieval_mode,
Expand Down Expand Up @@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str:
return ""

try:
task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
# When step-level, pass queries as goal string to fit the llm_prompt
goal_or_queries = self.goal
if self.hint_level == "step" and self.queries:
goal_or_queries = "\n".join(self.queries)

task_hints = self.hints_source.choose_hints(
self.llm,
task_name,
goal_or_queries,
)

hints = []
for hint in task_hints:
Expand All @@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str:
print(f"Warning: Error getting hints for task {task_name}: {e}")

return ""


class StepWiseContextIdentificationPrompt(dp.Shrinkable):
def __init__(
self,
obs_history: list[dict],
actions: list[str],
thoughts: list[str],
obs_flags: dp.ObsFlags,
n_queries: int = 1,
) -> None:
super().__init__()
self.obs_flags = obs_flags
self.n_queries = n_queries
self.history = dp.History(obs_history, actions, None, thoughts, obs_flags)
self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"])
self.obs = dp.Observation(obs_history[-1], obs_flags)

self.think = dp.Think(visible=True) # To replace with static text maybe

@property
def _prompt(self) -> HumanMessage:
prompt = HumanMessage(self.instructions.prompt)

prompt.add_text(
f"""\
{self.obs.prompt}\
{self.history.prompt}\
"""
)

example_queries = [
"The user has started sorting a table and needs to apply multiple column criteria simultaneously.",
"The user is attempting to configure advanced sorting options but the interface is unclear.",
"The user has selected the first sort column and is now looking for how to add a second sort criterion.",
"The user is in the middle of a multi-step sorting process and needs guidance on the next action.",
]

example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)

prompt.add_text(
f"""
# Querying memory

Before choosing an action, let's search our available documentation and memory for relevant context.
Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
<think>chain of thought</think>
<queries>json list of strings</queries> for the queries. Return exactly {self.n_queries}
queries in the list.

# Concrete Example

<think>
I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
I will be able to sort by both at the same time.
</think>

<queries>
{example_queries_str}
</queries>
"""
)

return self.obs.add_screenshot(prompt)

def shrink(self):
self.history.shrink()
self.obs.shrink()

def _parse_answer(self, text_answer):
ans_dict = parse_html_tags_raise(
text_answer, keys=["think", "queries"], merge_multiple=True
)
ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]"))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsafe JSON Deserialization category Security

Tell me more
What is the issue?

Unsafe JSON deserialization of untrusted data from the LLM response without validation.

Why this matters

Malicious JSON payloads could potentially lead to code execution or denial of service through carefully crafted inputs that exploit json.loads vulnerabilities.

Suggested change ∙ Feature Preview
def validate_queries(queries):
    if not isinstance(queries, list):
        raise ValueError("Queries must be a list")
    if not all(isinstance(q, str) for q in queries):
        raise ValueError("All queries must be strings")
    return queries

ans_dict["queries"] = validate_queries(json.loads(ans_dict.get("queries", "[]")))
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

return ans_dict
Loading