Skip to content

Commit 6169282

Browse files
committed
Merge branch 'feature/resource_progress' into feature/617-async-call-tools
2 parents 05b7156 + d634e6a commit 6169282

File tree

6 files changed

+91
-22
lines changed

6 files changed

+91
-22
lines changed

src/mcp/client/session.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
from datetime import timedelta
2-
from typing import Any, Protocol
2+
from typing import Annotated, Any, Protocol
33

44
import anyio.lowlevel
55
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
6-
from pydantic import AnyUrl, TypeAdapter
6+
from pydantic import TypeAdapter
7+
from pydantic.networks import AnyUrl, UrlConstraints
78

89
import mcp.types as types
910
from mcp.shared.context import RequestContext
1011
from mcp.shared.message import SessionMessage
11-
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
12+
from mcp.shared.session import (
13+
BaseSession,
14+
ProgressFnT,
15+
RequestResponder,
16+
ResourceProgressFnT,
17+
)
1218
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1319

1420
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -179,6 +185,9 @@ async def send_progress_notification(
179185
progress: float,
180186
total: float | None = None,
181187
message: str | None = None,
188+
# TODO check whether MCP spec allows clients to create resources
189+
# for server and therefore whether resource notifications
190+
# would be required here too
182191
) -> None:
183192
"""Send a progress notification."""
184193
await self.send_notification(
@@ -208,7 +217,10 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul
208217
)
209218

210219
async def list_resources(
211-
self, cursor: str | None = None
220+
self,
221+
cursor: str | None = None,
222+
# TODO suggest in progress resources should be excluded by default?
223+
# possibly add an optional flag to include?
212224
) -> types.ListResourcesResult:
213225
"""Send a resources/list request."""
214226
return await self.send_request(
@@ -239,7 +251,9 @@ async def list_resource_templates(
239251
types.ListResourceTemplatesResult,
240252
)
241253

242-
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
254+
async def read_resource(
255+
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
256+
) -> types.ReadResourceResult:
243257
"""Send a resources/read request."""
244258
return await self.send_request(
245259
types.ClientRequest(
@@ -251,7 +265,9 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
251265
types.ReadResourceResult,
252266
)
253267

254-
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
268+
async def subscribe_resource(
269+
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
270+
) -> types.EmptyResult:
255271
"""Send a resources/subscribe request."""
256272
return await self.send_request(
257273
types.ClientRequest(
@@ -263,7 +279,9 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
263279
types.EmptyResult,
264280
)
265281

266-
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
282+
async def unsubscribe_resource(
283+
self, uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
284+
) -> types.EmptyResult:
267285
"""Send a resources/unsubscribe request."""
268286
return await self.send_request(
269287
types.ClientRequest(
@@ -280,7 +298,7 @@ async def call_tool(
280298
name: str,
281299
arguments: dict[str, Any] | None = None,
282300
read_timeout_seconds: timedelta | None = None,
283-
progress_callback: ProgressFnT | None = None,
301+
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
284302
) -> types.CallToolResult:
285303
"""Send a tools/call request with optional progress callback support."""
286304

src/mcp/server/fastmcp/server.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
asynccontextmanager,
1111
)
1212
from itertools import chain
13-
from typing import Any, Generic, Literal
13+
from typing import Annotated, Any, Generic, Literal
1414

1515
import anyio
1616
import pydantic_core
1717
from pydantic import BaseModel, Field
18-
from pydantic.networks import AnyUrl
18+
from pydantic.networks import AnyUrl, UrlConstraints
1919
from pydantic_settings import BaseSettings, SettingsConfigDict
2020
from starlette.applications import Starlette
2121
from starlette.middleware import Middleware
@@ -962,7 +962,12 @@ def request_context(
962962
return self._request_context
963963

964964
async def report_progress(
965-
self, progress: float, total: float | None = None, message: str | None = None
965+
self,
966+
progress: float,
967+
total: float | None = None,
968+
message: str | None = None,
969+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
970+
| None = None,
966971
) -> None:
967972
"""Report progress for the current operation.
968973
@@ -985,6 +990,7 @@ async def report_progress(
985990
progress=progress,
986991
total=total,
987992
message=message,
993+
resource_uri=resource_uri,
988994
)
989995

990996
async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]:

src/mcp/server/session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3838
"""
3939

4040
from enum import Enum
41-
from typing import Any, TypeVar
41+
from typing import Annotated, Any, TypeVar
4242

4343
import anyio
4444
import anyio.lowlevel
4545
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
46-
from pydantic import AnyUrl
46+
from pydantic.networks import AnyUrl, UrlConstraints
4747

4848
import mcp.types as types
4949
from mcp.server.models import InitializationOptions
@@ -288,6 +288,8 @@ async def send_progress_notification(
288288
total: float | None = None,
289289
message: str | None = None,
290290
related_request_id: str | None = None,
291+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
292+
| None = None,
291293
) -> None:
292294
"""Send a progress notification."""
293295
await self.send_notification(
@@ -299,6 +301,7 @@ async def send_progress_notification(
299301
progress=progress,
300302
total=total,
301303
message=message,
304+
resource_uri=resource_uri,
302305
),
303306
)
304307
),

src/mcp/shared/session.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
import inspect
12
import logging
23
from collections.abc import Callable
34
from contextlib import AsyncExitStack
45
from datetime import timedelta
56
from types import TracebackType
6-
from typing import Any, Generic, Protocol, TypeVar
7+
from typing import Annotated, Any, Generic, Protocol, TypeVar, runtime_checkable
78

89
import anyio
910
import httpx
1011
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1112
from pydantic import BaseModel
13+
from pydantic.networks import AnyUrl, UrlConstraints
1214
from typing_extensions import Self
1315

1416
from mcp.shared.exceptions import McpError
@@ -44,6 +46,7 @@
4446
RequestId = str | int
4547

4648

49+
@runtime_checkable
4750
class ProgressFnT(Protocol):
4851
"""Protocol for progress notification callbacks."""
4952

@@ -52,6 +55,20 @@ async def __call__(
5255
) -> None: ...
5356

5457

58+
@runtime_checkable
59+
class ResourceProgressFnT(Protocol):
60+
"""Protocol for progress notification callbacks with resources."""
61+
62+
async def __call__(
63+
self,
64+
progress: float,
65+
total: float | None,
66+
message: str | None,
67+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
68+
| None = None,
69+
) -> None: ...
70+
71+
5572
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
5673
"""Handles responding to MCP requests and manages request lifecycle.
5774
@@ -182,6 +199,7 @@ class BaseSession(
182199
_request_id: int
183200
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
184201
_progress_callbacks: dict[RequestId, ProgressFnT]
202+
_resource_callbacks: dict[RequestId, ResourceProgressFnT]
185203

186204
def __init__(
187205
self,
@@ -201,6 +219,7 @@ def __init__(
201219
self._session_read_timeout_seconds = read_timeout_seconds
202220
self._in_flight = {}
203221
self._progress_callbacks = {}
222+
self._resource_callbacks = {}
204223
self._exit_stack = AsyncExitStack()
205224

206225
async def __aenter__(self) -> Self:
@@ -228,7 +247,7 @@ async def send_request(
228247
result_type: type[ReceiveResultT],
229248
request_read_timeout_seconds: timedelta | None = None,
230249
metadata: MessageMetadata = None,
231-
progress_callback: ProgressFnT | None = None,
250+
progress_callback: ProgressFnT | ResourceProgressFnT | None = None,
232251
) -> ReceiveResultT:
233252
"""
234253
Sends a request and wait for a response. Raises an McpError if the
@@ -255,8 +274,15 @@ async def send_request(
255274
if "_meta" not in request_data["params"]:
256275
request_data["params"]["_meta"] = {}
257276
request_data["params"]["_meta"]["progressToken"] = request_id
258-
# Store the callback for this request
259-
self._progress_callbacks[request_id] = progress_callback
277+
# note this is required to ensure backwards compatibility
278+
# for previous clients
279+
signature = inspect.signature(progress_callback.__call__)
280+
if len(signature.parameters) == 3:
281+
# Store the callback for this request
282+
self._resource_callbacks[request_id] = progress_callback # type: ignore
283+
else:
284+
# Store the callback for this request
285+
self._progress_callbacks[request_id] = progress_callback
260286

261287
try:
262288
jsonrpc_request = JSONRPCRequest(
@@ -401,6 +427,15 @@ async def _receive_loop(self) -> None:
401427
notification.root.params.total,
402428
notification.root.params.message,
403429
)
430+
elif progress_token in self._resource_callbacks:
431+
callback = self._resource_callbacks[progress_token]
432+
await callback(
433+
notification.root.params.progress,
434+
notification.root.params.total,
435+
notification.root.params.message,
436+
notification.root.params.resource_uri,
437+
)
438+
404439
await self._received_notification(notification)
405440
await self._handle_incoming(notification)
406441
except Exception as e:

src/mcp/types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,12 +350,19 @@ class ProgressNotificationParams(NotificationParams):
350350
total is unknown.
351351
"""
352352
total: float | None = None
353+
"""Total number of items to process (or total progress required), if known."""
354+
message: str | None = None
353355
"""
354356
Message related to progress. This should provide relevant human readable
355357
progress information.
356358
"""
357-
message: str | None = None
358-
"""Total number of items to process (or total progress required), if known."""
359+
resource_uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] | None = None
360+
"""
361+
An optional reference to an ephemeral resource associated with this
362+
progress, servers may delete these at their descretion, but are encouraged
363+
to make them available for a reasonable time period to allow clients to
364+
retrieve and cache the resources locally
365+
"""
359366
model_config = ConfigDict(extra="allow")
360367

361368

tests/issues/test_176_progress_token.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ async def test_progress_token_zero_first_call():
3939
mock_session.send_progress_notification.call_count == 3
4040
), "All progress notifications should be sent"
4141
mock_session.send_progress_notification.assert_any_call(
42-
progress_token=0, progress=0.0, total=10.0, message=None
42+
progress_token=0, progress=0.0, total=10.0, message=None, resource_uri=None
4343
)
4444
mock_session.send_progress_notification.assert_any_call(
45-
progress_token=0, progress=5.0, total=10.0, message=None
45+
progress_token=0, progress=5.0, total=10.0, message=None, resource_uri=None
4646
)
4747
mock_session.send_progress_notification.assert_any_call(
48-
progress_token=0, progress=10.0, total=10.0, message=None
48+
progress_token=0, progress=10.0, total=10.0, message=None, resource_uri=None
4949
)

0 commit comments

Comments
 (0)