Skip to content

Commit 6eb1b3f

Browse files
committed
Unify sampling and elicitation code paths with shared validation
This refactoring ensures all sampling and elicitation code paths use consistent validation and support the same features. Sampling changes: - Add shared validation module (mcp/server/validation.py) with validate_sampling_tools() and validate_tool_use_result_messages() - Add tools and tool_choice parameters to all sampling methods: - _build_create_message_request() - ExperimentalServerSessionFeatures.create_message_as_task() - ServerTaskContext.create_message() - ServerTaskContext.create_message_as_task() - Refactor ServerSession.create_message() to use shared validation Elicitation changes: - Rename _build_elicit_request to _build_elicit_form_request for clarity - Add _build_elicit_url_request() for URL mode elicitation - Add ServerTaskContext.elicit_url() so URL elicitation can be used from inside task-augmented tool calls (e.g., for OAuth flows) This fixes a gap where task-augmented code paths were missing: - tools/tool_choice parameters for sampling - URL mode for elicitation
1 parent b7d44fa commit 6eb1b3f

File tree

6 files changed

+351
-54
lines changed

6 files changed

+351
-54
lines changed

src/mcp/server/experimental/session_features.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import TYPE_CHECKING, Any, TypeVar
1212

1313
import mcp.types as types
14+
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
1415
from mcp.shared.experimental.tasks.capabilities import (
1516
require_task_augmented_elicitation,
1617
require_task_augmented_sampling,
@@ -156,6 +157,8 @@ async def create_message_as_task(
156157
stop_sequences: list[str] | None = None,
157158
metadata: dict[str, Any] | None = None,
158159
model_preferences: types.ModelPreferences | None = None,
160+
tools: list[types.Tool] | None = None,
161+
tool_choice: types.ToolChoice | None = None,
159162
) -> types.CreateMessageResult:
160163
"""
161164
Send a task-augmented sampling request and poll until complete.
@@ -173,15 +176,20 @@ async def create_message_as_task(
173176
stop_sequences: Stop sequences
174177
metadata: Additional metadata
175178
model_preferences: Model selection preferences
179+
tools: Optional list of tools the LLM can use during sampling
180+
tool_choice: Optional control over tool usage behavior
176181
177182
Returns:
178183
The sampling result from the client
179184
180185
Raises:
181-
McpError: If client doesn't support task-augmented sampling
186+
McpError: If client doesn't support task-augmented sampling or tools
187+
ValueError: If tool_use or tool_result message structure is invalid
182188
"""
183189
client_caps = self._session.client_params.capabilities if self._session.client_params else None
184190
require_task_augmented_sampling(client_caps)
191+
validate_sampling_tools(client_caps, tools, tool_choice)
192+
validate_tool_use_result_messages(messages)
185193

186194
create_result = await self._session.send_request(
187195
types.ServerRequest(
@@ -195,6 +203,8 @@ async def create_message_as_task(
195203
stopSequences=stop_sequences,
196204
metadata=metadata,
197205
modelPreferences=model_preferences,
206+
tools=tools,
207+
toolChoice=tool_choice,
198208
task=types.TaskMetadata(ttl=ttl),
199209
)
200210
)

src/mcp/server/experimental/task_context.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from mcp.server.experimental.task_result_handler import TaskResultHandler
1515
from mcp.server.session import ServerSession
16+
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
1617
from mcp.shared.exceptions import McpError
1718
from mcp.shared.experimental.tasks.capabilities import (
1819
require_task_augmented_elicitation,
@@ -44,6 +45,8 @@
4445
TaskMetadata,
4546
TaskStatusNotification,
4647
TaskStatusNotificationParams,
48+
Tool,
49+
ToolChoice,
4750
)
4851

4952

@@ -231,7 +234,7 @@ async def elicit(
231234
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
232235

233236
# Build the request using session's helper
234-
request = self._session._build_elicit_request( # pyright: ignore[reportPrivateUsage]
237+
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
235238
message=message,
236239
requestedSchema=requestedSchema,
237240
related_task_id=self.task_id,
@@ -263,6 +266,77 @@ async def elicit(
263266
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
264267
raise
265268

269+
async def elicit_url(
270+
self,
271+
message: str,
272+
url: str,
273+
elicitation_id: str,
274+
) -> ElicitResult:
275+
"""
276+
Send a URL mode elicitation request via the task message queue.
277+
278+
This directs the user to an external URL for out-of-band interactions
279+
like OAuth flows, credential collection, or payment processing.
280+
281+
This method:
282+
1. Checks client capability
283+
2. Updates task status to "input_required"
284+
3. Queues the elicitation request
285+
4. Waits for the response (delivered via tasks/result round-trip)
286+
5. Updates task status back to "working"
287+
6. Returns the result
288+
289+
Args:
290+
message: Human-readable explanation of why the interaction is needed
291+
url: The URL the user should navigate to
292+
elicitation_id: Unique identifier for tracking this elicitation
293+
294+
Returns:
295+
The client's response indicating acceptance, decline, or cancellation
296+
297+
Raises:
298+
McpError: If client doesn't support elicitation capability
299+
RuntimeError: If handler is not configured
300+
"""
301+
self._check_elicitation_capability()
302+
303+
if self._handler is None:
304+
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
305+
306+
# Update status to input_required
307+
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
308+
309+
# Build the request using session's helper
310+
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
311+
message=message,
312+
url=url,
313+
elicitation_id=elicitation_id,
314+
related_task_id=self.task_id,
315+
)
316+
request_id: RequestId = request.id
317+
318+
# Create resolver and register with handler for response routing
319+
resolver: Resolver[dict[str, Any]] = Resolver()
320+
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
321+
322+
# Queue the request
323+
queued = QueuedMessage(
324+
type="request",
325+
message=request,
326+
resolver=resolver,
327+
original_request_id=request_id,
328+
)
329+
await self._queue.enqueue(self.task_id, queued)
330+
331+
try:
332+
# Wait for response (routed back via TaskResultHandler)
333+
response_data = await resolver.wait()
334+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
335+
return ElicitResult.model_validate(response_data)
336+
except anyio.get_cancelled_exc_class(): # pragma: no cover
337+
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
338+
raise
339+
266340
async def create_message(
267341
self,
268342
messages: list[SamplingMessage],
@@ -274,6 +348,8 @@ async def create_message(
274348
stop_sequences: list[str] | None = None,
275349
metadata: dict[str, Any] | None = None,
276350
model_preferences: ModelPreferences | None = None,
351+
tools: list[Tool] | None = None,
352+
tool_choice: ToolChoice | None = None,
277353
) -> CreateMessageResult:
278354
"""
279355
Send a sampling request via the task message queue.
@@ -295,14 +371,20 @@ async def create_message(
295371
stop_sequences: Stop sequences
296372
metadata: Additional metadata
297373
model_preferences: Model selection preferences
374+
tools: Optional list of tools the LLM can use during sampling
375+
tool_choice: Optional control over tool usage behavior
298376
299377
Returns:
300378
The sampling result from the client
301379
302380
Raises:
303-
McpError: If client doesn't support sampling capability
381+
McpError: If client doesn't support sampling capability or tools
382+
ValueError: If tool_use or tool_result message structure is invalid
304383
"""
305384
self._check_sampling_capability()
385+
client_caps = self._session.client_params.capabilities if self._session.client_params else None
386+
validate_sampling_tools(client_caps, tools, tool_choice)
387+
validate_tool_use_result_messages(messages)
306388

307389
if self._handler is None:
308390
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
@@ -320,6 +402,8 @@ async def create_message(
320402
stop_sequences=stop_sequences,
321403
metadata=metadata,
322404
model_preferences=model_preferences,
405+
tools=tools,
406+
tool_choice=tool_choice,
323407
related_task_id=self.task_id,
324408
)
325409
request_id: RequestId = request.id
@@ -386,7 +470,7 @@ async def elicit_as_task(
386470
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
387471

388472
# Build request WITH task field for task-augmented elicitation
389-
request = self._session._build_elicit_request( # pyright: ignore[reportPrivateUsage]
473+
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
390474
message=message,
391475
requestedSchema=requestedSchema,
392476
related_task_id=self.task_id,
@@ -442,6 +526,8 @@ async def create_message_as_task(
442526
stop_sequences: list[str] | None = None,
443527
metadata: dict[str, Any] | None = None,
444528
model_preferences: ModelPreferences | None = None,
529+
tools: list[Tool] | None = None,
530+
tool_choice: ToolChoice | None = None,
445531
) -> CreateMessageResult:
446532
"""
447533
Send a task-augmented sampling request via the queue, then poll client.
@@ -461,16 +547,21 @@ async def create_message_as_task(
461547
stop_sequences: Stop sequences
462548
metadata: Additional metadata
463549
model_preferences: Model selection preferences
550+
tools: Optional list of tools the LLM can use during sampling
551+
tool_choice: Optional control over tool usage behavior
464552
465553
Returns:
466554
The sampling result from the client
467555
468556
Raises:
469-
McpError: If client doesn't support task-augmented sampling
557+
McpError: If client doesn't support task-augmented sampling or tools
558+
ValueError: If tool_use or tool_result message structure is invalid
470559
RuntimeError: If handler is not configured
471560
"""
472561
client_caps = self._session.client_params.capabilities if self._session.client_params else None
473562
require_task_augmented_sampling(client_caps)
563+
validate_sampling_tools(client_caps, tools, tool_choice)
564+
validate_tool_use_result_messages(messages)
474565

475566
if self._handler is None:
476567
raise RuntimeError("handler is required for create_message_as_task()")
@@ -488,6 +579,8 @@ async def create_message_as_task(
488579
stop_sequences=stop_sequences,
489580
metadata=metadata,
490581
model_preferences=model_preferences,
582+
tools=tools,
583+
tool_choice=tool_choice,
491584
related_task_id=self.task_id,
492585
task=TaskMetadata(ttl=ttl),
493586
)

src/mcp/server/session.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4848
import mcp.types as types
4949
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
5050
from mcp.server.models import InitializationOptions
51-
from mcp.shared.exceptions import McpError
51+
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
5252
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
5353
from mcp.shared.message import ServerMessageMetadata, SessionMessage
5454
from mcp.shared.response_router import ResponseRouter
@@ -293,47 +293,12 @@ async def create_message(
293293
The sampling result from the client.
294294
295295
Raises:
296-
McpError: If tool_use or tool_result blocks are misused when tools are provided.
296+
McpError: If tools are provided but client doesn't support them.
297+
ValueError: If tool_use or tool_result message structure is invalid.
297298
"""
298-
299-
if tools is not None or tool_choice is not None:
300-
has_tools_cap = self.check_client_capability(
301-
types.ClientCapabilities(sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()))
302-
)
303-
if not has_tools_cap:
304-
raise McpError(
305-
types.ErrorData(
306-
code=types.INVALID_PARAMS,
307-
message="Client does not support sampling tools capability",
308-
)
309-
)
310-
311-
# Validate tool_use/tool_result message structure per SEP-1577:
312-
# https://github.com/modelcontextprotocol/modelcontextprotocol/issues/1577
313-
# This validation runs regardless of whether `tools` is in this request,
314-
# since a tool loop continuation may omit `tools` while still containing
315-
# tool_result content that must match previous tool_use.
316-
if messages:
317-
last_content = messages[-1].content_as_list
318-
has_tool_results = any(c.type == "tool_result" for c in last_content)
319-
320-
previous_content = messages[-2].content_as_list if len(messages) >= 2 else None
321-
has_previous_tool_use = previous_content and any(c.type == "tool_use" for c in previous_content)
322-
323-
if has_tool_results:
324-
# Per spec: "SamplingMessage with tool result content blocks
325-
# MUST NOT contain other content types."
326-
if any(c.type != "tool_result" for c in last_content):
327-
raise ValueError("The last message must contain only tool_result content if any is present")
328-
if previous_content is None:
329-
raise ValueError("tool_result requires a previous message containing tool_use")
330-
if not has_previous_tool_use:
331-
raise ValueError("tool_result blocks do not match any tool_use in the previous message")
332-
if has_previous_tool_use and previous_content:
333-
tool_use_ids = {c.id for c in previous_content if c.type == "tool_use"}
334-
tool_result_ids = {c.toolUseId for c in last_content if c.type == "tool_result"}
335-
if tool_use_ids != tool_result_ids:
336-
raise ValueError("ids of tool_result blocks and tool_use blocks from previous message do not match")
299+
client_caps = self._client_params.capabilities if self._client_params else None
300+
validate_sampling_tools(client_caps, tools, tool_choice)
301+
validate_tool_use_result_messages(messages)
337302

338303
return await self.send_request(
339304
request=types.ServerRequest(
@@ -525,14 +490,14 @@ async def send_elicit_complete(
525490
# by TaskContext to construct requests that will be queued instead of sent
526491
# directly, avoiding code duplication between ServerSession and TaskContext.
527492

528-
def _build_elicit_request(
493+
def _build_elicit_form_request(
529494
self,
530495
message: str,
531496
requestedSchema: types.ElicitRequestedSchema,
532497
related_task_id: str | None = None,
533498
task: types.TaskMetadata | None = None,
534499
) -> types.JSONRPCRequest:
535-
"""Build an elicitation request without sending it.
500+
"""Build a form mode elicitation request without sending it.
536501
537502
Args:
538503
message: The message to present to the user
@@ -567,6 +532,48 @@ def _build_elicit_request(
567532
params=params_data,
568533
)
569534

535+
def _build_elicit_url_request(
536+
self,
537+
message: str,
538+
url: str,
539+
elicitation_id: str,
540+
related_task_id: str | None = None,
541+
) -> types.JSONRPCRequest:
542+
"""Build a URL mode elicitation request without sending it.
543+
544+
Args:
545+
message: Human-readable explanation of why the interaction is needed
546+
url: The URL the user should navigate to
547+
elicitation_id: Unique identifier for tracking this elicitation
548+
related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata
549+
550+
Returns:
551+
A JSONRPCRequest ready to be sent or queued
552+
"""
553+
params = types.ElicitRequestURLParams(
554+
message=message,
555+
url=url,
556+
elicitationId=elicitation_id,
557+
)
558+
params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True)
559+
560+
# Add related-task metadata if associated with a parent task
561+
if related_task_id is not None:
562+
if "_meta" not in params_data:
563+
params_data["_meta"] = {}
564+
params_data["_meta"]["io.modelcontextprotocol/related-task"] = {"taskId": related_task_id}
565+
566+
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
567+
if related_task_id is None:
568+
self._request_id += 1
569+
570+
return types.JSONRPCRequest(
571+
jsonrpc="2.0",
572+
id=request_id,
573+
method="elicitation/create",
574+
params=params_data,
575+
)
576+
570577
def _build_create_message_request(
571578
self,
572579
messages: list[types.SamplingMessage],
@@ -578,6 +585,8 @@ def _build_create_message_request(
578585
stop_sequences: list[str] | None = None,
579586
metadata: dict[str, Any] | None = None,
580587
model_preferences: types.ModelPreferences | None = None,
588+
tools: list[types.Tool] | None = None,
589+
tool_choice: types.ToolChoice | None = None,
581590
related_task_id: str | None = None,
582591
task: types.TaskMetadata | None = None,
583592
) -> types.JSONRPCRequest:
@@ -592,6 +601,8 @@ def _build_create_message_request(
592601
stop_sequences: Optional stop sequences
593602
metadata: Optional metadata to pass through to the LLM provider
594603
model_preferences: Optional model selection preferences
604+
tools: Optional list of tools the LLM can use during sampling
605+
tool_choice: Optional control over tool usage behavior
595606
related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata
596607
task: If provided, makes this a task-augmented request
597608
@@ -607,6 +618,8 @@ def _build_create_message_request(
607618
stopSequences=stop_sequences,
608619
metadata=metadata,
609620
modelPreferences=model_preferences,
621+
tools=tools,
622+
toolChoice=tool_choice,
610623
task=task,
611624
)
612625
params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True)

0 commit comments

Comments
 (0)