Skip to content

Commit c09ab58

Browse files
committed
validation options
1 parent ef4e167 commit c09ab58

File tree

3 files changed

+259
-8
lines changed

3 files changed

+259
-8
lines changed

src/mcp/client/session.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import anyio.lowlevel
66
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
77
from jsonschema import SchemaError, ValidationError, validate
8-
from pydantic import AnyUrl, TypeAdapter
8+
from pydantic import AnyUrl, BaseModel, Field, TypeAdapter
99

1010
import mcp.types as types
1111
from mcp.shared.context import RequestContext
@@ -18,6 +18,17 @@
1818
logger = logging.getLogger("client")
1919

2020

21+
class ValidationOptions(BaseModel):
22+
"""Options for controlling validation behavior in MCP client sessions."""
23+
24+
strict_output_validation: bool = Field(
25+
default=True,
26+
description="Whether to raise exceptions when tools don't return structured "
27+
"content as specified by their output schema. When False, validation "
28+
"errors are logged as warnings and execution continues.",
29+
)
30+
31+
2132
class SamplingFnT(Protocol):
2233
async def __call__(
2334
self,
@@ -118,6 +129,7 @@ def __init__(
118129
logging_callback: LoggingFnT | None = None,
119130
message_handler: MessageHandlerFnT | None = None,
120131
client_info: types.Implementation | None = None,
132+
validation_options: ValidationOptions | None = None,
121133
) -> None:
122134
super().__init__(
123135
read_stream,
@@ -133,6 +145,7 @@ def __init__(
133145
self._logging_callback = logging_callback or _default_logging_callback
134146
self._message_handler = message_handler or _default_message_handler
135147
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
148+
self._validation_options = validation_options or ValidationOptions()
136149

137150
async def initialize(self) -> types.InitializeResult:
138151
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
@@ -324,13 +337,27 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) -
324337

325338
if output_schema is not None:
326339
if result.structuredContent is None:
327-
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
328-
try:
329-
validate(result.structuredContent, output_schema)
330-
except ValidationError as e:
331-
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
332-
except SchemaError as e:
333-
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
340+
if self._validation_options.strict_output_validation:
341+
raise RuntimeError(f"Tool {name} has an output schema but did not return structured content")
342+
else:
343+
logger.warning(
344+
f"Tool {name} has an output schema but did not return structured content. "
345+
f"Continuing without structured content validation due to lenient validation mode."
346+
)
347+
else:
348+
try:
349+
validate(result.structuredContent, output_schema)
350+
except ValidationError as e:
351+
if self._validation_options.strict_output_validation:
352+
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}")
353+
else:
354+
logger.warning(
355+
f"Invalid structured content returned by tool {name}: {e}. "
356+
f"Continuing due to lenient validation mode."
357+
)
358+
except SchemaError as e:
359+
# Schema errors are always raised - they indicate a problem with the schema itself
360+
raise RuntimeError(f"Invalid schema for tool {name}: {e}")
334361

335362
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:
336363
"""Send a prompts/list request."""

src/mcp/shared/memory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ async def create_connected_server_and_client_session(
6161
client_info: types.Implementation | None = None,
6262
raise_exceptions: bool = False,
6363
elicitation_callback: ElicitationFnT | None = None,
64+
validation_options: Any | None = None,
6465
) -> AsyncGenerator[ClientSession, None]:
6566
"""Creates a ClientSession that is connected to a running MCP server."""
6667
async with create_client_server_memory_streams() as (
@@ -92,6 +93,7 @@ async def create_connected_server_and_client_session(
9293
message_handler=message_handler,
9394
client_info=client_info,
9495
elicitation_callback=elicitation_callback,
96+
validation_options=validation_options,
9597
) as client_session:
9698
await client_session.initialize()
9799
yield client_session
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
"""Tests for client-side validation options."""
2+
3+
import logging
4+
from contextlib import contextmanager
5+
from unittest.mock import patch
6+
7+
import pytest
8+
9+
from mcp.client.session import ValidationOptions
10+
from mcp.server.lowlevel import Server
11+
from mcp.shared.memory import (
12+
create_connected_server_and_client_session as client_session,
13+
)
14+
from mcp.types import Tool
15+
16+
17+
@contextmanager
18+
def bypass_server_output_validation():
19+
"""
20+
Context manager that bypasses server-side output validation.
21+
This simulates a non-compliant server that doesn't validate its outputs.
22+
"""
23+
with patch("mcp.server.lowlevel.server.jsonschema.validate"):
24+
yield
25+
26+
27+
class TestValidationOptions:
28+
"""Test validation options for MCP client sessions."""
29+
30+
@pytest.mark.anyio
31+
async def test_strict_validation_default(self):
32+
"""Test that strict validation is enabled by default."""
33+
server = Server("test-server")
34+
35+
output_schema = {
36+
"type": "object",
37+
"properties": {"result": {"type": "integer"}},
38+
"required": ["result"],
39+
}
40+
41+
@server.list_tools()
42+
async def list_tools():
43+
return [
44+
Tool(
45+
name="test_tool",
46+
description="Test tool",
47+
inputSchema={"type": "object"},
48+
outputSchema=output_schema,
49+
)
50+
]
51+
52+
@server.call_tool()
53+
async def call_tool(name: str, arguments: dict):
54+
# Return unstructured content instead of structured content
55+
# This will trigger the validation error we want to test
56+
return "This is unstructured text content"
57+
58+
with bypass_server_output_validation():
59+
async with client_session(server) as client:
60+
# Should raise by default
61+
with pytest.raises(RuntimeError) as exc_info:
62+
await client.call_tool("test_tool", {})
63+
assert "has an output schema but did not return structured content" in str(exc_info.value)
64+
65+
@pytest.mark.anyio
66+
async def test_lenient_validation_missing_content(self, caplog):
67+
"""Test lenient validation when structured content is missing."""
68+
server = Server("test-server")
69+
70+
output_schema = {
71+
"type": "object",
72+
"properties": {"result": {"type": "integer"}},
73+
"required": ["result"],
74+
}
75+
76+
@server.list_tools()
77+
async def list_tools():
78+
return [
79+
Tool(
80+
name="test_tool",
81+
description="Test tool",
82+
inputSchema={"type": "object"},
83+
outputSchema=output_schema,
84+
)
85+
]
86+
87+
@server.call_tool()
88+
async def call_tool(name: str, arguments: dict):
89+
# Return unstructured content instead of structured content
90+
# This will trigger the validation error we want to test
91+
return "This is unstructured text content"
92+
93+
# Set logging level to capture warnings
94+
caplog.set_level(logging.WARNING)
95+
96+
# Create client with lenient validation
97+
validation_options = ValidationOptions(strict_output_validation=False)
98+
99+
with bypass_server_output_validation():
100+
async with client_session(server, validation_options=validation_options) as client:
101+
# Should not raise with lenient validation
102+
result = await client.call_tool("test_tool", {})
103+
104+
# Should have logged a warning
105+
assert "has an output schema but did not return structured content" in caplog.text
106+
assert "Continuing without structured content validation" in caplog.text
107+
108+
# Result should still be returned
109+
assert result.isError is False
110+
assert result.structuredContent is None
111+
112+
@pytest.mark.anyio
113+
async def test_lenient_validation_invalid_content(self, caplog):
114+
"""Test lenient validation when structured content is invalid."""
115+
server = Server("test-server")
116+
117+
output_schema = {
118+
"type": "object",
119+
"properties": {"result": {"type": "integer"}},
120+
"required": ["result"],
121+
}
122+
123+
@server.list_tools()
124+
async def list_tools():
125+
return [
126+
Tool(
127+
name="test_tool",
128+
description="Test tool",
129+
inputSchema={"type": "object"},
130+
outputSchema=output_schema,
131+
)
132+
]
133+
134+
@server.call_tool()
135+
async def call_tool(name: str, arguments: dict):
136+
# Return invalid structured content (string instead of integer)
137+
return {"result": "not_an_integer"}
138+
139+
# Set logging level to capture warnings
140+
caplog.set_level(logging.WARNING)
141+
142+
# Create client with lenient validation
143+
validation_options = ValidationOptions(strict_output_validation=False)
144+
145+
with bypass_server_output_validation():
146+
async with client_session(server, validation_options=validation_options) as client:
147+
# Should not raise with lenient validation
148+
result = await client.call_tool("test_tool", {})
149+
150+
# Should have logged a warning
151+
assert "Invalid structured content returned by tool test_tool" in caplog.text
152+
assert "Continuing due to lenient validation mode" in caplog.text
153+
154+
# Result should still be returned with the invalid content
155+
assert result.isError is False
156+
assert result.structuredContent == {"result": "not_an_integer"}
157+
158+
@pytest.mark.anyio
159+
async def test_strict_validation_with_valid_content(self):
160+
"""Test that valid content passes with strict validation."""
161+
server = Server("test-server")
162+
163+
output_schema = {
164+
"type": "object",
165+
"properties": {"result": {"type": "integer"}},
166+
"required": ["result"],
167+
}
168+
169+
@server.list_tools()
170+
async def list_tools():
171+
return [
172+
Tool(
173+
name="test_tool",
174+
description="Test tool",
175+
inputSchema={"type": "object"},
176+
outputSchema=output_schema,
177+
)
178+
]
179+
180+
@server.call_tool()
181+
async def call_tool(name: str, arguments: dict):
182+
# Return valid structured content
183+
return {"result": 42}
184+
185+
async with client_session(server) as client:
186+
# Should succeed with valid content
187+
result = await client.call_tool("test_tool", {})
188+
assert result.isError is False
189+
assert result.structuredContent == {"result": 42}
190+
191+
@pytest.mark.anyio
192+
async def test_schema_errors_always_raised(self):
193+
"""Test that schema errors are always raised regardless of validation mode."""
194+
server = Server("test-server")
195+
196+
# Invalid schema (missing required 'type' field)
197+
output_schema = {"properties": {"result": {}}}
198+
199+
@server.list_tools()
200+
async def list_tools():
201+
return [
202+
Tool(
203+
name="test_tool",
204+
description="Test tool",
205+
inputSchema={"type": "object"},
206+
outputSchema=output_schema,
207+
)
208+
]
209+
210+
@server.call_tool()
211+
async def call_tool(name: str, arguments: dict):
212+
return {"result": 42}
213+
214+
# Test with lenient validation
215+
validation_options = ValidationOptions(strict_output_validation=False)
216+
217+
with bypass_server_output_validation():
218+
async with client_session(server, validation_options=validation_options) as client:
219+
# Should still raise for schema errors
220+
with pytest.raises(RuntimeError) as exc_info:
221+
await client.call_tool("test_tool", {})
222+
assert "Invalid schema for tool test_tool" in str(exc_info.value)

0 commit comments

Comments
 (0)