diff --git a/src/agentlab/agents/generic_agent/generic_agent.py b/src/agentlab/agents/generic_agent/generic_agent.py index d1f48f76..646a52b2 100644 --- a/src/agentlab/agents/generic_agent/generic_agent.py +++ b/src/agentlab/agents/generic_agent/generic_agent.py @@ -98,6 +98,7 @@ def obs_preprocessor(self, obs: dict) -> dict: def get_action(self, obs): self.obs_history.append(obs) + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent.py b/src/agentlab/agents/generic_agent_hinter/generic_agent.py index 5e04df0d..0cbdb6b3 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent.py @@ -23,7 +23,11 @@ from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry from agentlab.llm.tracking import cost_tracker_decorator -from .generic_agent_prompt import GenericPromptFlags, MainPrompt +from .generic_agent_prompt import ( + GenericPromptFlags, + MainPrompt, + StepWiseContextIdentificationPrompt, +) @dataclass @@ -102,6 +106,14 @@ def set_task_name(self, task_name: str): def get_action(self, obs): self.obs_history.append(obs) + + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + + queries, think_queries = self._get_queries() + + # use those queries to retrieve from the database and pass to prompt if step-level + queries_for_hints = queries if getattr(self.flags, "hint_level", "episode") == "step" else None + main_prompt = MainPrompt( action_set=self.action_set, obs_history=self.obs_history, @@ -112,6 +124,7 @@ def get_action(self, obs): step=self.plan_step, flags=self.flags, llm=self.chat_llm, + queries=queries_for_hints, ) # Set task name for task hints if available @@ -120,8 +133,6 @@ def get_action(self, obs): max_prompt_tokens, max_trunc_itr = self._get_maxes() - system_prompt = SystemMessage(dp.SystemPrompt().prompt) - human_prompt = dp.fit_tokens( shrinkable=main_prompt, max_prompt_tokens=max_prompt_tokens, @@ -168,6 +179,31 @@ def get_action(self, obs): ) return ans_dict["action"], agent_info + def _get_queries(self): + """Retrieve queries for hinting.""" + system_prompt = SystemMessage(dp.SystemPrompt().prompt) + query_prompt = StepWiseContextIdentificationPrompt( + obs_history=self.obs_history, + actions=self.actions, + thoughts=self.thoughts, + obs_flags=self.flags.obs, + n_queries=self.flags.n_retrieval_queries, # TODO + ) + + chat_messages = Discussion([system_prompt, query_prompt.prompt]) + ans_dict = retry( + self.chat_llm, + chat_messages, + n_retry=self.max_retry, + parser=query_prompt._parse_answer, + ) + + queries = ans_dict.get("queries", []) + assert len(queries) == self.flags.n_retrieval_queries + + # TODO: we should probably propagate these chat_messages to be able to see them in xray + return queries, ans_dict.get("think", None) + def reset(self, seed=None): self.seed = seed self.plan = "No plan yet" diff --git a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py index 76205341..10cfeef6 100644 --- a/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py +++ b/src/agentlab/agents/generic_agent_hinter/generic_agent_prompt.py @@ -4,6 +4,7 @@ It is based on the dynamic_prompting module from the agentlab package. """ +import json import logging from dataclasses import dataclass from pathlib import Path @@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags): hint_index_path: str = None hint_retriever_path: str = None hint_num_results: int = 5 + n_retrieval_queries: int = 3 + hint_level: Literal["episode", "step"] = "episode" class MainPrompt(dp.Shrinkable): @@ -81,6 +84,7 @@ def __init__( step: int, flags: GenericPromptFlags, llm: ChatModel, + queries: list[str] | None = None, ) -> None: super().__init__() self.flags = flags @@ -130,6 +134,8 @@ def time_for_caution(): hint_index_path=flags.hint_index_path, hint_retriever_path=flags.hint_retriever_path, hint_num_results=flags.hint_num_results, + hint_level=flags.hint_level, + queries=queries, ) self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan self.criticise = Criticise(visible=lambda: flags.use_criticise) @@ -324,6 +330,8 @@ def __init__( hint_num_results: int = 5, skip_hints_for_current_task: bool = False, hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", + hint_level: Literal["episode", "step"] = "episode", + queries: list[str] | None = None, ) -> None: super().__init__(visible=use_task_hint) self.use_task_hint = use_task_hint @@ -339,6 +347,8 @@ def __init__( self.skip_hints_for_current_task = skip_hints_for_current_task self.goal = goal self.llm = llm + self.hint_level: Literal["episode", "step"] = hint_level + self.queries: list[str] | None = queries self._init() _prompt = "" # Task hints are added dynamically in MainPrompt @@ -394,6 +404,7 @@ def _init(self): else: print(f"Warning: Hint database not found at {hint_db_path}") self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) + self.hints_source = HintsSource( hint_db_path=hint_db_path.as_posix(), hint_retrieval_mode=self.hint_retrieval_mode, @@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str: return "" try: - task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) + # When step-level, pass queries as goal string to fit the llm_prompt + goal_or_queries = self.goal + if self.hint_level == "step" and self.queries: + goal_or_queries = "\n".join(self.queries) + + task_hints = self.hints_source.choose_hints( + self.llm, + task_name, + goal_or_queries, + ) hints = [] for hint in task_hints: @@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str: print(f"Warning: Error getting hints for task {task_name}: {e}") return "" + + +class StepWiseContextIdentificationPrompt(dp.Shrinkable): + def __init__( + self, + obs_history: list[dict], + actions: list[str], + thoughts: list[str], + obs_flags: dp.ObsFlags, + n_queries: int = 1, + ) -> None: + super().__init__() + self.obs_flags = obs_flags + self.n_queries = n_queries + self.history = dp.History(obs_history, actions, None, thoughts, obs_flags) + self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"]) + self.obs = dp.Observation(obs_history[-1], obs_flags) + + self.think = dp.Think(visible=True) # To replace with static text maybe + + @property + def _prompt(self) -> HumanMessage: + prompt = HumanMessage(self.instructions.prompt) + + prompt.add_text( + f"""\ +{self.obs.prompt}\ +{self.history.prompt}\ +""" + ) + + example_queries = [ + "The user has started sorting a table and needs to apply multiple column criteria simultaneously.", + "The user is attempting to configure advanced sorting options but the interface is unclear.", + "The user has selected the first sort column and is now looking for how to add a second sort criterion.", + "The user is in the middle of a multi-step sorting process and needs guidance on the next action.", + ] + + example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) + + prompt.add_text( + f""" +# Querying memory + +Before choosing an action, let's search our available documentation and memory for relevant context. +Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow +chain of thought +json list of strings for the queries. Return exactly {self.n_queries} +queries in the list. + +# Concrete Example + + +I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if +I will be able to sort by both at the same time. + + + +{example_queries_str} + +""" + ) + + return self.obs.add_screenshot(prompt) + + def shrink(self): + self.history.shrink() + self.obs.shrink() + + def _parse_answer(self, text_answer): + ans_dict = parse_html_tags_raise( + text_answer, keys=["think", "queries"], merge_multiple=True + ) + ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) + return ans_dict