Skip to content

Commit 226cafb

Browse files
committed
fix typehints for client transport session
1 parent 0002f70 commit 226cafb

File tree

14 files changed

+73
-47
lines changed

14 files changed

+73
-47
lines changed

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from urllib.parse import parse_qs, urlparse
1818

1919
from mcp.client.auth import OAuthClientProvider, TokenStorage
20-
from mcp.client.session import ClientSession
20+
from mcp.client.session import ClientSession, ClientTransportSession
2121
from mcp.client.sse import sse_client
2222
from mcp.client.streamable_http import streamablehttp_client
2323
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
@@ -153,7 +153,7 @@ class SimpleAuthClient:
153153
def __init__(self, server_url: str, transport_type: str = "streamable-http"):
154154
self.server_url = server_url
155155
self.transport_type = transport_type
156-
self.session: ClientSession | None = None
156+
self.session: ClientTransportSession | None = None
157157

158158
async def connect(self):
159159
"""Connect to the MCP server."""

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dotenv import load_dotenv
1111
from mcp import ClientSession, StdioServerParameters
1212
from mcp.client.stdio import stdio_client
13+
from mcp.client.transport_session import ClientTransportSession
1314

1415
# Configure logging
1516
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
@@ -67,7 +68,7 @@ def __init__(self, name: str, config: dict[str, Any]) -> None:
6768
self.name: str = name
6869
self.config: dict[str, Any] = config
6970
self.stdio_context: Any | None = None
70-
self.session: ClientSession | None = None
71+
self.session: ClientTransportSession | None = None
7172
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
7273
self.exit_stack: AsyncExitStack = AsyncExitStack()
7374

examples/snippets/clients/display_utilities.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from mcp import ClientSession, StdioServerParameters
1010
from mcp.client.stdio import stdio_client
11+
from mcp.client.transport_session import ClientTransportSession
1112
from mcp.shared.metadata_utils import get_display_name
1213

1314
# Create server parameters for stdio connection
@@ -18,7 +19,7 @@
1819
)
1920

2021

21-
async def display_tools(session: ClientSession):
22+
async def display_tools(session: ClientTransportSession):
2223
"""Display available tools with human-readable names"""
2324
tools_response = await session.list_tools()
2425

@@ -30,7 +31,7 @@ async def display_tools(session: ClientSession):
3031
print(f" {tool.description}")
3132

3233

33-
async def display_resources(session: ClientSession):
34+
async def display_resources(session: ClientTransportSession):
3435
"""Display available resources with human-readable names"""
3536
resources_response = await session.list_resources()
3637

examples/snippets/clients/stdio_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import AnyUrl
1010

1111
from mcp import ClientSession, StdioServerParameters, types
12+
from mcp.client.session import ClientTransportSession
1213
from mcp.client.stdio import stdio_client
1314
from mcp.shared.context import RequestContext
1415

@@ -22,7 +23,7 @@
2223

2324
# Optional: create a sampling callback
2425
async def handle_sampling_message(
25-
context: RequestContext[ClientSession, None], params: types.CreateMessageRequestParams
26+
context: RequestContext[ClientTransportSession, None], params: types.CreateMessageRequestParams
2627
) -> types.CreateMessageResult:
2728
print(f"Sampling request: {params.messages}")
2829
return types.CreateMessageResult(

src/mcp/client/session.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
class SamplingFnT(Protocol):
2424
async def __call__(
2525
self,
26-
context: RequestContext["ClientSession", Any],
26+
context: RequestContext["ClientTransportSession", Any],
2727
params: types.CreateMessageRequestParams,
2828
) -> types.CreateMessageResult | types.ErrorData: ...
2929

3030

3131
class ElicitationFnT(Protocol):
3232
async def __call__(
3333
self,
34-
context: RequestContext["ClientSession", Any],
34+
context: RequestContext["ClientTransportSession", Any],
3535
params: types.ElicitRequestParams,
3636
) -> types.ElicitResult | types.ErrorData: ...
3737

3838

3939
class ListRootsFnT(Protocol):
4040
async def __call__(
41-
self, context: RequestContext["ClientSession", Any]
41+
self, context: RequestContext["ClientTransportSession", Any]
4242
) -> types.ListRootsResult | types.ErrorData: ...
4343

4444

@@ -63,7 +63,7 @@ async def _default_message_handler(
6363

6464

6565
async def _default_sampling_callback(
66-
context: RequestContext["ClientSession", Any],
66+
context: RequestContext["ClientTransportSession", Any],
6767
params: types.CreateMessageRequestParams,
6868
) -> types.CreateMessageResult | types.ErrorData:
6969
return types.ErrorData(
@@ -73,7 +73,7 @@ async def _default_sampling_callback(
7373

7474

7575
async def _default_elicitation_callback(
76-
context: RequestContext["ClientSession", Any],
76+
context: RequestContext["ClientTransportSession", Any],
7777
params: types.ElicitRequestParams,
7878
) -> types.ElicitResult | types.ErrorData:
7979
return types.ErrorData(
@@ -83,7 +83,7 @@ async def _default_elicitation_callback(
8383

8484

8585
async def _default_list_roots_callback(
86-
context: RequestContext["ClientSession", Any],
86+
context: RequestContext["ClientTransportSession", Any],
8787
) -> types.ListRootsResult | types.ErrorData:
8888
return types.ErrorData(
8989
code=types.INVALID_REQUEST,
@@ -498,7 +498,7 @@ async def send_roots_list_changed(self) -> None:
498498
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
499499

500500
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
501-
ctx = RequestContext[ClientSession, Any](
501+
ctx = RequestContext[ClientTransportSession, Any](
502502
request_id=responder.request_id,
503503
meta=responder.request_meta,
504504
session=self,

src/mcp/client/session_group.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,10 @@ class _ComponentNames(BaseModel):
9696
_tools: dict[str, types.Tool]
9797

9898
# Client-server connection management.
99-
_sessions: dict[mcp.ClientSession, _ComponentNames]
100-
_tool_to_session: dict[str, mcp.ClientSession]
99+
_sessions: dict[mcp.ClientTransportSession, _ComponentNames]
100+
_tool_to_session: dict[str, mcp.ClientTransportSession]
101101
_exit_stack: contextlib.AsyncExitStack
102-
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
102+
_session_exit_stacks: dict[mcp.ClientTransportSession, contextlib.AsyncExitStack]
103103

104104
# Optional fn consuming (component_name, serverInfo) for custom names.
105105
# This is provide a means to mitigate naming conflicts across servers.
@@ -153,7 +153,7 @@ async def __aexit__(
153153
tg.start_soon(exit_stack.aclose)
154154

155155
@property
156-
def sessions(self) -> list[mcp.ClientSession]:
156+
def sessions(self) -> list[mcp.ClientTransportSession]:
157157
"""Returns the list of sessions being managed."""
158158
return list(self._sessions.keys())
159159

@@ -178,7 +178,7 @@ async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResu
178178
session_tool_name = self.tools[name].name
179179
return await session.call_tool(session_tool_name, args)
180180

181-
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
181+
async def disconnect_from_server(self, session: mcp.ClientTransportSession) -> None:
182182
"""Disconnects from a single MCP server."""
183183

184184
session_known_for_components = session in self._sessions
@@ -216,23 +216,23 @@ async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
216216
await session_stack_to_close.aclose()
217217

218218
async def connect_with_session(
219-
self, server_info: types.Implementation, session: mcp.ClientSession
220-
) -> mcp.ClientSession:
219+
self, server_info: types.Implementation, session: mcp.ClientTransportSession
220+
) -> mcp.ClientTransportSession:
221221
"""Connects to a single MCP server."""
222222
await self._aggregate_components(server_info, session)
223223
return session
224224

225225
async def connect_to_server(
226226
self,
227227
server_params: ServerParameters,
228-
) -> mcp.ClientSession:
228+
) -> mcp.ClientTransportSession:
229229
"""Connects to a single MCP server."""
230230
server_info, session = await self._establish_session(server_params)
231231
return await self.connect_with_session(server_info, session)
232232

233233
async def _establish_session(
234234
self, server_params: ServerParameters
235-
) -> tuple[types.Implementation, mcp.ClientSession]:
235+
) -> tuple[types.Implementation, mcp.ClientTransportSession]:
236236
"""Establish a client session to an MCP server."""
237237

238238
session_stack = contextlib.AsyncExitStack()
@@ -276,7 +276,9 @@ async def _establish_session(
276276
await session_stack.aclose()
277277
raise
278278

279-
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
279+
async def _aggregate_components(
280+
self, server_info: types.Implementation, session: mcp.ClientTransportSession
281+
) -> None:
280282
"""Aggregates prompts, resources, and tools from a given session."""
281283

282284
# Create a reverse index so we can find all prompts, resources, and
@@ -289,7 +291,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session
289291
prompts_temp: dict[str, types.Prompt] = {}
290292
resources_temp: dict[str, types.Resource] = {}
291293
tools_temp: dict[str, types.Tool] = {}
292-
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
294+
tool_to_session_temp: dict[str, mcp.ClientTransportSession] = {}
293295

294296
# Query the server for its prompts and aggregate to list.
295297
try:

src/mcp/client/transport_session.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from abc import ABC, abstractmethod
22
from datetime import timedelta
3-
from typing import Any
3+
from typing import Any, overload
44

55
from pydantic import AnyUrl
6+
from typing_extensions import deprecated
67

78
from mcp import types
89
from mcp.shared.session import ProgressFnT
@@ -109,12 +110,29 @@ async def complete(
109110
"""Send a completion/complete request."""
110111
raise NotImplementedError
111112

113+
@overload
114+
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
115+
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
116+
117+
@overload
118+
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
119+
120+
@overload
121+
async def list_tools(self) -> types.ListToolsResult: ...
122+
112123
@abstractmethod
113124
async def list_tools(
114125
self,
115126
cursor: str | None = None,
127+
*,
128+
params: types.PaginatedRequestParams | None = None,
116129
) -> types.ListToolsResult:
117-
"""Send a tools/list request."""
130+
"""Send a tools/list request.
131+
132+
Args:
133+
cursor: Simple cursor string for pagination (deprecated, use params instead)
134+
params: Full pagination parameters including cursor and any future fields
135+
"""
118136
raise NotImplementedError
119137

120138
@abstractmethod

src/mcp/shared/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
from typing_extensions import TypeVar
55

6+
from mcp.client.transport_session import ClientTransportSession
67
from mcp.shared.session import BaseSession
78
from mcp.types import RequestId, RequestParams
89

9-
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
10+
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession)
1011
LifespanContextT = TypeVar("LifespanContextT")
1112
RequestT = TypeVar("RequestT", default=Any)
1213

tests/client/test_list_roots_callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
from pydantic import FileUrl
33

4-
from mcp.client.session import ClientSession
4+
from mcp.client.session import ClientTransportSession
55
from mcp.server.fastmcp.server import Context
66
from mcp.server.session import ServerSession
77
from mcp.shared.context import RequestContext
@@ -31,7 +31,7 @@ async def test_list_roots_callback():
3131
)
3232

3333
async def list_roots_callback(
34-
context: RequestContext[ClientSession, None],
34+
context: RequestContext[ClientTransportSession, None],
3535
) -> ListRootsResult:
3636
return callback_return
3737

tests/client/test_sampling_callback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from mcp.client.session import ClientSession
3+
from mcp.client.session import ClientTransportSession
44
from mcp.shared.context import RequestContext
55
from mcp.shared.memory import (
66
create_connected_server_and_client_session as create_session,
@@ -27,7 +27,7 @@ async def test_sampling_callback():
2727
)
2828

2929
async def sampling_callback(
30-
context: RequestContext[ClientSession, None],
30+
context: RequestContext[ClientTransportSession, None],
3131
params: CreateMessageRequestParams,
3232
) -> CreateMessageResult:
3333
return callback_return

0 commit comments

Comments
 (0)