Skip to content

Commit c502663

Browse files
committed
fix type hints for serversession
1 parent 226cafb commit c502663

File tree

11 files changed

+29
-20
lines changed

11 files changed

+29
-20
lines changed

examples/snippets/servers/elicitation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pydantic import BaseModel, Field
22

33
from mcp.server.fastmcp import Context, FastMCP
4-
from mcp.server.session import ServerSession
4+
from mcp.server.session import ServerTransportSession
55

66
mcp = FastMCP(name="Elicitation Example")
77

@@ -17,7 +17,7 @@ class BookingPreferences(BaseModel):
1717

1818

1919
@mcp.tool()
20-
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str:
20+
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerTransportSession, None]) -> str:
2121
"""Book a table with date availability check."""
2222
# Check if date is available
2323
if date == "2024-12-25":

examples/snippets/servers/lifespan_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66

77
from mcp.server.fastmcp import Context, FastMCP
8-
from mcp.server.session import ServerSession
8+
from mcp.server.session import ServerTransportSession
99

1010

1111
# Mock database class for example
@@ -51,7 +51,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
5151

5252
# Access type-safe lifespan context in tools
5353
@mcp.tool()
54-
def query_db(ctx: Context[ServerSession, AppContext]) -> str:
54+
def query_db(ctx: Context[ServerTransportSession, AppContext]) -> str:
5555
"""Tool that uses initialized resources."""
5656
db = ctx.request_context.lifespan_context.db
5757
return db.query()

examples/snippets/servers/notifications.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from mcp.server.fastmcp import Context, FastMCP
2-
from mcp.server.session import ServerSession
2+
from mcp.server.session import ServerTransportSession
33

44
mcp = FastMCP(name="Notifications Example")
55

66

77
@mcp.tool()
8-
async def process_data(data: str, ctx: Context[ServerSession, None]) -> str:
8+
async def process_data(data: str, ctx: Context[ServerTransportSession, None]) -> str:
99
"""Process data with logging."""
1010
# Different log levels
1111
await ctx.debug(f"Debug: Processing '{data}'")

examples/snippets/servers/tool_progress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from mcp.server.fastmcp import Context, FastMCP
2-
from mcp.server.session import ServerSession
2+
from mcp.server.session import ServerTransportSession
33

44
mcp = FastMCP(name="Progress Example")
55

66

77
@mcp.tool()
8-
async def long_running_task(task_name: str, ctx: Context[ServerSession, None], steps: int = 5) -> str:
8+
async def long_running_task(task_name: str, ctx: Context[ServerTransportSession, None], steps: int = 5) -> str:
99
"""Execute a task with progress updates."""
1010
await ctx.info(f"Starting: {task_name}")
1111

src/mcp/server/elicitation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import BaseModel
99
from pydantic.fields import FieldInfo
1010

11-
from mcp.server.session import ServerSession
11+
from mcp.server.session import ServerTransportSession
1212
from mcp.types import RequestId
1313

1414
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
@@ -74,7 +74,7 @@ def _is_primitive_field(field_info: FieldInfo) -> bool:
7474

7575

7676
async def elicit_with_validation(
77-
session: ServerSession,
77+
session: ServerTransportSession,
7878
message: str,
7979
schema: type[ElicitSchemaModelT],
8080
related_request_id: RequestId | None = None,

src/mcp/server/fastmcp/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from mcp.server.lowlevel.server import LifespanResultT
5555
from mcp.server.lowlevel.server import Server as MCPServer
5656
from mcp.server.lowlevel.server import lifespan as default_lifespan
57-
from mcp.server.session import ServerSession, ServerSessionT
57+
from mcp.server.session import ServerSessionT, ServerTransportSession
5858
from mcp.server.sse import SseServerTransport
5959
from mcp.server.stdio import stdio_server
6060
from mcp.server.streamable_http import EventStore
@@ -315,7 +315,7 @@ async def list_tools(self) -> list[MCPTool]:
315315
for info in tools
316316
]
317317

318-
def get_context(self) -> Context[ServerSession, LifespanResultT, Request]:
318+
def get_context(self) -> Context[ServerTransportSession, LifespanResultT, Request]:
319319
"""
320320
Returns a Context object. Note that the context will only be valid
321321
during a request; outside a request, most methods will error.

src/mcp/server/lowlevel/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def main():
8585
from mcp.server.lowlevel.func_inspection import create_call_wrapper
8686
from mcp.server.lowlevel.helper_types import ReadResourceContents
8787
from mcp.server.models import InitializationOptions
88-
from mcp.server.session import ServerSession
88+
from mcp.server.session import ServerSession, ServerTransportSession
8989
from mcp.shared.context import RequestContext
9090
from mcp.shared.exceptions import McpError
9191
from mcp.shared.message import ServerMessageMetadata, SessionMessage
@@ -102,7 +102,7 @@ async def main():
102102
CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent]
103103

104104
# This will be properly typed in each Server instance's context
105-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx")
105+
request_ctx: contextvars.ContextVar[RequestContext[ServerTransportSession, Any, Any]] = contextvars.ContextVar("request_ctx")
106106

107107

108108
class NotificationOptions:
@@ -231,7 +231,7 @@ def get_capabilities(
231231
@property
232232
def request_context(
233233
self,
234-
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
234+
) -> RequestContext[ServerTransportSession, LifespanResultT, RequestT]:
235235
"""If called outside of a request context, this will raise a LookupError."""
236236
return request_ctx.get()
237237

src/mcp/server/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class InitializationState(Enum):
6262
Initialized = 3
6363

6464

65-
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
65+
ServerSessionT = TypeVar("ServerSessionT", bound="ServerTransportSession")
6666

6767
ServerRequestResponder = (
6868
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception

src/mcp/shared/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
from typing_extensions import TypeVar
55

66
from mcp.client.transport_session import ClientTransportSession
7+
from mcp.server.transport_session import ServerTransportSession
78
from mcp.shared.session import BaseSession
89
from mcp.types import RequestId, RequestParams
910

10-
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | ClientTransportSession)
11+
SessionT = TypeVar("SessionT",
12+
bound=BaseSession[Any, Any, Any, Any, Any] |
13+
ClientTransportSession |
14+
ServerTransportSession)
1115
LifespanContextT = TypeVar("LifespanContextT")
1216
RequestT = TypeVar("RequestT", default=Any)
1317

tests/client/test_sampling_callback.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22

33
from mcp.client.session import ClientTransportSession
4+
from mcp.server.session import ServerSession
5+
from typing import cast
46
from mcp.shared.context import RequestContext
57
from mcp.shared.memory import (
68
create_connected_server_and_client_session as create_session,
@@ -34,7 +36,8 @@ async def sampling_callback(
3436

3537
@server.tool("test_sampling")
3638
async def test_sampling_tool(message: str):
37-
value = await server.get_context().session.create_message(
39+
session = cast(ServerSession, server.get_context().session)
40+
value = await session.create_message(
3841
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
3942
max_tokens=100,
4043
)

0 commit comments

Comments
 (0)