Skip to content

Commit 6ec1f81

Browse files
Add stub Client and InMemoryTransport classes with tests
This commit adds: - Client class stub with from_server() class method - InMemoryTransport class stub - Comprehensive test suite for Client (TDD approach) All methods currently raise NotImplementedError. Github-Issue:#1728
1 parent 72a8631 commit 6ec1f81

File tree

6 files changed

+878
-0
lines changed

6 files changed

+878
-0
lines changed

src/mcp/client/client.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""Unified MCP Client that wraps ClientSession with transport management."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from collections.abc import AsyncGenerator
7+
from contextlib import asynccontextmanager
8+
from typing import Any
9+
10+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
11+
from pydantic import AnyUrl
12+
13+
import mcp.types as types
14+
from mcp.client.session import (
15+
ClientSession,
16+
ElicitationFnT,
17+
ListRootsFnT,
18+
LoggingFnT,
19+
MessageHandlerFnT,
20+
SamplingFnT,
21+
)
22+
from mcp.server import Server
23+
from mcp.server.fastmcp import FastMCP
24+
from mcp.shared.message import SessionMessage
25+
from mcp.shared.session import ProgressFnT
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class Client:
31+
"""
32+
A high-level MCP client that manages transport and session lifecycle.
33+
34+
The Client class provides a unified interface for connecting to MCP servers
35+
using different transports (in-memory, stdio, HTTP, etc.) and exposes
36+
all ClientSession functionality with simpler lifecycle management.
37+
38+
Example with in-memory transport:
39+
server = FastMCP("test")
40+
41+
async with Client.from_server(server) as client:
42+
tools = await client.list_tools()
43+
result = await client.call_tool("my_tool", {"arg": "value"})
44+
45+
Example with custom transport:
46+
async with Client(read_stream, write_stream) as client:
47+
await client.initialize()
48+
# Use client...
49+
"""
50+
51+
def __init__(
52+
self,
53+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
54+
write_stream: MemoryObjectSendStream[SessionMessage],
55+
read_timeout_seconds: float | None = None,
56+
sampling_callback: SamplingFnT | None = None,
57+
list_roots_callback: ListRootsFnT | None = None,
58+
logging_callback: LoggingFnT | None = None,
59+
message_handler: MessageHandlerFnT | None = None,
60+
client_info: types.Implementation | None = None,
61+
elicitation_callback: ElicitationFnT | None = None,
62+
) -> None:
63+
"""
64+
Initialize the client with transport streams.
65+
66+
Args:
67+
read_stream: Stream for receiving messages from the server
68+
write_stream: Stream for sending messages to the server
69+
read_timeout_seconds: Timeout for read operations
70+
sampling_callback: Callback for handling sampling requests
71+
list_roots_callback: Callback for handling list roots requests
72+
logging_callback: Callback for handling logging notifications
73+
message_handler: Callback for handling raw messages
74+
client_info: Client implementation info to send to server
75+
elicitation_callback: Callback for handling elicitation requests
76+
"""
77+
raise NotImplementedError("Client.__init__ is not yet implemented")
78+
79+
@classmethod
80+
@asynccontextmanager
81+
async def from_server(
82+
cls,
83+
server: Server[Any] | FastMCP,
84+
read_timeout_seconds: float | None = None,
85+
sampling_callback: SamplingFnT | None = None,
86+
list_roots_callback: ListRootsFnT | None = None,
87+
logging_callback: LoggingFnT | None = None,
88+
message_handler: MessageHandlerFnT | None = None,
89+
client_info: types.Implementation | None = None,
90+
raise_exceptions: bool = False,
91+
elicitation_callback: ElicitationFnT | None = None,
92+
) -> AsyncGenerator[Client, None]:
93+
"""
94+
Create a client connected to an in-memory server.
95+
96+
This is a convenience method that creates an in-memory transport,
97+
starts the server, and returns an initialized client.
98+
99+
Args:
100+
server: The MCP server to connect to (Server or FastMCP instance)
101+
read_timeout_seconds: Timeout for read operations
102+
sampling_callback: Callback for handling sampling requests
103+
list_roots_callback: Callback for handling list roots requests
104+
logging_callback: Callback for handling logging notifications
105+
message_handler: Callback for handling raw messages
106+
client_info: Client implementation info to send to server
107+
raise_exceptions: Whether to raise exceptions from the server
108+
elicitation_callback: Callback for handling elicitation requests
109+
110+
Yields:
111+
An initialized Client connected to the server
112+
113+
Example:
114+
server = FastMCP("test")
115+
116+
@server.tool()
117+
def my_tool(arg: str) -> str:
118+
return f"Result: {arg}"
119+
120+
async with Client.from_server(server) as client:
121+
result = await client.call_tool("my_tool", {"arg": "value"})
122+
"""
123+
# Silence unused parameter warnings in stub
124+
_ = (server, read_timeout_seconds, sampling_callback, list_roots_callback,
125+
logging_callback, message_handler, client_info, raise_exceptions, elicitation_callback)
126+
# Stub: yield fake value, actual implementation will provide real client
127+
yield None # type: ignore[misc]
128+
raise NotImplementedError("Client.from_server is not yet implemented")
129+
130+
async def __aenter__(self) -> Client:
131+
"""Enter the async context manager."""
132+
raise NotImplementedError("Client.__aenter__ is not yet implemented")
133+
134+
async def __aexit__(
135+
self,
136+
exc_type: type[BaseException] | None,
137+
exc_val: BaseException | None,
138+
exc_tb: Any,
139+
) -> None:
140+
"""Exit the async context manager."""
141+
raise NotImplementedError("Client.__aexit__ is not yet implemented")
142+
143+
async def initialize(self) -> types.InitializeResult:
144+
"""
145+
Initialize the MCP session with the server.
146+
147+
This must be called before using other client methods.
148+
149+
Returns:
150+
The initialization result from the server
151+
"""
152+
raise NotImplementedError("Client.initialize is not yet implemented")
153+
154+
def get_server_capabilities(self) -> types.ServerCapabilities | None:
155+
"""
156+
Return the server capabilities received during initialization.
157+
158+
Returns:
159+
The server capabilities, or None if not yet initialized
160+
"""
161+
raise NotImplementedError("Client.get_server_capabilities is not yet implemented")
162+
163+
async def send_ping(self) -> types.EmptyResult:
164+
"""Send a ping request to the server."""
165+
raise NotImplementedError("Client.send_ping is not yet implemented")
166+
167+
async def send_progress_notification(
168+
self,
169+
progress_token: str | int,
170+
progress: float,
171+
total: float | None = None,
172+
message: str | None = None,
173+
) -> None:
174+
"""Send a progress notification to the server."""
175+
raise NotImplementedError("Client.send_progress_notification is not yet implemented")
176+
177+
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
178+
"""Set the logging level on the server."""
179+
raise NotImplementedError("Client.set_logging_level is not yet implemented")
180+
181+
async def list_resources(
182+
self,
183+
cursor: str | None = None,
184+
*,
185+
params: types.PaginatedRequestParams | None = None,
186+
) -> types.ListResourcesResult:
187+
"""
188+
List available resources from the server.
189+
190+
Args:
191+
cursor: Pagination cursor (deprecated, use params instead)
192+
params: Full pagination parameters
193+
194+
Returns:
195+
List of available resources
196+
"""
197+
raise NotImplementedError("Client.list_resources is not yet implemented")
198+
199+
async def list_resource_templates(
200+
self,
201+
cursor: str | None = None,
202+
*,
203+
params: types.PaginatedRequestParams | None = None,
204+
) -> types.ListResourceTemplatesResult:
205+
"""
206+
List available resource templates from the server.
207+
208+
Args:
209+
cursor: Pagination cursor (deprecated, use params instead)
210+
params: Full pagination parameters
211+
212+
Returns:
213+
List of available resource templates
214+
"""
215+
raise NotImplementedError("Client.list_resource_templates is not yet implemented")
216+
217+
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
218+
"""
219+
Read a resource from the server.
220+
221+
Args:
222+
uri: The URI of the resource to read
223+
224+
Returns:
225+
The resource content
226+
"""
227+
raise NotImplementedError("Client.read_resource is not yet implemented")
228+
229+
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
230+
"""Subscribe to resource updates."""
231+
raise NotImplementedError("Client.subscribe_resource is not yet implemented")
232+
233+
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
234+
"""Unsubscribe from resource updates."""
235+
raise NotImplementedError("Client.unsubscribe_resource is not yet implemented")
236+
237+
async def call_tool(
238+
self,
239+
name: str,
240+
arguments: dict[str, Any] | None = None,
241+
read_timeout_seconds: float | None = None,
242+
progress_callback: ProgressFnT | None = None,
243+
*,
244+
meta: dict[str, Any] | None = None,
245+
) -> types.CallToolResult:
246+
"""
247+
Call a tool on the server.
248+
249+
Args:
250+
name: The name of the tool to call
251+
arguments: Arguments to pass to the tool
252+
read_timeout_seconds: Timeout for the tool call
253+
progress_callback: Callback for progress updates
254+
meta: Additional metadata for the request
255+
256+
Returns:
257+
The tool result
258+
"""
259+
raise NotImplementedError("Client.call_tool is not yet implemented")
260+
261+
async def list_prompts(
262+
self,
263+
cursor: str | None = None,
264+
*,
265+
params: types.PaginatedRequestParams | None = None,
266+
) -> types.ListPromptsResult:
267+
"""
268+
List available prompts from the server.
269+
270+
Args:
271+
cursor: Pagination cursor (deprecated, use params instead)
272+
params: Full pagination parameters
273+
274+
Returns:
275+
List of available prompts
276+
"""
277+
raise NotImplementedError("Client.list_prompts is not yet implemented")
278+
279+
async def get_prompt(
280+
self,
281+
name: str,
282+
arguments: dict[str, str] | None = None,
283+
) -> types.GetPromptResult:
284+
"""
285+
Get a prompt from the server.
286+
287+
Args:
288+
name: The name of the prompt
289+
arguments: Arguments to pass to the prompt
290+
291+
Returns:
292+
The prompt content
293+
"""
294+
raise NotImplementedError("Client.get_prompt is not yet implemented")
295+
296+
async def complete(
297+
self,
298+
ref: types.ResourceTemplateReference | types.PromptReference,
299+
argument: dict[str, str],
300+
context_arguments: dict[str, str] | None = None,
301+
) -> types.CompleteResult:
302+
"""
303+
Get completions for a prompt or resource template argument.
304+
305+
Args:
306+
ref: Reference to the prompt or resource template
307+
argument: The argument to complete
308+
context_arguments: Additional context arguments
309+
310+
Returns:
311+
Completion suggestions
312+
"""
313+
raise NotImplementedError("Client.complete is not yet implemented")
314+
315+
async def list_tools(
316+
self,
317+
cursor: str | None = None,
318+
*,
319+
params: types.PaginatedRequestParams | None = None,
320+
) -> types.ListToolsResult:
321+
"""
322+
List available tools from the server.
323+
324+
Args:
325+
cursor: Pagination cursor (deprecated, use params instead)
326+
params: Full pagination parameters
327+
328+
Returns:
329+
List of available tools
330+
"""
331+
raise NotImplementedError("Client.list_tools is not yet implemented")
332+
333+
async def send_roots_list_changed(self) -> None:
334+
"""Send a notification that the roots list has changed."""
335+
raise NotImplementedError("Client.send_roots_list_changed is not yet implemented")
336+
337+
@property
338+
def session(self) -> ClientSession:
339+
"""
340+
Get the underlying ClientSession.
341+
342+
This provides access to the full ClientSession API for advanced use cases.
343+
"""
344+
raise NotImplementedError("Client.session is not yet implemented")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""In-memory transport implementations for MCP clients."""
2+
3+
from mcp.client.transports.memory import InMemoryTransport
4+
5+
__all__ = ["InMemoryTransport"]

0 commit comments

Comments
 (0)