Skip to content

Commit 7cf0cc2

Browse files
Merge branch 'main' into feat/sampling-resources
2 parents 435483e + f2f4dbd commit 7cf0cc2

File tree

14 files changed

+412
-91
lines changed

14 files changed

+412
-91
lines changed

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def get_response(self, messages: list[dict[str, str]]) -> str:
245245
}
246246
payload = {
247247
"messages": messages,
248-
"model": "llama-3.2-90b-vision-preview",
248+
"model": "meta-llama/llama-4-scout-17b-16e-instruct",
249249
"temperature": 0.7,
250250
"max_tokens": 4096,
251251
"top_p": 1,
@@ -284,12 +284,9 @@ def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
284284

285285
async def cleanup_servers(self) -> None:
286286
"""Clean up all servers properly."""
287-
cleanup_tasks = [
288-
asyncio.create_task(server.cleanup()) for server in self.servers
289-
]
290-
if cleanup_tasks:
287+
for server in reversed(self.servers):
291288
try:
292-
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
289+
await server.cleanup()
293290
except Exception as e:
294291
logging.warning(f"Warning during final cleanup: {e}")
295292

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,13 @@ members = ["examples/servers/*"]
109109
mcp = { workspace = true }
110110

111111
[tool.pytest.ini_options]
112+
log_cli = true
112113
xfail_strict = true
114+
addopts = """
115+
--color=yes
116+
--capture=fd
117+
--numprocesses auto
118+
"""
113119
filterwarnings = [
114120
"error",
115121
# This should be fixed on Uvicorn's side.

src/mcp/client/session.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ async def list_resources(
209209
types.ClientRequest(
210210
types.ListResourcesRequest(
211211
method="resources/list",
212-
cursor=cursor,
212+
params=types.PaginatedRequestParams(cursor=cursor)
213+
if cursor is not None
214+
else None,
213215
)
214216
),
215217
types.ListResourcesResult,
@@ -223,7 +225,9 @@ async def list_resource_templates(
223225
types.ClientRequest(
224226
types.ListResourceTemplatesRequest(
225227
method="resources/templates/list",
226-
cursor=cursor,
228+
params=types.PaginatedRequestParams(cursor=cursor)
229+
if cursor is not None
230+
else None,
227231
)
228232
),
229233
types.ListResourceTemplatesResult,
@@ -295,7 +299,9 @@ async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResu
295299
types.ClientRequest(
296300
types.ListPromptsRequest(
297301
method="prompts/list",
298-
cursor=cursor,
302+
params=types.PaginatedRequestParams(cursor=cursor)
303+
if cursor is not None
304+
else None,
299305
)
300306
),
301307
types.ListPromptsResult,
@@ -340,7 +346,9 @@ async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult:
340346
types.ClientRequest(
341347
types.ListToolsRequest(
342348
method="tools/list",
343-
cursor=cursor,
349+
params=types.PaginatedRequestParams(cursor=cursor)
350+
if cursor is not None
351+
else None,
344352
)
345353
),
346354
types.ListToolsResult,

src/mcp/client/session_group.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ class ClientSessionGroup:
7777
the client and can be accessed via the session.
7878
7979
Example Usage:
80-
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
80+
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
8181
async with ClientSessionGroup(component_name_hook=name_fn) as group:
8282
for server_params in server_params:
83-
group.connect_to_server(server_param)
83+
await group.connect_to_server(server_param)
8484
...
8585
8686
"""
@@ -145,14 +145,15 @@ async def __aexit__(
145145
) -> bool | None:
146146
"""Closes session exit stacks and main exit stack upon completion."""
147147

148+
# Only close the main exit stack if we created it
149+
if self._owns_exit_stack:
150+
await self._exit_stack.aclose()
151+
148152
# Concurrently close session stacks.
149153
async with anyio.create_task_group() as tg:
150154
for exit_stack in self._session_exit_stacks.values():
151155
tg.start_soon(exit_stack.aclose)
152156

153-
# Only close the main exit stack if we created it
154-
if self._owns_exit_stack:
155-
await self._exit_stack.aclose()
156157

157158
@property
158159
def sessions(self) -> list[mcp.ClientSession]:

src/mcp/client/sse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from httpx_sse import aconnect_sse
1111

1212
import mcp.types as types
13-
from mcp.shared._httpx_utils import create_mcp_http_client
13+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1414
from mcp.shared.message import SessionMessage
1515

1616
logger = logging.getLogger(__name__)
@@ -26,6 +26,7 @@ async def sse_client(
2626
headers: dict[str, Any] | None = None,
2727
timeout: float = 5,
2828
sse_read_timeout: float = 60 * 5,
29+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
2930
auth: httpx.Auth | None = None,
3031
):
3132
"""
@@ -53,7 +54,7 @@ async def sse_client(
5354
async with anyio.create_task_group() as tg:
5455
try:
5556
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
56-
async with create_mcp_http_client(headers=headers, auth=auth) as client:
57+
async with httpx_client_factory(headers=headers, auth=auth) as client:
5758
async with aconnect_sse(
5859
client,
5960
"GET",

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2020
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2121

22-
from mcp.shared._httpx_utils import create_mcp_http_client
22+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
2323
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2424
from mcp.types import (
2525
ErrorData,
@@ -430,6 +430,7 @@ async def streamablehttp_client(
430430
timeout: timedelta = timedelta(seconds=30),
431431
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
432432
terminate_on_close: bool = True,
433+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
433434
auth: httpx.Auth | None = None,
434435
) -> AsyncGenerator[
435436
tuple[
@@ -464,7 +465,7 @@ async def streamablehttp_client(
464465
try:
465466
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
466467

467-
async with create_mcp_http_client(
468+
async with httpx_client_factory(
468469
headers=transport.request_headers,
469470
timeout=httpx.Timeout(
470471
transport.timeout.seconds, read=transport.sse_read_timeout.seconds

src/mcp/server/fastmcp/tools/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import functools
34
import inspect
45
from collections.abc import Callable
56
from typing import TYPE_CHECKING, Any, get_origin
@@ -53,7 +54,7 @@ def from_function(
5354
raise ValueError("You must provide a name for lambda functions")
5455

5556
func_doc = description or fn.__doc__ or ""
56-
is_async = inspect.iscoroutinefunction(fn)
57+
is_async = _is_async_callable(fn)
5758

5859
if context_kwarg is None:
5960
sig = inspect.signature(fn)
@@ -98,3 +99,12 @@ async def run(
9899
)
99100
except Exception as e:
100101
raise ToolError(f"Error executing tool {self.name}: {e}") from e
102+
103+
104+
def _is_async_callable(obj: Any) -> bool:
105+
while isinstance(obj, functools.partial):
106+
obj = obj.func
107+
108+
return inspect.iscoroutinefunction(obj) or (
109+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
110+
)

src/mcp/shared/_httpx_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
"""Utilities for creating standardized httpx AsyncClient instances."""
22

3-
from typing import Any
3+
from typing import Any, Protocol
44

55
import httpx
66

77
__all__ = ["create_mcp_http_client"]
88

99

10+
class McpHttpClientFactory(Protocol):
11+
def __call__(
12+
self,
13+
headers: dict[str, str] | None = None,
14+
timeout: httpx.Timeout | None = None,
15+
auth: httpx.Auth | None = None,
16+
) -> httpx.AsyncClient: ...
17+
18+
1019
def create_mcp_http_client(
1120
headers: dict[str, str] | None = None,
1221
timeout: httpx.Timeout | None = None,

src/mcp/types.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ class Meta(BaseModel):
5353
meta: Meta | None = Field(alias="_meta", default=None)
5454

5555

56+
class PaginatedRequestParams(RequestParams):
57+
cursor: Cursor | None = None
58+
"""
59+
An opaque token representing the current pagination position.
60+
If provided, the server should return results starting after this cursor.
61+
"""
62+
63+
5664
class NotificationParams(BaseModel):
5765
class Meta(BaseModel):
5866
model_config = ConfigDict(extra="allow")
@@ -79,12 +87,13 @@ class Request(BaseModel, Generic[RequestParamsT, MethodT]):
7987
model_config = ConfigDict(extra="allow")
8088

8189

82-
class PaginatedRequest(Request[RequestParamsT, MethodT]):
83-
cursor: Cursor | None = None
84-
"""
85-
An opaque token representing the current pagination position.
86-
If provided, the server should return results starting after this cursor.
87-
"""
90+
class PaginatedRequest(
91+
Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]
92+
):
93+
"""Base class for paginated requests,
94+
matching the schema's PaginatedRequest interface."""
95+
96+
params: PaginatedRequestParams | None = None
8897

8998

9099
class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
@@ -358,13 +367,10 @@ class ProgressNotification(
358367
params: ProgressNotificationParams
359368

360369

361-
class ListResourcesRequest(
362-
PaginatedRequest[RequestParams | None, Literal["resources/list"]]
363-
):
370+
class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]):
364371
"""Sent from the client to request a list of resources the server has."""
365372

366373
method: Literal["resources/list"]
367-
params: RequestParams | None = None
368374

369375

370376
class Annotations(BaseModel):
@@ -423,12 +429,11 @@ class ListResourcesResult(PaginatedResult):
423429

424430

425431
class ListResourceTemplatesRequest(
426-
PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]
432+
PaginatedRequest[Literal["resources/templates/list"]]
427433
):
428434
"""Sent from the client to request a list of resource templates the server has."""
429435

430436
method: Literal["resources/templates/list"]
431-
params: RequestParams | None = None
432437

433438

434439
class ListResourceTemplatesResult(PaginatedResult):
@@ -570,13 +575,10 @@ class ResourceUpdatedNotification(
570575
params: ResourceUpdatedNotificationParams
571576

572577

573-
class ListPromptsRequest(
574-
PaginatedRequest[RequestParams | None, Literal["prompts/list"]]
575-
):
578+
class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]):
576579
"""Sent from the client to request a list of prompts and prompt templates."""
577580

578581
method: Literal["prompts/list"]
579-
params: RequestParams | None = None
580582

581583

582584
class PromptArgument(BaseModel):
@@ -703,11 +705,10 @@ class PromptListChangedNotification(
703705
params: NotificationParams | None = None
704706

705707

706-
class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
708+
class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]):
707709
"""Sent from the client to request a list of tools the server has."""
708710

709711
method: Literal["tools/list"]
710-
params: RequestParams | None = None
711712

712713

713714
class ToolAnnotations(BaseModel):
@@ -741,7 +742,7 @@ class ToolAnnotations(BaseModel):
741742

742743
idempotentHint: bool | None = None
743744
"""
744-
If true, calling the tool repeatedly with the same arguments
745+
If true, calling the tool repeatedly with the same arguments
745746
will have no additional effect on the its environment.
746747
(This property is meaningful only when `readOnlyHint == false`)
747748
Default: false

0 commit comments

Comments
 (0)