diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py
index d1f48f76..646a52b2 100644
--- a/src/agentlab/agents/generic_agent/generic_agent.py
+++ b/src/agentlab/agents/generic_agent/generic_agent.py
@@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict:
def get_action(self, obs):
self.obs_history.append(obs)
+
main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py
index 5e04df0d..0cbdb6b3 100644
--- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py
+++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py
@@ -23,7 +23,11 @@
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator
-from .generic_agent_prompt import GenericPromptFlags, MainPrompt
+from .generic_agent_prompt import (
+ GenericPromptFlags,
+ MainPrompt,
+ StepWiseContextIdentificationPrompt,
+)
@dataclass
@@ -102,6 +106,14 @@ def set_task_name(self, task_name: str):
def get_action(self, obs):
self.obs_history.append(obs)
+
+ system_prompt = SystemMessage(dp.SystemPrompt().prompt)
+
+ queries, think_queries = self._get_queries()
+
+ # use those queries to retrieve from the database and pass to prompt if step-level
+ queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None
+
main_prompt = MainPrompt(
action_set=self.action_set,
obs_history=self.obs_history,
@@ -112,6 +124,7 @@ def get_action(self, obs):
step=self.plan_step,
flags=self.flags,
llm=self.chat_llm,
+ queries=queries_for_hints,
)
# Set task name for task hints if available
@@ -120,8 +133,6 @@ def get_action(self, obs):
max_prompt_tokens, max_trunc_itr = self._get_maxes()
- system_prompt = SystemMessage(dp.SystemPrompt().prompt)
-
human_prompt = dp.fit_tokens(
shrinkable=main_prompt,
max_prompt_tokens=max_prompt_tokens,
@@ -168,6 +179,31 @@ def get_action(self, obs):
)
return ans_dict["action"], agent_info
+ def _get_queries(self):
+ """Retrieve queries for hinting."""
+ system_prompt = SystemMessage(dp.SystemPrompt().prompt)
+ query_prompt = StepWiseContextIdentificationPrompt(
+ obs_history=self.obs_history,
+ actions=self.actions,
+ thoughts=self.thoughts,
+ obs_flags=self.flags.obs,
+ n_queries=self.flags.n_retrieval_queries, # TODO
+ )
+
+ chat_messages = Discussion([system_prompt, query_prompt.prompt])
+ ans_dict = retry(
+ self.chat_llm,
+ chat_messages,
+ n_retry=self.max_retry,
+ parser=query_prompt._parse_answer,
+ )
+
+ queries = ans_dict.get("queries", [])
+ assert len(queries) == self.flags.n_retrieval_queries
+
+ # TODO: we should probably propagate these chat_messages to be able to see them in xray
+ return queries, ans_dict.get("think", None)
+
def reset(self, seed=None):
self.seed = seed
self.plan = "No plan yet"
diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py
index 76205341..10cfeef6 100644
--- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py
+++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py
@@ -4,6 +4,7 @@
It is based on the dynamic_prompting module from the agentlab package.
"""
+import json
import logging
from dataclasses import dataclass
from pathlib import Path
@@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags):
hint_index_path: str = None
hint_retriever_path: str = None
hint_num_results: int = 5
+ n_retrieval_queries: int = 3
+ hint_level: Literal["episode", "step"] = "episode"
class MainPrompt(dp.Shrinkable):
@@ -81,6 +84,7 @@ def __init__(
step: int,
flags: GenericPromptFlags,
llm: ChatModel,
+ queries: list[str] | None = None,
) -> None:
super().__init__()
self.flags = flags
@@ -130,6 +134,8 @@ def time_for_caution():
hint_index_path=flags.hint_index_path,
hint_retriever_path=flags.hint_retriever_path,
hint_num_results=flags.hint_num_results,
+ hint_level=flags.hint_level,
+ queries=queries,
)
self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan
self.criticise = Criticise(visible=lambda: flags.use_criticise)
@@ -324,6 +330,8 @@ def __init__(
hint_num_results: int = 5,
skip_hints_for_current_task: bool = False,
hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct",
+ hint_level: Literal["episode", "step"] = "episode",
+ queries: list[str] | None = None,
) -> None:
super().__init__(visible=use_task_hint)
self.use_task_hint = use_task_hint
@@ -339,6 +347,8 @@ def __init__(
self.skip_hints_for_current_task = skip_hints_for_current_task
self.goal = goal
self.llm = llm
+ self.hint_level: Literal["episode", "step"] = hint_level
+ self.queries: list[str] | None = queries
self._init()
_prompt = "" # Task hints are added dynamically in MainPrompt
@@ -394,6 +404,7 @@ def _init(self):
else:
print(f"Warning: Hint database not found at {hint_db_path}")
self.hint_db = pd.DataFrame(columns=["task_name", "hint"])
+
self.hints_source = HintsSource(
hint_db_path=hint_db_path.as_posix(),
hint_retrieval_mode=self.hint_retrieval_mode,
@@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str:
return ""
try:
- task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal)
+ # When step-level, pass queries as goal string to fit the llm_prompt
+ goal_or_queries = self.goal
+ if self.hint_level == "step" and self.queries:
+ goal_or_queries = "\n".join(self.queries)
+
+ task_hints = self.hints_source.choose_hints(
+ self.llm,
+ task_name,
+ goal_or_queries,
+ )
hints = []
for hint in task_hints:
@@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str:
print(f"Warning: Error getting hints for task {task_name}: {e}")
return ""
+
+
+class StepWiseContextIdentificationPrompt(dp.Shrinkable):
+ def __init__(
+ self,
+ obs_history: list[dict],
+ actions: list[str],
+ thoughts: list[str],
+ obs_flags: dp.ObsFlags,
+ n_queries: int = 1,
+ ) -> None:
+ super().__init__()
+ self.obs_flags = obs_flags
+ self.n_queries = n_queries
+ self.history = dp.History(obs_history, actions, None, thoughts, obs_flags)
+ self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"])
+ self.obs = dp.Observation(obs_history[-1], obs_flags)
+
+ self.think = dp.Think(visible=True) # To replace with static text maybe
+
+ @property
+ def _prompt(self) -> HumanMessage:
+ prompt = HumanMessage(self.instructions.prompt)
+
+ prompt.add_text(
+ f"""\
+{self.obs.prompt}\
+{self.history.prompt}\
+"""
+ )
+
+ example_queries = [
+ "The user has started sorting a table and needs to apply multiple column criteria simultaneously.",
+ "The user is attempting to configure advanced sorting options but the interface is unclear.",
+ "The user has selected the first sort column and is now looking for how to add a second sort criterion.",
+ "The user is in the middle of a multi-step sorting process and needs guidance on the next action.",
+ ]
+
+ example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2)
+
+ prompt.add_text(
+ f"""
+# Querying memory
+
+Before choosing an action, let's search our available documentation and memory for relevant context.
+Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow
+chain of thought
+json list of strings for the queries. Return exactly {self.n_queries}
+queries in the list.
+
+# Concrete Example
+
+
+I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if
+I will be able to sort by both at the same time.
+
+
+
+{example_queries_str}
+
+"""
+ )
+
+ return self.obs.add_screenshot(prompt)
+
+ def shrink(self):
+ self.history.shrink()
+ self.obs.shrink()
+
+ def _parse_answer(self, text_answer):
+ ans_dict = parse_html_tags_raise(
+ text_answer, keys=["think", "queries"], merge_multiple=True
+ )
+ ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]"))
+ return ans_dict