From 94fa1ab7fb7a69e1ae75a6964a4d77098f0050f0 Mon Sep 17 00:00:00 2001 From: recursix Date: Thu, 4 Sep 2025 16:46:46 -0400 Subject: [PATCH 1/3] Add StepWiseQueriesPrompt for enhanced query handling in GenericAgent --- .../agents/generic_agent/generic_agent.py | 1 + .../generic_agent_hinter/generic_agent.py | 43 ++++++++++- .../generic_agent_prompt.py | 73 +++++++++++++++++++ 3 files changed, 114 insertions(+), 3 deletions(-) diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index d1f48f76..646a52b2 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict: def get_action(self, obs): self.obs_history.append(obs) + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index cfbd19bd..c8368039 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -23,7 +23,11 @@ from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from .generic_agent_prompt import GenericPromptFlags, MainPrompt +from .generic_agent_prompt import ( + GenericPromptFlags, + MainPrompt, + StepWiseRetrievalPrompt, +) @dataclass @@ -102,6 +106,16 @@ def set_task_name(self, task_name: str): def get_action(self, obs): self.obs_history.append(obs) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + + queries, think_queries = self._get_queries() + + # TODO + # use those queries to retreive from the database. e.g.: + # hints = self.hint_db.get_hints(queries) + # then add those hints to the main prompt + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, @@ -120,8 +134,6 @@ def get_action(self, obs): max_prompt_tokens, max_trunc_itr = self._get_maxes() - system_prompt = SystemMessage(dp.SystemPrompt().prompt) - human_prompt = dp.fit_tokens( shrinkable=main_prompt, max_prompt_tokens=max_prompt_tokens, @@ -168,6 +180,31 @@ def get_action(self, obs): ) return ans_dict["action"], agent_info + def _get_queries(self): + """Retrieve queries for hinting.""" + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + query_prompt = StepWiseRetrievalPrompt( + obs_history=self.obs_history, + actions=self.actions, + thoughts=self.thoughts, + obs_flags=self.flags.obs, + n_queries=self.flags.n_retrieval_queries, # TODO + ) + + chat_messages = Discussion([system_prompt, query_prompt.prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=query_prompt._parse_answer, + ) + + queries = ans_dict.get("queries", []) + assert len(queries) == self.flags.n_retrieval_queries + + # TODO: we should probably propagate these chat_messages to be able to see them in xray + return queries, ans_dict.get("think", None) + def reset(self, seed=None): self.seed = seed self.plan = "No plan yet" diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 983c9d48..44d17845 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -4,6 +4,7 @@ It is based on the dynamic_prompting module from the agentlab package. """ +import json import logging from dataclasses import dataclass from pathlib import Path @@ -60,6 +61,7 @@ class GenericPromptFlags(dp.Flags): add_missparsed_messages: bool = True max_trunc_itr: int = 20 flag_group: str = None + n_retrieval_queries: int = 3 class MainPrompt(dp.Shrinkable): @@ -396,3 +398,74 @@ def get_hints_for_task(self, task_name: str) -> str: print(f"Warning: Error getting hints for task {task_name}: {e}") return "" + + +class StepWiseRetrievalPrompt(dp.Shrinkable): + def __init__( + self, + obs_history: list[dict], + actions: list[str], + thoughts: list[str], + obs_flags: dp.ObsFlags, + n_queries: int = 3, + ) -> None: + super().__init__() + self.obs_flags = obs_flags + self.n_queries = n_queries + self.history = dp.History(obs_history, actions, None, thoughts, obs_flags) + self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"]) + self.obs = dp.Observation(obs_history[-1], obs_flags) + + self.think = dp.Think(visible=True) # To replace with static text maybe + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +""" + ) + + example_queries = [ + "How to sort with multiple columns on the ServiceNow platform?", + "What are the potential challenges of sorting by multiple columns?", + "How to handle sorting by multiple columns in a table?", + "Can I use the filter tool to sort by multiple columns?", + ] + + example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) + + prompt.add_text( + f""" +# Querying memory + +Before choosing an action, let's search our available documentation and memory on how to approach this step. +This could provide valuable hints on how to properly solve this task. Return your answer as follow +chain of thought +json list of strings for the queries. Return exactly {self.n_queries} +queries in the list. + +# Concrete Example + + +I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if +I will be able to sort by both at the same time. + + + +{example_queries_str} + +""" + ) + + return self.obs.add_screenshot(prompt) + + def _parse_answer(self, text_answer): + ans_dict = parse_html_tags_raise( + text_answer, keys=["think", "queries"], merge_multiple=True + ) + ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + return ans_dict From ee2653a0925b39da7ac9f1db3b3490ee3c7e8c17 Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Mon, 8 Sep 2025 19:41:27 -0400 Subject: [PATCH 2/3] stepwise hint retrieval --- .../generic_agent_hinter/generic_agent.py | 11 +++--- .../generic_agent_prompt.py | 36 ++++++++++++++----- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index c8368039..50e6d399 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -26,7 +26,7 @@ from .generic_agent_prompt import ( GenericPromptFlags, MainPrompt, - StepWiseRetrievalPrompt, + StepWiseContextIdentificationPrompt, ) @@ -111,10 +111,8 @@ def get_action(self, obs): queries, think_queries = self._get_queries() - # TODO - # use those queries to retreive from the database. e.g.: - # hints = self.hint_db.get_hints(queries) - # then add those hints to the main prompt + # use those queries to retrieve from the database and pass to prompt if step-level + queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None main_prompt = MainPrompt( action_set=self.action_set, @@ -126,6 +124,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, + queries=queries_for_hints, ) # Set task name for task hints if available @@ -183,7 +182,7 @@ def get_action(self, obs): def _get_queries(self): """Retrieve queries for hinting.""" system_prompt = SystemMessage(dp.SystemPrompt().prompt) - query_prompt = StepWiseRetrievalPrompt( + query_prompt = StepWiseContextIdentificationPrompt( obs_history=self.obs_history, actions=self.actions, thoughts=self.thoughts, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 44d17845..cf87a326 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -62,6 +62,7 @@ class GenericPromptFlags(dp.Flags): max_trunc_itr: int = 20 flag_group: str = None n_retrieval_queries: int = 3 + hint_level: Literal["episode", "step"] = "episode" class MainPrompt(dp.Shrinkable): @@ -76,6 +77,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, + queries: list[str] | None = None, ) -> None: super().__init__() self.flags = flags @@ -118,6 +120,8 @@ def time_for_caution(): hint_retrieval_mode=flags.task_hint_retrieval_mode, llm=llm, skip_hints_for_current_task=flags.skip_hints_for_current_task, + hint_level=flags.hint_level, + queries=queries, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -306,6 +310,8 @@ def __init__( hint_retrieval_mode: Literal["direct", "llm", "emb"], skip_hints_for_current_task: bool, llm: ChatModel, + hint_level: Literal["episode", "step"] = "episode", + queries: list[str] | None = None, ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint @@ -315,6 +321,8 @@ def __init__( self.skip_hints_for_current_task = skip_hints_for_current_task self.goal = goal self.llm = llm + self.hint_level: Literal["episode", "step"] = hint_level + self.queries: list[str] | None = queries self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -352,6 +360,7 @@ def _init(self): else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( hint_db_path=hint_db_path.as_posix(), hint_retrieval_mode=self.hint_retrieval_mode, @@ -380,7 +389,16 @@ def get_hints_for_task(self, task_name: str) -> str: return "" try: - task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.goal + if self.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + task_name, + goal_or_queries, + ) hints = [] for hint in task_hints: @@ -400,14 +418,14 @@ def get_hints_for_task(self, task_name: str) -> str: return "" -class StepWiseRetrievalPrompt(dp.Shrinkable): +class StepWiseContextIdentificationPrompt(dp.Shrinkable): def __init__( self, obs_history: list[dict], actions: list[str], thoughts: list[str], obs_flags: dp.ObsFlags, - n_queries: int = 3, + n_queries: int = 1, ) -> None: super().__init__() self.obs_flags = obs_flags @@ -430,10 +448,10 @@ def _prompt(self) -> HumanMessage: ) example_queries = [ - "How to sort with multiple columns on the ServiceNow platform?", - "What are the potential challenges of sorting by multiple columns?", - "How to handle sorting by multiple columns in a table?", - "Can I use the filter tool to sort by multiple columns?", + "The user has started sorting a table and needs to apply multiple column criteria simultaneously.", + "The user is attempting to configure advanced sorting options but the interface is unclear.", + "The user has selected the first sort column and is now looking for how to add a second sort criterion.", + "The user is in the middle of a multi-step sorting process and needs guidance on the next action.", ] example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) @@ -442,8 +460,8 @@ def _prompt(self) -> HumanMessage: f""" # Querying memory -Before choosing an action, let's search our available documentation and memory on how to approach this step. -This could provide valuable hints on how to properly solve this task. Return your answer as follow +Before choosing an action, let's search our available documentation and memory for relevant context. +Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow chain of thought json list of strings for the queries. Return exactly {self.n_queries} queries in the list. From ca11170a7d7e7d44669bba2369c9b5dddf6a75ec Mon Sep 17 00:00:00 2001 From: Hadi Nekoei Date: Tue, 9 Sep 2025 00:05:14 -0400 Subject: [PATCH 3/3] added shrink method --- .../agents/generic_agent_hinter/generic_agent_prompt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 21ed167d..10cfeef6 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -551,6 +551,10 @@ def _prompt(self) -> HumanMessage: return self.obs.add_screenshot(prompt) + def shrink(self): + self.history.shrink() + self.obs.shrink() + def _parse_answer(self, text_answer): ans_dict = parse_html_tags_raise( text_answer, keys=["think", "queries"], merge_multiple=True