|
14 | 14 | from mcp.server.fastmcp import Context, FastMCP |
15 | 15 | from mcp.server.fastmcp.prompts.base import UserMessage |
16 | 16 | from mcp.server.session import ServerSession |
| 17 | +from mcp.server.streamable_http import ( |
| 18 | + EventCallback, |
| 19 | + EventId, |
| 20 | + EventMessage, |
| 21 | + EventStore, |
| 22 | + StreamId, |
| 23 | +) |
17 | 24 | from mcp.types import ( |
18 | 25 | AudioContent, |
19 | 26 | Completion, |
20 | 27 | CompletionArgument, |
21 | 28 | CompletionContext, |
22 | 29 | EmbeddedResource, |
23 | 30 | ImageContent, |
| 31 | + JSONRPCMessage, |
24 | 32 | PromptReference, |
25 | 33 | ResourceTemplateReference, |
26 | 34 | SamplingMessage, |
27 | 35 | TextContent, |
28 | 36 | TextResourceContents, |
29 | 37 | ) |
30 | 38 | from pydantic import AnyUrl, BaseModel, Field |
| 39 | +from starlette.requests import Request |
31 | 40 |
|
32 | 41 | logger = logging.getLogger(__name__) |
33 | 42 |
|
|
39 | 48 | resource_subscriptions: set[str] = set() |
40 | 49 | watched_resource_content = "Watched resource content" |
41 | 50 |
|
| 51 | + |
| 52 | +# Simple in-memory event store for SSE polling resumability (SEP-1699) |
| 53 | +class SimpleEventStore(EventStore): |
| 54 | + """Simple in-memory event store for testing resumability.""" |
| 55 | + |
| 56 | + def __init__(self) -> None: |
| 57 | + self._events: list[tuple[StreamId, EventId, JSONRPCMessage]] = [] |
| 58 | + self._event_id_counter = 0 |
| 59 | + |
| 60 | + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: |
| 61 | + """Store an event and return its ID.""" |
| 62 | + self._event_id_counter += 1 |
| 63 | + event_id = str(self._event_id_counter) |
| 64 | + self._events.append((stream_id, event_id, message)) |
| 65 | + return event_id |
| 66 | + |
| 67 | + async def replay_events_after( |
| 68 | + self, |
| 69 | + last_event_id: EventId, |
| 70 | + send_callback: EventCallback, |
| 71 | + ) -> StreamId | None: |
| 72 | + """Replay events after the specified ID.""" |
| 73 | + target_stream_id = None |
| 74 | + found = False |
| 75 | + for stream_id, event_id, message in self._events: |
| 76 | + if event_id == last_event_id: |
| 77 | + target_stream_id = stream_id |
| 78 | + found = True |
| 79 | + continue |
| 80 | + if found and stream_id == target_stream_id: |
| 81 | + await send_callback(EventMessage(message=message, event_id=event_id)) |
| 82 | + return target_stream_id |
| 83 | + |
| 84 | + |
| 85 | +# Create event store for resumability |
| 86 | +event_store = SimpleEventStore() |
| 87 | + |
42 | 88 | mcp = FastMCP( |
43 | 89 | name="mcp-conformance-test-server", |
| 90 | + event_store=event_store, |
44 | 91 | ) |
45 | 92 |
|
46 | 93 |
|
@@ -257,6 +304,33 @@ async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> s |
257 | 304 | return f"Elicitation not supported or error: {str(e)}" |
258 | 305 |
|
259 | 306 |
|
| 307 | +@mcp.tool() |
| 308 | +async def test_reconnection(ctx: Context[ServerSession, None]) -> str: |
| 309 | + """Tests SSE polling via server-initiated disconnect (SEP-1699) |
| 310 | +
|
| 311 | + This tool closes the SSE stream mid-call, requiring the client to reconnect |
| 312 | + with Last-Event-ID to receive the remaining events. |
| 313 | + """ |
| 314 | + # Send notification before disconnect |
| 315 | + await ctx.info("Notification before disconnect") |
| 316 | + |
| 317 | + # Get session_id from request headers |
| 318 | + request = ctx.request_context.request |
| 319 | + if isinstance(request, Request): |
| 320 | + session_id = request.headers.get("mcp-session-id") |
| 321 | + if session_id: |
| 322 | + # Trigger server-initiated SSE disconnect |
| 323 | + await mcp.session_manager.close_sse_stream(session_id, ctx.request_id) |
| 324 | + |
| 325 | + # Wait for client to reconnect |
| 326 | + await asyncio.sleep(0.2) |
| 327 | + |
| 328 | + # Send notification after disconnect (will be replayed via event store) |
| 329 | + await ctx.info("Notification after disconnect") |
| 330 | + |
| 331 | + return "Reconnection test completed successfully" |
| 332 | + |
| 333 | + |
260 | 334 | @mcp.tool() |
261 | 335 | def test_error_handling() -> str: |
262 | 336 | """Tests error response handling""" |
|
0 commit comments