Skip to content

Commit 3ca0718

Browse files
committed
feat: propagate session_id from transport to tool context
Enables tools to access the transport-level session ID (e.g., from streamable-http) via Context.session_id property. This supports session-level tracing and addresses user requests in #485 and #942. Changes: - Add session_id field to InitializationOptions - Store session_id in ServerSession and expose via property - Add session_id field to RequestContext dataclass - Pass session_id when creating RequestContext - Update StreamableHTTPSessionManager to pass session_id at init - Add session_id property to FastMCP Context class - Add comprehensive tests for session_id propagation The session_id is available for streamable-http transport and None for stdio/stateless modes. Tools can now correlate all requests within a session for observability and tracing purposes. Fixes #485 Fixes #942 Related to #421 (OpenTelemetry tracing)
1 parent da4fce2 commit 3ca0718

File tree

7 files changed

+314
-1
lines changed

7 files changed

+314
-1
lines changed

src/mcp/server/fastmcp/server.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,15 @@ def session(self):
12371237
"""Access to the underlying session for advanced usage."""
12381238
return self.request_context.session
12391239

1240+
@property
1241+
def session_id(self) -> str | None:
1242+
"""Get the session ID if available.
1243+
1244+
Returns the transport-level session ID (e.g., from streamable-http),
1245+
or None if not available (e.g., stdio transport or stateless mode).
1246+
"""
1247+
return self.request_context.session_id if self._request_context else None
1248+
12401249
# Convenience methods for common log levels
12411250
async def debug(self, message: str, **extra: Any) -> None:
12421251
"""Send a debug log message."""

src/mcp/server/lowlevel/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def create_initialization_options(
160160
self,
161161
notification_options: NotificationOptions | None = None,
162162
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
163+
session_id: str | None = None,
163164
) -> InitializationOptions:
164165
"""Create initialization options from this server instance."""
165166

@@ -183,6 +184,7 @@ def pkg_version(package: str) -> str:
183184
instructions=self.instructions,
184185
website_url=self.website_url,
185186
icons=self.icons,
187+
session_id=session_id,
186188
)
187189

188190
def get_capabilities(
@@ -691,6 +693,7 @@ async def _handle_request(
691693
session,
692694
lifespan_context,
693695
request=request_data,
696+
session_id=session.session_id,
694697
)
695698
)
696699
response = await handler(req)

src/mcp/server/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ class InitializationOptions(BaseModel):
1818
instructions: str | None = None
1919
website_url: str | None = None
2020
icons: list[Icon] | None = None
21+
session_id: str | None = None

src/mcp/server/session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
)
9494

9595
self._init_options = init_options
96+
self._session_id = init_options.session_id
9697
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
9798
ServerRequestResponder
9899
](0)
@@ -102,6 +103,11 @@ def __init__(
102103
def client_params(self) -> types.InitializeRequestParams | None:
103104
return self._client_params
104105

106+
@property
107+
def session_id(self) -> str | None:
108+
"""Get the session ID if available."""
109+
return self._session_id
110+
105111
def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
106112
"""Check if the client supports a specific capability."""
107113
if self._client_params is None:

src/mcp/server/streamable_http_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
241241
await self.app.run(
242242
read_stream,
243243
write_stream,
244-
self.app.create_initialization_options(),
244+
self.app.create_initialization_options(session_id=new_session_id),
245245
stateless=False, # Stateful mode
246246
)
247247
except Exception as e:

src/mcp/shared/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1818
session: SessionT
1919
lifespan_context: LifespanContextT
2020
request: RequestT | None = None
21+
session_id: str | None = None
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""Tests for session_id propagation through the MCP stack."""
2+
3+
import json
4+
from typing import Any
5+
6+
import pytest
7+
from starlette.types import Message
8+
9+
from mcp.server.fastmcp import Context, FastMCP
10+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
11+
12+
13+
@pytest.mark.anyio
14+
async def test_session_id_propagates_to_tool_context():
15+
"""Test that session_id from transport propagates to tool Context."""
16+
# Track session_id seen in tool
17+
captured_session_id: str | None = None
18+
19+
# Create FastMCP server with a tool that captures session_id
20+
mcp = FastMCP("test-session-id-server")
21+
22+
@mcp.tool()
23+
async def get_session_info(ctx: Context) -> dict[str, Any]:
24+
"""Tool that returns session information."""
25+
nonlocal captured_session_id
26+
captured_session_id = ctx.session_id
27+
return {
28+
"session_id": ctx.session_id,
29+
"request_id": ctx.request_id,
30+
}
31+
32+
# Create session manager
33+
manager = StreamableHTTPSessionManager(app=mcp._mcp_server, stateless=False)
34+
35+
async with manager.run():
36+
# Prepare ASGI scope and messages
37+
scope = {
38+
"type": "http",
39+
"method": "POST",
40+
"path": "/mcp",
41+
"headers": [
42+
(b"content-type", b"application/json"),
43+
(b"accept", b"application/json"),
44+
],
45+
}
46+
47+
# Create initialize request
48+
initialize_request = {
49+
"jsonrpc": "2.0",
50+
"id": 1,
51+
"method": "initialize",
52+
"params": {
53+
"protocolVersion": "2025-03-26",
54+
"capabilities": {},
55+
"clientInfo": {"name": "test-client", "version": "1.0.0"},
56+
},
57+
}
58+
59+
# Track sent messages
60+
sent_messages: list[Message] = []
61+
receive_calls = 0
62+
session_id_from_header: str | None = None
63+
64+
async def mock_receive():
65+
nonlocal receive_calls
66+
receive_calls += 1
67+
if receive_calls == 1:
68+
# First call: send initialize request
69+
return {
70+
"type": "http.request",
71+
"body": json.dumps(initialize_request).encode(),
72+
"more_body": False,
73+
}
74+
# Subsequent calls: end stream
75+
return {"type": "http.disconnect"}
76+
77+
async def mock_send(message: Message):
78+
sent_messages.append(message)
79+
# Capture session ID from response header
80+
if message["type"] == "http.response.start":
81+
nonlocal session_id_from_header
82+
headers = dict(message.get("headers", []))
83+
if b"mcp-session-id" in headers:
84+
session_id_from_header = headers[b"mcp-session-id"].decode()
85+
86+
# Handle request (initialize)
87+
await manager.handle_request(scope, mock_receive, mock_send)
88+
89+
# Verify session ID was set in response header
90+
assert session_id_from_header is not None, "Session ID should be in response header"
91+
92+
# Now make a tools/call request to test session_id in Context
93+
# Reset for second request
94+
receive_calls = 0
95+
sent_messages.clear()
96+
97+
tool_call_request = {
98+
"jsonrpc": "2.0",
99+
"id": 2,
100+
"method": "tools/call",
101+
"params": {"name": "get_session_info", "arguments": {}},
102+
}
103+
104+
scope_with_session = {
105+
**scope,
106+
"headers": [
107+
*scope["headers"],
108+
(b"mcp-session-id", session_id_from_header.encode()),
109+
],
110+
}
111+
112+
async def mock_receive_tool_call():
113+
nonlocal receive_calls
114+
receive_calls += 1
115+
if receive_calls == 1:
116+
return {
117+
"type": "http.request",
118+
"body": json.dumps(tool_call_request).encode(),
119+
"more_body": False,
120+
}
121+
return {"type": "http.disconnect"}
122+
123+
await manager.handle_request(scope_with_session, mock_receive_tool_call, mock_send)
124+
125+
# Verify session_id was captured in tool context
126+
assert captured_session_id is not None, "session_id should be available in Context"
127+
assert captured_session_id == session_id_from_header, (
128+
f"session_id in Context ({captured_session_id}) should match "
129+
f"session ID from header ({session_id_from_header})"
130+
)
131+
132+
133+
@pytest.mark.anyio
134+
async def test_session_id_is_none_for_stateless_mode():
135+
"""Test that session_id is None in stateless mode."""
136+
# Track session_id seen in tool
137+
captured_session_id: str | None = "not-set"
138+
139+
# Create FastMCP server
140+
mcp = FastMCP("test-stateless-server")
141+
142+
@mcp.tool()
143+
async def check_session(ctx: Context) -> dict[str, Any]:
144+
"""Tool that checks session_id."""
145+
nonlocal captured_session_id
146+
captured_session_id = ctx.session_id
147+
return {"has_session_id": ctx.session_id is not None}
148+
149+
# Create session manager in stateless mode
150+
manager = StreamableHTTPSessionManager(app=mcp._mcp_server, stateless=True)
151+
152+
async with manager.run():
153+
scope = {
154+
"type": "http",
155+
"method": "POST",
156+
"path": "/mcp",
157+
"headers": [
158+
(b"content-type", b"application/json"),
159+
(b"accept", b"application/json"),
160+
],
161+
}
162+
163+
initialize_request = {
164+
"jsonrpc": "2.0",
165+
"id": 1,
166+
"method": "initialize",
167+
"params": {
168+
"protocolVersion": "2025-03-26",
169+
"capabilities": {},
170+
"clientInfo": {"name": "test-client", "version": "1.0.0"},
171+
},
172+
}
173+
174+
sent_messages: list[Message] = []
175+
receive_calls = 0
176+
177+
async def mock_receive():
178+
nonlocal receive_calls
179+
receive_calls += 1
180+
if receive_calls == 1:
181+
return {
182+
"type": "http.request",
183+
"body": json.dumps(initialize_request).encode(),
184+
"more_body": False,
185+
}
186+
return {"type": "http.disconnect"}
187+
188+
async def mock_send(message: Message):
189+
sent_messages.append(message)
190+
191+
await manager.handle_request(scope, mock_receive, mock_send)
192+
193+
# In stateless mode, session_id should not be set
194+
# (Note: This test primarily verifies no errors occur;
195+
# we can't easily call a tool in stateless mode without a full integration test)
196+
197+
198+
@pytest.mark.anyio
199+
async def test_session_id_consistent_across_requests():
200+
"""Test that session_id remains consistent across multiple requests in same session."""
201+
# Track all session_ids seen
202+
seen_session_ids: list[str | None] = []
203+
204+
# Create FastMCP server
205+
mcp = FastMCP("test-consistency-server")
206+
207+
@mcp.tool()
208+
async def track_session(ctx: Context) -> dict[str, Any]:
209+
"""Tool that tracks session_id."""
210+
seen_session_ids.append(ctx.session_id)
211+
return {"session_id": ctx.session_id, "call_number": len(seen_session_ids)}
212+
213+
# Create session manager
214+
manager = StreamableHTTPSessionManager(app=mcp._mcp_server, stateless=False)
215+
216+
async with manager.run():
217+
# First request: initialize and get session ID
218+
scope = {
219+
"type": "http",
220+
"method": "POST",
221+
"path": "/mcp",
222+
"headers": [
223+
(b"content-type", b"application/json"),
224+
(b"accept", b"application/json"),
225+
],
226+
}
227+
228+
initialize_request = {
229+
"jsonrpc": "2.0",
230+
"id": 1,
231+
"method": "initialize",
232+
"params": {
233+
"protocolVersion": "2025-03-26",
234+
"capabilities": {},
235+
"clientInfo": {"name": "test-client", "version": "1.0.0"},
236+
},
237+
}
238+
239+
sent_messages: list[Message] = []
240+
session_id_from_header: str | None = None
241+
242+
async def mock_receive_init():
243+
return {
244+
"type": "http.request",
245+
"body": json.dumps(initialize_request).encode(),
246+
"more_body": False,
247+
}
248+
249+
async def mock_send(message: Message):
250+
sent_messages.append(message)
251+
if message["type"] == "http.response.start":
252+
nonlocal session_id_from_header
253+
headers = dict(message.get("headers", []))
254+
if b"mcp-session-id" in headers:
255+
session_id_from_header = headers[b"mcp-session-id"].decode()
256+
257+
await manager.handle_request(scope, mock_receive_init, mock_send)
258+
259+
assert session_id_from_header is not None
260+
261+
# Make multiple tool calls with same session ID
262+
for call_num in range(3):
263+
sent_messages.clear()
264+
265+
tool_call_request = {
266+
"jsonrpc": "2.0",
267+
"id": call_num + 2,
268+
"method": "tools/call",
269+
"params": {"name": "track_session", "arguments": {}},
270+
}
271+
272+
scope_with_session = {
273+
**scope,
274+
"headers": [
275+
*scope["headers"],
276+
(b"mcp-session-id", session_id_from_header.encode()),
277+
],
278+
}
279+
280+
async def mock_receive_tool():
281+
return {
282+
"type": "http.request",
283+
"body": json.dumps(tool_call_request).encode(),
284+
"more_body": False,
285+
}
286+
287+
await manager.handle_request(scope_with_session, mock_receive_tool, mock_send)
288+
289+
# Verify all calls saw the same session_id
290+
assert len(seen_session_ids) == 3, "Should have made 3 tool calls"
291+
assert all(sid == session_id_from_header for sid in seen_session_ids), (
292+
f"All session_ids should match: {seen_session_ids}"
293+
)

0 commit comments

Comments
 (0)