-
Notifications
You must be signed in to change notification settings - Fork 105
Add StepWiseQueriesPrompt for enhanced query handling in GenericAgent #291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| It is based on the dynamic_prompting module from the agentlab package. | ||
| """ | ||
|
|
||
| import json | ||
| import logging | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
@@ -67,6 +68,8 @@ class GenericPromptFlags(dp.Flags): | |
| hint_index_path: str = None | ||
| hint_retriever_path: str = None | ||
| hint_num_results: int = 5 | ||
| n_retrieval_queries: int = 3 | ||
| hint_level: Literal["episode", "step"] = "episode" | ||
|
|
||
|
|
||
| class MainPrompt(dp.Shrinkable): | ||
|
|
@@ -81,6 +84,7 @@ def __init__( | |
| step: int, | ||
| flags: GenericPromptFlags, | ||
| llm: ChatModel, | ||
| queries: list[str] | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.flags = flags | ||
|
|
@@ -130,6 +134,8 @@ def time_for_caution(): | |
| hint_index_path=flags.hint_index_path, | ||
| hint_retriever_path=flags.hint_retriever_path, | ||
| hint_num_results=flags.hint_num_results, | ||
| hint_level=flags.hint_level, | ||
| queries=queries, | ||
| ) | ||
| self.plan = Plan(previous_plan, step, lambda: flags.use_plan) # TODO add previous plan | ||
| self.criticise = Criticise(visible=lambda: flags.use_criticise) | ||
|
|
@@ -324,6 +330,8 @@ def __init__( | |
| hint_num_results: int = 5, | ||
| skip_hints_for_current_task: bool = False, | ||
| hint_retrieval_mode: Literal["direct", "llm", "emb"] = "direct", | ||
| hint_level: Literal["episode", "step"] = "episode", | ||
| queries: list[str] | None = None, | ||
| ) -> None: | ||
| super().__init__(visible=use_task_hint) | ||
| self.use_task_hint = use_task_hint | ||
|
|
@@ -339,6 +347,8 @@ def __init__( | |
| self.skip_hints_for_current_task = skip_hints_for_current_task | ||
| self.goal = goal | ||
| self.llm = llm | ||
| self.hint_level: Literal["episode", "step"] = hint_level | ||
| self.queries: list[str] | None = queries | ||
| self._init() | ||
|
|
||
| _prompt = "" # Task hints are added dynamically in MainPrompt | ||
|
|
@@ -394,6 +404,7 @@ def _init(self): | |
| else: | ||
| print(f"Warning: Hint database not found at {hint_db_path}") | ||
| self.hint_db = pd.DataFrame(columns=["task_name", "hint"]) | ||
|
|
||
| self.hints_source = HintsSource( | ||
| hint_db_path=hint_db_path.as_posix(), | ||
| hint_retrieval_mode=self.hint_retrieval_mode, | ||
|
|
@@ -448,7 +459,16 @@ def get_hints_for_task(self, task_name: str) -> str: | |
| return "" | ||
|
|
||
| try: | ||
| task_hints = self.hints_source.choose_hints(self.llm, task_name, self.goal) | ||
| # When step-level, pass queries as goal string to fit the llm_prompt | ||
| goal_or_queries = self.goal | ||
| if self.hint_level == "step" and self.queries: | ||
| goal_or_queries = "\n".join(self.queries) | ||
|
|
||
| task_hints = self.hints_source.choose_hints( | ||
| self.llm, | ||
| task_name, | ||
| goal_or_queries, | ||
| ) | ||
|
|
||
| hints = [] | ||
| for hint in task_hints: | ||
|
|
@@ -466,3 +486,78 @@ def get_hints_for_task(self, task_name: str) -> str: | |
| print(f"Warning: Error getting hints for task {task_name}: {e}") | ||
|
|
||
| return "" | ||
|
|
||
|
|
||
| class StepWiseContextIdentificationPrompt(dp.Shrinkable): | ||
| def __init__( | ||
| self, | ||
| obs_history: list[dict], | ||
| actions: list[str], | ||
| thoughts: list[str], | ||
| obs_flags: dp.ObsFlags, | ||
| n_queries: int = 1, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.obs_flags = obs_flags | ||
| self.n_queries = n_queries | ||
| self.history = dp.History(obs_history, actions, None, thoughts, obs_flags) | ||
| self.instructions = dp.GoalInstructions(obs_history[-1]["goal_object"]) | ||
| self.obs = dp.Observation(obs_history[-1], obs_flags) | ||
|
|
||
| self.think = dp.Think(visible=True) # To replace with static text maybe | ||
|
|
||
| @property | ||
| def _prompt(self) -> HumanMessage: | ||
| prompt = HumanMessage(self.instructions.prompt) | ||
|
|
||
| prompt.add_text( | ||
| f"""\ | ||
| {self.obs.prompt}\ | ||
| {self.history.prompt}\ | ||
| """ | ||
| ) | ||
|
|
||
| example_queries = [ | ||
| "The user has started sorting a table and needs to apply multiple column criteria simultaneously.", | ||
| "The user is attempting to configure advanced sorting options but the interface is unclear.", | ||
| "The user has selected the first sort column and is now looking for how to add a second sort criterion.", | ||
| "The user is in the middle of a multi-step sorting process and needs guidance on the next action.", | ||
| ] | ||
|
|
||
| example_queries_str = json.dumps(example_queries[: self.n_queries], indent=2) | ||
|
|
||
| prompt.add_text( | ||
| f""" | ||
| # Querying memory | ||
|
|
||
| Before choosing an action, let's search our available documentation and memory for relevant context. | ||
| Generate a brief, general summary of the current status to help identify useful hints. Return your answer as follow | ||
| <think>chain of thought</think> | ||
| <queries>json list of strings</queries> for the queries. Return exactly {self.n_queries} | ||
| queries in the list. | ||
|
|
||
| # Concrete Example | ||
|
|
||
| <think> | ||
| I have to sort by client and country. I could use the built-in sort on each column but I'm not sure if | ||
| I will be able to sort by both at the same time. | ||
| </think> | ||
|
|
||
| <queries> | ||
| {example_queries_str} | ||
| </queries> | ||
| """ | ||
| ) | ||
|
|
||
| return self.obs.add_screenshot(prompt) | ||
|
|
||
| def shrink(self): | ||
| self.history.shrink() | ||
| self.obs.shrink() | ||
|
|
||
| def _parse_answer(self, text_answer): | ||
| ans_dict = parse_html_tags_raise( | ||
| text_answer, keys=["think", "queries"], merge_multiple=True | ||
| ) | ||
| ans_dict["queries"] = json.loads(ans_dict.get("queries", "[]")) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unsafe JSON Deserialization
Tell me moreWhat is the issue?Unsafe JSON deserialization of untrusted data from the LLM response without validation. Why this mattersMalicious JSON payloads could potentially lead to code execution or denial of service through carefully crafted inputs that exploit json.loads vulnerabilities. Suggested change ∙ Feature Previewdef validate_queries(queries):
if not isinstance(queries, list):
raise ValueError("Queries must be a list")
if not all(isinstance(q, str) for q in queries):
raise ValueError("All queries must be strings")
return queries
ans_dict["queries"] = validate_queries(json.loads(ans_dict.get("queries", "[]")))Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| return ans_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsafe Query Length Validation
Tell me more
What is the issue?
A hard assertion on query length without proper validation that self.flags.n_retrieval_queries exists and is not None.
Why this matters
The code will crash with an AttributeError if n_retrieval_queries is not defined in flags, or with an AssertionError if the number of queries doesn't match exactly.
Suggested change ∙ Feature Preview
Replace with a safer validation approach:
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.