Skip to content

Commit 35db624

Browse files
committed
Add output schema generation and related tests
1 parent 1a330ac commit 35db624

File tree

8 files changed

+949
-678
lines changed

8 files changed

+949
-678
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
GetPromptResult,
4242
ImageContent,
4343
TextContent,
44+
DataContent,
4445
ToolAnnotations,
4546
)
4647
from mcp.types import Prompt as MCPPrompt
@@ -169,14 +170,17 @@ def _setup_handlers(self) -> None:
169170
self._mcp_server.get_prompt()(self.get_prompt)
170171
self._mcp_server.list_resource_templates()(self.list_resource_templates)
171172

173+
172174
async def list_tools(self) -> list[MCPTool]:
173175
"""List all available tools."""
174176
tools = self._tool_manager.list_tools()
177+
logger.info(f"list tools: {tools}")
175178
return [
176179
MCPTool(
177180
name=info.name,
178181
description=info.description,
179182
inputSchema=info.parameters,
183+
outputSchema=info.outputSchema,
180184
annotations=info.annotations,
181185
)
182186
for info in tools
@@ -195,10 +199,11 @@ def get_context(self) -> Context[ServerSession, object]:
195199

196200
async def call_tool(
197201
self, name: str, arguments: dict[str, Any]
198-
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
202+
) -> Sequence[TextContent | ImageContent | DataContent | EmbeddedResource]:
199203
"""Call a tool by name with arguments."""
200204
context = self.get_context()
201205
result = await self._tool_manager.call_tool(name, arguments, context=context)
206+
logger.info(f"call tool: {name} with args: {arguments} -> {result}")
202207
converted_result = _convert_to_content(result)
203208
return converted_result
204209

@@ -260,7 +265,10 @@ def add_tool(
260265
annotations: Optional ToolAnnotations providing additional tool information
261266
"""
262267
self._tool_manager.add_tool(
263-
fn, name=name, description=description, annotations=annotations
268+
fn,
269+
name=name,
270+
description=description,
271+
annotations=annotations,
264272
)
265273

266274
def tool(
@@ -304,7 +312,10 @@ async def async_tool(x: int, context: Context) -> str:
304312

305313
def decorator(fn: AnyFunction) -> AnyFunction:
306314
self.add_tool(
307-
fn, name=name, description=description, annotations=annotations
315+
fn,
316+
name=name,
317+
description=description,
318+
annotations=annotations,
308319
)
309320
return fn
310321

@@ -547,12 +558,13 @@ async def get_prompt(
547558

548559
def _convert_to_content(
549560
result: Any,
550-
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
561+
) -> Sequence[TextContent | ImageContent | EmbeddedResource | DataContent]:
551562
"""Convert a result to a sequence of content objects."""
552563
if result is None:
553564
return []
554565

555-
if isinstance(result, TextContent | ImageContent | EmbeddedResource):
566+
# Handle existing content types
567+
if isinstance(result, TextContent | ImageContent | EmbeddedResource | DataContent):
556568
return [result]
557569

558570
if isinstance(result, Image):
@@ -561,9 +573,21 @@ def _convert_to_content(
561573
if isinstance(result, list | tuple):
562574
return list(chain.from_iterable(_convert_to_content(item) for item in result)) # type: ignore[reportUnknownVariableType]
563575

576+
# For non-string objects, convert to DataContent
564577
if not isinstance(result, str):
565-
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
578+
# Try to convert to a JSON-serializable structure
579+
try:
580+
# Get the data as a dict/list structure
581+
data = pydantic_core.to_jsonable_python(result)
582+
# Create DataContent with the data
583+
return [DataContent(type="data", data=data)]
584+
except Exception as e:
585+
logger.warning(f"Failed to convert result to DataContent: {e}")
586+
# Fall back to string representation
587+
result_str = pydantic_core.to_json(result, fallback=str, indent=2).decode()
588+
return [TextContent(type="text", text=result_str)]
566589

590+
# For strings, use TextContent
567591
return [TextContent(type="text", text=result)]
568592

569593

src/mcp/server/fastmcp/tools/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import inspect
44
from collections.abc import Callable
55
from typing import TYPE_CHECKING, Any, get_origin
6+
from mcp.server.fastmcp.utilities.logging import get_logger
67

78
from pydantic import BaseModel, Field
89

@@ -15,6 +16,7 @@
1516
from mcp.server.session import ServerSessionT
1617
from mcp.shared.context import LifespanContextT
1718

19+
logger = get_logger(__name__)
1820

1921
class Tool(BaseModel):
2022
"""Internal tool registration info."""
@@ -23,6 +25,9 @@ class Tool(BaseModel):
2325
name: str = Field(description="Name of the tool")
2426
description: str = Field(description="Description of what the tool does")
2527
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
28+
outputSchema: dict[str, Any] | None = Field(
29+
None, description="Optional JSON schema for tool output"
30+
)
2631
fn_metadata: FuncMetadata = Field(
2732
description="Metadata about the function including a pydantic model for tool"
2833
" arguments"
@@ -70,6 +75,10 @@ def from_function(
7075
)
7176
parameters = func_arg_metadata.arg_model.model_json_schema()
7277

78+
output_schema = getattr(func_arg_metadata, "outputSchema", None)
79+
80+
logger.info(f"output schema: {output_schema}")
81+
7382
return cls(
7483
fn=fn,
7584
name=func_name,
@@ -78,6 +87,7 @@ def from_function(
7887
fn_metadata=func_arg_metadata,
7988
is_async=is_async,
8089
context_kwarg=context_kwarg,
90+
outputSchema=output_schema,
8191
annotations=annotations,
8292
)
8393

@@ -87,6 +97,7 @@ async def run(
8797
context: Context[ServerSessionT, LifespanContextT] | None = None,
8898
) -> Any:
8999
"""Run the tool with arguments."""
100+
90101
try:
91102
return await self.fn_metadata.call_fn_with_arg_validation(
92103
self.fn,

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def add_tool(
4040
) -> Tool:
4141
"""Add a tool to the server."""
4242
tool = Tool.from_function(
43-
fn, name=name, description=description, annotations=annotations
43+
fn,
44+
name=name,
45+
description=description,
46+
annotations=annotations,
4447
)
4548
existing = self._tools.get(tool.name)
4649
if existing:

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def model_dump_one_level(self) -> dict[str, Any]:
3838

3939
class FuncMetadata(BaseModel):
4040
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
41+
outputSchema: dict[str, Any] | None = None
4142
# We can add things in the future like
4243
# - Maybe some args are excluded from attempting to parse from JSON
4344
# - Maybe some args are special (like context) for dependency injection
@@ -60,6 +61,11 @@ async def call_fn_with_arg_validation(
6061

6162
arguments_parsed_dict |= arguments_to_pass_directly or {}
6263

64+
logger.info(
65+
"Calling function with arguments: %s",
66+
arguments_parsed_dict,
67+
)
68+
logger.info(f"Function is async: ${fn}")
6369
if fn_is_async:
6470
if isinstance(fn, Awaitable):
6571
return await fn
@@ -172,7 +178,62 @@ def func_metadata(
172178
**dynamic_pydantic_model_params,
173179
__base__=ArgModelBase,
174180
)
175-
resp = FuncMetadata(arg_model=arguments_model)
181+
182+
# Generate output schema from return type annotation
183+
output_schema: dict[str, Any] | None = None
184+
return_annotation = sig.return_annotation
185+
186+
if return_annotation is not inspect.Signature.empty:
187+
try:
188+
# Handle forward references
189+
return_type = _get_typed_annotation(return_annotation, globalns)
190+
logger.info(f"return_type: {return_type}")
191+
# Special case for None
192+
if return_type is type(None): # noqa: E721
193+
output_schema = {"type": "null"}
194+
else:
195+
# Create a temporary model to get the schema
196+
class OutputModel(BaseModel):
197+
result: return_type # type: ignore
198+
199+
model_config = ConfigDict(
200+
arbitrary_types_allowed=True,
201+
)
202+
203+
# Extract the schema for the return type
204+
full_schema = OutputModel.model_json_schema()
205+
206+
# If the return type is a complex type, use its schema definition
207+
if "$defs" in full_schema and "result" in full_schema.get(
208+
"properties", {}
209+
):
210+
prop = full_schema["properties"]["result"]
211+
if isinstance(prop, dict) and "$ref" in prop:
212+
if isinstance(prop["$ref"], str):
213+
ref_name = prop["$ref"].split("/")[-1]
214+
else:
215+
raise TypeError("Expected to be a string")
216+
if ref_name in full_schema.get("$defs", {}):
217+
ref_schema = full_schema["$defs"][ref_name]
218+
output_schema = {
219+
"type": "object",
220+
"properties": ref_schema.get("properties", {}),
221+
"required": ref_schema.get("required", []),
222+
}
223+
# Optionally include title if present
224+
if "title" in ref_schema:
225+
output_schema["title"] = ref_schema["title"]
226+
else:
227+
output_schema = prop
228+
else:
229+
# For simple types
230+
output_schema = full_schema["properties"]["result"]
231+
232+
except Exception as e:
233+
# If we can't generate a schema, log the error but continue
234+
logger.warning(f"Failed to generate output schema for {func.__name__}: {e}")
235+
236+
resp = FuncMetadata(arg_model=arguments_model, outputSchema=output_schema)
176237
return resp
177238

178239

@@ -194,6 +255,10 @@ def try_eval_type(
194255
if status is False:
195256
raise InvalidSignature(f"Unable to evaluate type annotation {annotation}")
196257

258+
# If the annotation is already a valid type, return it directly
259+
if isinstance(annotation, type):
260+
return annotation
261+
197262
return annotation
198263

199264

@@ -210,5 +275,8 @@ def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
210275
)
211276
for param in signature.parameters.values()
212277
]
213-
typed_signature = inspect.Signature(typed_params)
278+
typed_signature = inspect.Signature(
279+
typed_params,
280+
return_annotation=_get_typed_annotation(signature.return_annotation, globalns),
281+
)
214282
return typed_signature

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def decorator(
398398
...,
399399
Awaitable[
400400
Iterable[
401-
types.TextContent | types.ImageContent | types.EmbeddedResource
401+
types.TextContent | types.ImageContent | types.DataContent | types.EmbeddedResource
402402
]
403403
],
404404
],

src/mcp/types.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,27 @@ class ImageContent(BaseModel):
646646
model_config = ConfigDict(extra="allow")
647647

648648

649+
class DataContent(BaseModel):
650+
"""Structured JSON content for a message or tool result."""
651+
652+
type: Literal["data"]
653+
data: dict[str, Any]
654+
"""
655+
The structured JSON data. This is a JSON serializable object.
656+
"""
657+
658+
schema_definition: dict[str, Any] | str | None = None
659+
"""
660+
An optional schema describing the structure of the data.
661+
- Can be a string (schema reference URI),
662+
- A dictionary (full schema definition),
663+
- Or omitted if no schema is provided.
664+
"""
665+
666+
annotations: Annotations | None = None
667+
model_config = ConfigDict(extra="allow")
668+
669+
649670
class SamplingMessage(BaseModel):
650671
"""Describes a message issued to or received from an LLM API."""
651672

@@ -762,6 +783,8 @@ class Tool(BaseModel):
762783
"""A human-readable description of the tool."""
763784
inputSchema: dict[str, Any]
764785
"""A JSON Schema object defining the expected parameters for the tool."""
786+
outputSchema: dict[str, Any] | None = None
787+
"""A JSON Schema object defining the expected structure of the tool's output."""
765788
annotations: ToolAnnotations | None = None
766789
"""Optional additional tool information."""
767790
model_config = ConfigDict(extra="allow")
@@ -791,7 +814,7 @@ class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
791814
class CallToolResult(Result):
792815
"""The server's response to a tool call."""
793816

794-
content: list[TextContent | ImageContent | EmbeddedResource]
817+
content: list[TextContent | ImageContent | DataContent | EmbeddedResource]
795818
isError: bool = False
796819

797820

0 commit comments

Comments
 (0)