Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions dreadnode/agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

if t.TYPE_CHECKING:
from dreadnode.agent.tools import execute, fs, memory, planning, reporting, tasking
from dreadnode.agent.tools import execute, fs, interaction, memory, planning, reporting, tasking

__all__ = [
"AnyTool",
Expand All @@ -30,6 +30,7 @@
"discover_tools_on_obj",
"execute",
"fs",
"interaction",
"memory",
"planning",
"reporting",
Expand All @@ -38,7 +39,15 @@
"tool_method",
]

__lazy_submodules__: list[str] = ["fs", "planning", "reporting", "tasking", "execute", "memory"]
__lazy_submodules__: list[str] = [
"execute",
"fs",
"interaction",
"memory",
"planning",
"reporting",
"tasking",
]
__lazy_components__: dict[str, str] = {}


Expand Down
80 changes: 80 additions & 0 deletions dreadnode/agent/tools/interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import typing as t

from pydantic import BaseModel, Field

from dreadnode.agent.tools.base import tool


class QuestionOption(BaseModel):
label: str = Field(..., description="Display text for this option")
description: str = Field(..., description="Explanation of this option")


class Question(BaseModel):
id: str = Field(..., description="Unique identifier for this question")
question: str = Field(..., description="The question text")
options: list[QuestionOption] = Field(..., min_length=2, max_length=4)
multi_select: bool = Field(default=False)


@tool(catch=True)
def ask_user(
questions: t.Annotated[
list[dict[str, t.Any]],
"List of questions to ask. Each has: id, question, options (list of {label, description}), multi_select (optional)",
],
) -> str:
"""
Ask the user one or more multiple-choice questions during execution.

Use this when you need user input to make decisions, clarify requirements,
or get preferences that affect your work.

Each question should have 2-4 options. Users can select one option, or multiple
if multi_select is true.

Returns a formatted string with the user's responses.
"""
from dreadnode import log_metric, log_output

parsed = [Question(**q) for q in questions]

print("\n" + "=" * 80)
print("Agent needs your input:")
print("=" * 80 + "\n")

responses: dict[str, str | list[str]] = {}

for q_num, q in enumerate(parsed, 1):
print(f"Question {q_num}/{len(parsed)}: {q.question}\n")

for idx, opt in enumerate(q.options, 1):
print(f" {idx}. {opt.label}")
print(f" {opt.description}\n")

if q.multi_select:
choice_str = input("Your choice(s) (comma-separated): ").strip()
choices = [int(c.strip()) for c in choice_str.split(",")]
selected: list[str] = [q.options[i - 1].label for i in choices]
responses[q.id] = selected
else:
choice = int(input("Your choice: ").strip())
selected_option: str = q.options[choice - 1].label
responses[q.id] = selected_option

print()

print("=" * 80 + "\n")

log_output("user_questions", [q.model_dump() for q in parsed])
log_output("user_responses", responses)
log_metric("questions_asked", len(parsed))

result = "User responses:\n"
for q_id, answer in responses.items():
if isinstance(answer, list):
result += f"{q_id}: {', '.join(answer)}\n"
else:
result += f"{q_id}: {answer}\n"

return result
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,6 @@ skip-magic-trailing-comma = false
"dreadnode/transforms/language.py" = [
"RUF001", # intentional use of ambiguous unicode characters for airt
]
"dreadnode/agent/tools/interaction.py" = [
"T201", # print required for user interaction
]
56 changes: 56 additions & 0 deletions tests/test_interaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from unittest.mock import patch

from dreadnode.agent.tools.interaction import Question, QuestionOption, ask_user


def test_question_model() -> None:
q = Question(
id="test",
question="Test question?",
options=[
QuestionOption(label="Option 1", description="First option"),
QuestionOption(label="Option 2", description="Second option"),
],
)
assert q.id == "test"
assert len(q.options) == 2
assert q.multi_select is False


def test_ask_user_single_choice() -> None:
questions = [
{
"id": "choice",
"question": "Pick one",
"options": [
{"label": "A", "description": "First"},
{"label": "B", "description": "Second"},
],
}
]

with patch("builtins.input", return_value="1"), patch("builtins.print"):
result = ask_user(questions)

assert "choice: A" in result


def test_ask_user_multi_choice() -> None:
questions = [
{
"id": "choices",
"question": "Pick multiple",
"options": [
{"label": "A", "description": "First"},
{"label": "B", "description": "Second"},
{"label": "C", "description": "Third"},
],
"multi_select": True,
}
]

with patch("builtins.input", return_value="1,3"), patch("builtins.print"):
result = ask_user(questions)

assert "A" in result
assert "C" in result