33from __future__ import annotations
44
55import logging
6- from collections .abc import AsyncGenerator
7- from contextlib import AsyncExitStack , asynccontextmanager
6+ from contextlib import AsyncExitStack
87from typing import Any
98
10- from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
119from pydantic import AnyUrl
1210
1311import mcp .types as types
2220from mcp .client .transports .memory import InMemoryTransport
2321from mcp .server import Server
2422from mcp .server .fastmcp import FastMCP
25- from mcp .shared .message import SessionMessage
2623from mcp .shared .session import ProgressFnT
2724
2825logger = logging .getLogger (__name__ )
2926
3027
3128class Client :
3229 """
33- A high-level MCP client that manages transport and session lifecycle .
30+ A high-level MCP client for connecting to MCP servers .
3431
35- The Client class provides a unified interface for connecting to MCP servers
36- using different transports ( in-memory, stdio, HTTP, etc.) and exposes
37- all ClientSession functionality with simpler lifecycle management .
32+ The Client class provides a simple interface for testing MCP servers
33+ using in-memory transport. Pass a Server or FastMCP instance directly
34+ to the constructor .
3835
39- Example with in-memory transport (for testing) :
36+ Example:
4037 server = FastMCP("test")
4138
42- async with Client.from_server(server) as client:
43- tools = await client.list_tools()
44- result = await client.call_tool("my_tool", {"arg": "value"})
39+ @server.tool()
40+ def add(a: int, b: int) -> int:
41+ return a + b
4542
46- Example with custom streams:
47- async with Client(read_stream, write_stream) as client:
48- await client.initialize()
49- # Use client...
43+ async with Client(server) as client:
44+ result = await client.call_tool("add", {"a": 1, "b": 2})
5045 """
5146
5247 def __init__ (
5348 self ,
54- read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ],
55- write_stream : MemoryObjectSendStream [SessionMessage ],
49+ server : Server [Any ] | FastMCP ,
50+ * ,
51+ raise_exceptions : bool = False ,
5652 read_timeout_seconds : float | None = None ,
5753 sampling_callback : SamplingFnT | None = None ,
5854 list_roots_callback : ListRootsFnT | None = None ,
@@ -62,11 +58,11 @@ def __init__(
6258 elicitation_callback : ElicitationFnT | None = None ,
6359 ) -> None :
6460 """
65- Initialize the client with transport streams .
61+ Initialize the client with a server .
6662
6763 Args:
68- read_stream: Stream for receiving messages from the server
69- write_stream: Stream for sending messages to the server
64+ server: The MCP server to connect to (Server or FastMCP instance)
65+ raise_exceptions: Whether to raise exceptions from the server
7066 read_timeout_seconds: Timeout for read operations
7167 sampling_callback: Callback for handling sampling requests
7268 list_roots_callback: Callback for handling list roots requests
@@ -75,8 +71,8 @@ def __init__(
7571 client_info: Client implementation info to send to server
7672 elicitation_callback: Callback for handling elicitation requests
7773 """
78- self ._read_stream = read_stream
79- self ._write_stream = write_stream
74+ self ._server = server
75+ self ._raise_exceptions = raise_exceptions
8076 self ._read_timeout_seconds = read_timeout_seconds
8177 self ._sampling_callback = sampling_callback
8278 self ._list_roots_callback = list_roots_callback
@@ -88,86 +84,40 @@ def __init__(
8884 self ._session : ClientSession | None = None
8985 self ._exit_stack : AsyncExitStack | None = None
9086
91- @classmethod
92- @asynccontextmanager
93- async def from_server (
94- cls ,
95- server : Server [Any ] | FastMCP ,
96- read_timeout_seconds : float | None = None ,
97- sampling_callback : SamplingFnT | None = None ,
98- list_roots_callback : ListRootsFnT | None = None ,
99- logging_callback : LoggingFnT | None = None ,
100- message_handler : MessageHandlerFnT | None = None ,
101- client_info : types .Implementation | None = None ,
102- raise_exceptions : bool = False ,
103- elicitation_callback : ElicitationFnT | None = None ,
104- ) -> AsyncGenerator [Client , None ]:
105- """
106- Create a client connected to an in-memory server.
107-
108- This is a convenience method that creates an in-memory transport,
109- starts the server, and returns an initialized client.
110-
111- Args:
112- server: The MCP server to connect to (Server or FastMCP instance)
113- read_timeout_seconds: Timeout for read operations
114- sampling_callback: Callback for handling sampling requests
115- list_roots_callback: Callback for handling list roots requests
116- logging_callback: Callback for handling logging notifications
117- message_handler: Callback for handling raw messages
118- client_info: Client implementation info to send to server
119- raise_exceptions: Whether to raise exceptions from the server
120- elicitation_callback: Callback for handling elicitation requests
121-
122- Yields:
123- An initialized Client connected to the server
124-
125- Example:
126- server = FastMCP("test")
127-
128- @server.tool()
129- def my_tool(arg: str) -> str:
130- return f"Result: {arg}"
131-
132- async with Client.from_server(server) as client:
133- result = await client.call_tool("my_tool", {"arg": "value"})
134- """
135- transport = InMemoryTransport (server , raise_exceptions = raise_exceptions )
136- async with transport .connect () as (read_stream , write_stream ):
137- client = cls (
138- read_stream = read_stream ,
139- write_stream = write_stream ,
140- read_timeout_seconds = read_timeout_seconds ,
141- sampling_callback = sampling_callback ,
142- list_roots_callback = list_roots_callback ,
143- logging_callback = logging_callback ,
144- message_handler = message_handler ,
145- client_info = client_info ,
146- elicitation_callback = elicitation_callback ,
147- )
148- async with client :
149- await client .initialize ()
150- yield client
151-
15287 async def __aenter__ (self ) -> Client :
15388 """Enter the async context manager."""
15489 self ._exit_stack = AsyncExitStack ()
15590 await self ._exit_stack .__aenter__ ()
15691
157- self ._session = await self ._exit_stack .enter_async_context (
158- ClientSession (
159- read_stream = self ._read_stream ,
160- write_stream = self ._write_stream ,
161- read_timeout_seconds = self ._read_timeout_seconds ,
162- sampling_callback = self ._sampling_callback ,
163- list_roots_callback = self ._list_roots_callback ,
164- logging_callback = self ._logging_callback ,
165- message_handler = self ._message_handler ,
166- client_info = self ._client_info ,
167- elicitation_callback = self ._elicitation_callback ,
92+ try :
93+ # Create transport and connect
94+ transport = InMemoryTransport (self ._server , raise_exceptions = self ._raise_exceptions )
95+ read_stream , write_stream = await self ._exit_stack .enter_async_context (
96+ transport .connect ()
97+ )
98+
99+ # Create session
100+ self ._session = await self ._exit_stack .enter_async_context (
101+ ClientSession (
102+ read_stream = read_stream ,
103+ write_stream = write_stream ,
104+ read_timeout_seconds = self ._read_timeout_seconds ,
105+ sampling_callback = self ._sampling_callback ,
106+ list_roots_callback = self ._list_roots_callback ,
107+ logging_callback = self ._logging_callback ,
108+ message_handler = self ._message_handler ,
109+ client_info = self ._client_info ,
110+ elicitation_callback = self ._elicitation_callback ,
111+ )
168112 )
169- )
170- return self
113+
114+ # Initialize the session
115+ await self ._session .initialize ()
116+
117+ return self
118+ except Exception :
119+ await self ._exit_stack .__aexit__ (None , None , None )
120+ raise
171121
172122 async def __aexit__ (
173123 self ,
@@ -194,17 +144,6 @@ def session(self) -> ClientSession:
194144 raise RuntimeError ("Client must be used within an async context manager" )
195145 return self ._session
196146
197- async def initialize (self ) -> types .InitializeResult :
198- """
199- Initialize the MCP session with the server.
200-
201- This must be called before using other client methods.
202-
203- Returns:
204- The initialization result from the server
205- """
206- return await self .session .initialize ()
207-
208147 def get_server_capabilities (self ) -> types .ServerCapabilities | None :
209148 """
210149 Return the server capabilities received during initialization.
0 commit comments