Skip to content

Commit fa9546e

Browse files
committed
[minimcp] Add PromptManager for prompt template registration and execution
- Implement PromptManager class for managing MCP prompt handlers - Add prompt registration via decorator (@mcp.prompt()) or programmatically - Support prompt listing, getting, and removal operations - Add automatic argument inference from function signatures - Add support for multiple content types and annotations - Add comprehensive unit test suite
1 parent 6a350ca commit fa9546e

File tree

2 files changed

+1284
-0
lines changed

2 files changed

+1284
-0
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
import builtins
2+
import logging
3+
from collections.abc import Callable
4+
from functools import partial
5+
from typing import Any
6+
7+
import pydantic_core
8+
from typing_extensions import TypedDict, Unpack
9+
10+
from mcp.server.lowlevel.server import Server
11+
from mcp.server.minimcp.exceptions import InvalidArgumentsError, MCPRuntimeError, PrimitiveError
12+
from mcp.server.minimcp.utils.mcp_func import MCPFunc
13+
from mcp.types import AnyFunction, GetPromptResult, Prompt, PromptArgument, PromptMessage, TextContent
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class PromptDefinition(TypedDict, total=False):
19+
"""
20+
Type definition for prompt parameters.
21+
22+
Attributes:
23+
name: Optional unique identifier for the prompt. If not provided, the function name is used.
24+
Must be unique across all prompts in the server.
25+
title: Optional human-readable name for display purposes. Shows in client UIs (e.g., as slash commands).
26+
description: Optional human-readable description of what the prompt does. If not provided,
27+
the function's docstring is used.
28+
meta: Optional metadata dictionary for additional prompt information.
29+
"""
30+
31+
name: str | None
32+
title: str | None
33+
description: str | None
34+
meta: dict[str, Any] | None
35+
36+
37+
class PromptManager:
38+
"""
39+
PromptManager is responsible for registration and execution of MCP prompt handlers.
40+
41+
The Model Context Protocol (MCP) provides a standardized way for servers to expose prompt templates
42+
to clients. Prompts allow servers to provide structured messages and instructions for interacting
43+
with language models. Clients can discover available prompts, retrieve their contents, and provide
44+
arguments to customize them.
45+
46+
Prompts are designed to be user-controlled, exposed from servers to clients with the intention of
47+
the user being able to explicitly select them for use. Typically, prompts are triggered through
48+
user-initiated commands in the user interface, such as slash commands in chat applications.
49+
50+
The PromptManager can be used as a decorator (@mcp.prompt()) or programmatically via the mcp.prompt.add(),
51+
mcp.prompt.list(), mcp.prompt.get() and mcp.prompt.remove() methods.
52+
53+
When a prompt handler is added, its name (unique identifier) and description are automatically inferred
54+
from the handler function. You can override these by passing explicit parameters. The title field provides
55+
a human-readable name for display in client UIs. Prompt arguments are always inferred from the function
56+
signature. Type annotations are required in the function signature for correct argument extraction.
57+
58+
Prompt messages can contain different content types (text, image, audio, embedded resources) and support
59+
optional annotations for metadata. Handler functions typically return strings or PromptMessage objects,
60+
which are automatically converted to the appropriate message format with role ("user" or "assistant").
61+
62+
For more details, see: https://modelcontextprotocol.io/specification/2025-06-18/server/prompts
63+
64+
Example:
65+
@mcp.prompt()
66+
def problem_solving(problem_description: str) -> str:
67+
return f"You are a math problem solver. Solve: {problem_description}"
68+
69+
# With display title for UI (e.g., as slash command)
70+
@mcp.prompt(name="solver", title="💡 Problem Solver", description="Solve a math problem")
71+
def problem_solving(problem_description: str) -> str:
72+
return f"You are a math problem solver. Solve: {problem_description}"
73+
74+
# Or programmatically:
75+
mcp.prompt.add(problem_solving, name="solver", title="Problem Solver")
76+
"""
77+
78+
_prompts: dict[str, tuple[Prompt, MCPFunc]]
79+
80+
def __init__(self, core: Server):
81+
"""
82+
Args:
83+
core: The low-level MCP Server instance to hook into.
84+
"""
85+
self._prompts = {}
86+
self._hook_core(core)
87+
88+
def _hook_core(self, core: Server) -> None:
89+
"""Register prompt handlers with the MCP core server.
90+
91+
Args:
92+
core: The low-level MCP Server instance to hook into.
93+
"""
94+
core.list_prompts()(self._async_list)
95+
core.get_prompt()(self.get)
96+
# core.complete()(self._async_complete) # TODO: Implement completion for prompts
97+
98+
def __call__(self, **kwargs: Unpack[PromptDefinition]) -> Callable[[AnyFunction], Prompt]:
99+
"""Decorator to add/register a prompt handler at the time of handler function definition.
100+
101+
Prompt name and description are automatically inferred from the handler function. You can override
102+
these by passing explicit parameters (name, title, description, meta) as shown in the example below.
103+
Prompt arguments are always inferred from the function signature. Type annotations are required
104+
in the function signature for proper argument extraction.
105+
106+
Args:
107+
**kwargs: Optional prompt definition parameters (name, title, description, meta).
108+
Parameters are defined in the PromptDefinition class.
109+
110+
Returns:
111+
A decorator function that adds the prompt handler.
112+
113+
Example:
114+
@mcp.prompt(name="code_review", title="🔍 Request Code Review")
115+
def code_review(code: str) -> str:
116+
return f"Please review this code:\n{code}"
117+
"""
118+
return partial(self.add, **kwargs)
119+
120+
def add(self, func: AnyFunction, **kwargs: Unpack[PromptDefinition]) -> Prompt:
121+
"""To programmatically add/register a prompt handler function.
122+
123+
This is useful when the handler function is already defined and you have a function object
124+
that needs to be registered at runtime.
125+
126+
If not provided, the prompt name (unique identifier) and description are automatically inferred
127+
from the function's name and docstring. The title field should be provided for better display in
128+
client UIs. Arguments are always automatically inferred from the function signature. Type annotations
129+
are required in the function signature for proper argument extraction and validation.
130+
131+
Handler functions can return:
132+
- str: Converted to a user message with text content
133+
- PromptMessage: Used as-is with role ("user" or "assistant") and content
134+
- dict: Validated as PromptMessage
135+
- list/tuple: Multiple messages of any of the above types
136+
- Other types: JSON-serialized and converted to user messages
137+
138+
Args:
139+
func: The prompt handler function. Can be synchronous or asynchronous. Should return
140+
content that can be converted to PromptMessage objects.
141+
**kwargs: Optional prompt definition parameters to override inferred
142+
values (name, title, description, meta). Parameters are defined in
143+
the PromptDefinition class.
144+
145+
Returns:
146+
The registered Prompt object with unique identifier, optional title for display,
147+
and inferred arguments.
148+
149+
Raises:
150+
PrimitiveError: If a prompt with the same name is already registered or if the function
151+
isn't properly typed.
152+
"""
153+
154+
prompt_func = MCPFunc(func, kwargs.get("name"))
155+
if prompt_func.name in self._prompts:
156+
raise PrimitiveError(f"Prompt {prompt_func.name} already registered")
157+
158+
prompt = Prompt(
159+
name=prompt_func.name,
160+
title=kwargs.get("title", None),
161+
description=kwargs.get("description", prompt_func.doc),
162+
arguments=self._get_arguments(prompt_func),
163+
_meta=kwargs.get("meta", None),
164+
)
165+
166+
self._prompts[prompt_func.name] = (prompt, prompt_func)
167+
logger.debug("Prompt %s added", prompt_func.name)
168+
169+
return prompt
170+
171+
def _get_arguments(self, prompt_func: MCPFunc) -> list[PromptArgument]:
172+
"""Get the arguments for a prompt from the function signature per MCP specification.
173+
174+
Extracts parameter information from the function's input schema generated by MCPFunc,
175+
converting them to PromptArgument objects for MCP protocol compliance. Each argument
176+
includes a name, optional description, and required flag.
177+
178+
Arguments enable prompt customization and may be auto-completed through the MCP completion API.
179+
180+
Args:
181+
prompt_func: The MCPFunc wrapper containing the function's input schema.
182+
183+
Returns:
184+
A list of PromptArgument objects describing the prompt's parameters for customization.
185+
"""
186+
arguments: list[PromptArgument] = []
187+
188+
input_schema = prompt_func.input_schema
189+
if "properties" in input_schema:
190+
for param_name, param in input_schema["properties"].items():
191+
required = param_name in input_schema.get("required", [])
192+
arguments.append(
193+
PromptArgument(
194+
name=param_name,
195+
description=param.get("description"),
196+
required=required,
197+
)
198+
)
199+
200+
return arguments
201+
202+
def remove(self, name: str) -> Prompt:
203+
"""Remove a prompt by name.
204+
205+
Args:
206+
name: The name of the prompt to remove.
207+
208+
Returns:
209+
The removed Prompt object.
210+
211+
Raises:
212+
PrimitiveError: If the prompt is not found.
213+
"""
214+
if name not in self._prompts:
215+
# Raise INVALID_PARAMS as per MCP specification
216+
raise PrimitiveError(f"Unknown prompt: {name}")
217+
218+
return self._prompts.pop(name)[0]
219+
220+
def list(self) -> builtins.list[Prompt]:
221+
"""List all registered prompts.
222+
223+
Returns:
224+
A list of all registered Prompt objects.
225+
"""
226+
return [prompt[0] for prompt in self._prompts.values()]
227+
228+
async def _async_list(self) -> builtins.list[Prompt]:
229+
"""Async wrapper for list().
230+
231+
Returns:
232+
A list of all registered Prompt objects.
233+
"""
234+
return self.list()
235+
236+
async def get(self, name: str, args: dict[str, str] | None) -> GetPromptResult:
237+
"""Retrieve and execute a prompt by name, as specified in the MCP prompts/get protocol.
238+
239+
This method handles the MCP prompts/get request, executing the prompt handler function with
240+
the provided arguments. Arguments are validated against the prompt's argument definitions,
241+
and the result is converted to PromptMessage objects per the MCP specification.
242+
243+
PromptMessages include a role ("user" or "assistant") and content, which can be text, image,
244+
audio, or embedded resources. All content types support optional annotations for metadata.
245+
246+
Args:
247+
name: The unique identifier of the prompt to retrieve.
248+
args: Optional dictionary of arguments to pass to the prompt handler. Must include all
249+
required arguments as defined in the prompt. Arguments may be auto-completed through
250+
the completion API.
251+
252+
Returns:
253+
GetPromptResult containing:
254+
- description: Human-readable description of the prompt
255+
- messages: List of PromptMessage objects with role and content
256+
- _meta: Optional metadata
257+
258+
Raises:
259+
PrimitiveError: If the prompt is not found (maps to -32602 Invalid params per spec).
260+
MCPRuntimeError: If an error occurs during prompt execution or message conversion
261+
(maps to -32603 Internal error per spec).
262+
"""
263+
if name not in self._prompts:
264+
# Raise INVALID_PARAMS as per MCP specification
265+
raise PrimitiveError(f"Unknown prompt: {name}")
266+
267+
prompt, prompt_func = self._prompts[name]
268+
self._validate_args(prompt.arguments, args)
269+
270+
try:
271+
result = await prompt_func.execute(args)
272+
messages = self._convert_result(result)
273+
logger.debug("Prompt %s handled with args %s", name, args)
274+
275+
return GetPromptResult(
276+
description=prompt.description,
277+
messages=messages,
278+
_meta=prompt.meta,
279+
)
280+
except InvalidArgumentsError:
281+
raise
282+
except Exception as e:
283+
msg = f"Error getting prompt {name}: {e}"
284+
logger.exception(msg)
285+
raise MCPRuntimeError(msg) from e
286+
287+
def _validate_args(
288+
self, prompt_arguments: builtins.list[PromptArgument] | None, available_args: dict[str, Any] | None
289+
) -> None:
290+
"""Check for missing required arguments per MCP specification.
291+
292+
Args:
293+
prompt_arguments: The arguments for the prompt.
294+
available_args: The arguments provided by the client.
295+
296+
Raises:
297+
InvalidArgumentsError: If the required arguments are not provided.
298+
"""
299+
if prompt_arguments is None:
300+
return
301+
302+
required_arg_names = {arg.name for arg in prompt_arguments if arg.required}
303+
provided_arg_names = set(available_args or {})
304+
305+
missing_arg_names = required_arg_names - provided_arg_names
306+
if missing_arg_names:
307+
missing_arg_names_str = ", ".join(missing_arg_names)
308+
raise InvalidArgumentsError(
309+
f"Missing required arguments: Arguments {missing_arg_names_str} need to be provided"
310+
)
311+
312+
def _convert_result(self, result: Any) -> builtins.list[PromptMessage]:
313+
"""Convert prompt handler results to PromptMessage objects per MCP specification.
314+
315+
PromptMessages must include a role ("user" or "assistant") and content. Per the MCP spec,
316+
content can be:
317+
- Text content (type: "text") - most common for natural language interactions
318+
- Image content (type: "image") - base64-encoded with MIME type
319+
- Audio content (type: "audio") - base64-encoded with MIME type
320+
- Embedded resources (type: "resource") - server-side resources with URI
321+
322+
All content types support optional annotations for metadata about audience, priority,
323+
and modification times.
324+
325+
Supports multiple return types from handler functions:
326+
- PromptMessage objects (used as-is with role and content)
327+
- Dictionaries (validated as PromptMessage)
328+
- Strings (converted to user messages with text content)
329+
- Other types (JSON-serialized and converted to user messages with text content)
330+
- Lists/tuples of any of the above
331+
332+
Args:
333+
result: The return value from a prompt handler function.
334+
335+
Returns:
336+
A list of PromptMessage objects with role and content per MCP protocol.
337+
338+
Raises:
339+
MCPRuntimeError: If the result cannot be converted to valid messages.
340+
"""
341+
342+
if not isinstance(result, list | tuple):
343+
result = [result]
344+
345+
try:
346+
messages: list[PromptMessage] = []
347+
348+
for msg in result: # type: ignore[reportUnknownVariableType]
349+
if isinstance(msg, PromptMessage):
350+
messages.append(msg)
351+
elif isinstance(msg, dict):
352+
# Try to validate as PromptMessage
353+
messages.append(PromptMessage.model_validate(msg))
354+
elif isinstance(msg, str):
355+
# Create a user message with text content
356+
content = TextContent(type="text", text=msg)
357+
messages.append(PromptMessage(role="user", content=content))
358+
else:
359+
# Convert to JSON string and create user message
360+
content_text = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
361+
content = TextContent(type="text", text=content_text)
362+
messages.append(PromptMessage(role="user", content=content))
363+
364+
return messages
365+
except Exception as e:
366+
raise MCPRuntimeError("Could not convert prompt result to message") from e

0 commit comments

Comments
 (0)