Skip to content

Commit 6c2f7de

Browse files
committed
client resumability
1 parent 901dc98 commit 6c2f7de

File tree

5 files changed

+502
-97
lines changed

5 files changed

+502
-97
lines changed

src/mcp/client/session.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77

88
import mcp.types as types
99
from mcp.shared.context import RequestContext
10-
from mcp.shared.message import SessionMessage
10+
from mcp.shared.message import (
11+
ClientMessageMetadata,
12+
ResumptionToken,
13+
ResumptionTokenUpdateCallback,
14+
SessionMessage,
15+
)
1116
from mcp.shared.session import BaseSession, RequestResponder
1217
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1318

@@ -255,9 +260,18 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
255260
)
256261

257262
async def call_tool(
258-
self, name: str, arguments: dict[str, Any] | None = None
263+
self,
264+
name: str,
265+
arguments: dict[str, Any] | None = None,
266+
on_resumption_token_update: ResumptionTokenUpdateCallback | None = None,
259267
) -> types.CallToolResult:
260268
"""Send a tools/call request."""
269+
metadata = None
270+
if on_resumption_token_update:
271+
metadata = ClientMessageMetadata(
272+
on_resumption_token_update=on_resumption_token_update,
273+
)
274+
261275
return await self.send_request(
262276
types.ClientRequest(
263277
types.CallToolRequest(
@@ -266,6 +280,28 @@ async def call_tool(
266280
)
267281
),
268282
types.CallToolResult,
283+
metadata=metadata,
284+
)
285+
286+
async def resume_tool(
287+
self,
288+
resumption_token: ResumptionToken,
289+
) -> types.CallToolResult:
290+
"""Send a tools/call request with resumtion token to resume the tool."""
291+
292+
return await self.send_request(
293+
types.ClientRequest(
294+
types.CallToolRequest(
295+
method="tools/call",
296+
params=types.CallToolRequestParams(
297+
name="resume_from_token", arguments={}
298+
),
299+
)
300+
),
301+
types.CallToolResult,
302+
metadata=ClientMessageMetadata(
303+
resumption_token=resumption_token,
304+
),
269305
)
270306

271307
async def list_prompts(self) -> types.ListPromptsResult:

src/mcp/client/streamable_http.py

Lines changed: 192 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import httpx
1616
from httpx_sse import EventSource, aconnect_sse
1717

18-
from mcp.shared.message import SessionMessage
18+
from mcp.shared.message import ClientMessageMetadata, SessionMessage
1919
from mcp.types import (
2020
ErrorData,
2121
JSONRPCError,
2222
JSONRPCMessage,
2323
JSONRPCNotification,
2424
JSONRPCRequest,
25+
JSONRPCResponse,
2526
)
2627

2728
logger = logging.getLogger(__name__)
@@ -49,7 +50,7 @@ async def streamablehttp_client(
4950
event before disconnecting. All other HTTP operations are controlled by `timeout`.
5051
5152
Yields:
52-
Tuple of (read_stream, write_stream, terminate_callback)
53+
Tuple of (read_stream, write_stream, terminate_callback, get_session_id_callback)
5354
"""
5455

5556
read_stream_writer, read_stream = anyio.create_memory_object_stream[
@@ -104,11 +105,28 @@ async def post_writer(client: httpx.AsyncClient):
104105
async with write_stream_reader:
105106
async for session_message in write_stream_reader:
106107
message = session_message.message
108+
metadata = (
109+
session_message.metadata
110+
if isinstance(session_message.metadata, ClientMessageMetadata)
111+
else None
112+
)
113+
107114
# Add session ID to headers if we have one
108115
post_headers = request_headers.copy()
109116
if session_id:
110117
post_headers[MCP_SESSION_ID_HEADER] = session_id
111118

119+
# Check if this is a resumption request
120+
is_resumption = False
121+
original_request_id = None
122+
if metadata and metadata.resumption_token:
123+
# For resumption, use GET instead of POST
124+
is_resumption = True
125+
post_headers[LAST_EVENT_ID_HEADER] = metadata.resumption_token
126+
# Store the original request ID to map responses
127+
if isinstance(message.root, JSONRPCRequest):
128+
original_request_id = message.root.id
129+
112130
logger.debug(f"Sending client message: {message}")
113131

114132
# Handle initial initialization request
@@ -122,92 +140,176 @@ async def post_writer(client: httpx.AsyncClient):
122140
):
123141
tg.start_soon(get_stream)
124142

125-
async with client.stream(
126-
"POST",
127-
url,
128-
json=message.model_dump(
129-
by_alias=True, mode="json", exclude_none=True
130-
),
131-
headers=post_headers,
132-
) as response:
133-
if response.status_code == 202:
134-
logger.debug("Received 202 Accepted")
135-
continue
136-
# Check for 404 (session expired/invalid)
137-
if response.status_code == 404:
138-
if isinstance(message.root, JSONRPCRequest):
139-
jsonrpc_error = JSONRPCError(
140-
jsonrpc="2.0",
141-
id=message.root.id,
142-
error=ErrorData(
143-
code=32600,
144-
message="Session terminated",
145-
),
146-
)
147-
session_message = SessionMessage(
148-
JSONRPCMessage(jsonrpc_error)
149-
)
150-
await read_stream_writer.send(session_message)
151-
continue
152-
response.raise_for_status()
153-
154-
# Extract session ID from response headers
155-
if is_initialization:
156-
new_session_id = response.headers.get(MCP_SESSION_ID_HEADER)
157-
if new_session_id:
158-
session_id = new_session_id
159-
logger.info(f"Received session ID: {session_id}")
160-
161-
# Handle different response types
162-
content_type = response.headers.get("content-type", "").lower()
163-
164-
if content_type.startswith(CONTENT_TYPE_JSON):
165-
try:
166-
content = await response.aread()
167-
json_message = JSONRPCMessage.model_validate_json(
168-
content
169-
)
170-
session_message = SessionMessage(json_message)
171-
await read_stream_writer.send(session_message)
172-
except Exception as exc:
173-
logger.error(f"Error parsing JSON response: {exc}")
174-
await read_stream_writer.send(exc)
175-
176-
elif content_type.startswith(CONTENT_TYPE_SSE):
177-
# Parse SSE events from the response
178-
try:
179-
event_source = EventSource(response)
180-
async for sse in event_source.aiter_sse():
181-
if sse.event == "message":
182-
try:
183-
message = (
184-
JSONRPCMessage.model_validate_json(
185-
sse.data
186-
)
143+
if is_resumption:
144+
# For resumption, use GET with SSE
145+
async with aconnect_sse(
146+
client,
147+
"GET",
148+
url,
149+
headers=post_headers,
150+
timeout=httpx.Timeout(
151+
timeout.seconds, read=sse_read_timeout.seconds
152+
),
153+
) as event_source:
154+
event_source.response.raise_for_status()
155+
logger.debug("Resumption GET SSE connection established")
156+
157+
async for sse in event_source.aiter_sse():
158+
if sse.event == "message":
159+
try:
160+
message = JSONRPCMessage.model_validate_json(
161+
sse.data
162+
)
163+
logger.debug(
164+
f"Resumption GET message: {message}"
165+
)
166+
167+
# If this is a response and we have original_request_id, replace it
168+
if (
169+
original_request_id is not None
170+
and isinstance(
171+
message.root,
172+
(JSONRPCResponse, JSONRPCError),
187173
)
188-
session_message = SessionMessage(message)
189-
await read_stream_writer.send(
190-
session_message
174+
):
175+
message.root.id = original_request_id
176+
177+
session_message = SessionMessage(message)
178+
await read_stream_writer.send(session_message)
179+
180+
# Call resumption token callback if we have an ID
181+
if (
182+
sse.id
183+
and metadata
184+
and metadata.on_resumption_token_update
185+
):
186+
await metadata.on_resumption_token_update(
187+
sse.id
191188
)
192-
except Exception as exc:
193-
logger.exception("Error parsing message")
194-
await read_stream_writer.send(exc)
195-
else:
196-
logger.warning(f"Unknown event: {sse.event}")
197-
198-
except Exception as e:
199-
logger.exception("Error reading SSE stream:")
200-
await read_stream_writer.send(e)
201-
202-
else:
203-
# For 202 Accepted with no body
189+
190+
# If this is a response or error, we're done
191+
if isinstance(
192+
message.root,
193+
(JSONRPCResponse, JSONRPCError),
194+
):
195+
break
196+
except Exception as exc:
197+
logger.error(
198+
f"Error parsing resumption GET message: {exc}"
199+
)
200+
await read_stream_writer.send(exc)
201+
else:
202+
logger.warning(
203+
f"Unknown SSE event from resumption GET: {sse.event}"
204+
)
205+
else:
206+
# Normal POST request
207+
async with client.stream(
208+
"POST",
209+
url,
210+
json=message.model_dump(
211+
by_alias=True, mode="json", exclude_none=True
212+
),
213+
headers=post_headers,
214+
) as response:
204215
if response.status_code == 202:
205216
logger.debug("Received 202 Accepted")
206217
continue
218+
# Check for 404 (session expired/invalid)
219+
if response.status_code == 404:
220+
if isinstance(message.root, JSONRPCRequest):
221+
jsonrpc_error = JSONRPCError(
222+
jsonrpc="2.0",
223+
id=message.root.id,
224+
error=ErrorData(
225+
code=32600,
226+
message="Session terminated",
227+
),
228+
)
229+
session_message = SessionMessage(
230+
JSONRPCMessage(jsonrpc_error)
231+
)
232+
await read_stream_writer.send(session_message)
233+
continue
234+
response.raise_for_status()
235+
236+
# Extract session ID from response headers
237+
if is_initialization:
238+
new_session_id = response.headers.get(
239+
MCP_SESSION_ID_HEADER
240+
)
241+
if new_session_id:
242+
session_id = new_session_id
243+
logger.info(f"Received session ID: {session_id}")
244+
245+
# Handle different response types
246+
content_type = response.headers.get(
247+
"content-type", ""
248+
).lower()
249+
250+
if content_type.startswith(CONTENT_TYPE_JSON):
251+
try:
252+
content = await response.aread()
253+
json_message = JSONRPCMessage.model_validate_json(
254+
content
255+
)
256+
session_message = SessionMessage(json_message)
257+
await read_stream_writer.send(session_message)
258+
except Exception as exc:
259+
logger.error(f"Error parsing JSON response: {exc}")
260+
await read_stream_writer.send(exc)
261+
262+
elif content_type.startswith(CONTENT_TYPE_SSE):
263+
# Parse SSE events from the response
264+
try:
265+
event_source = EventSource(response)
266+
async for sse in event_source.aiter_sse():
267+
if sse.event == "message":
268+
try:
269+
message = (
270+
JSONRPCMessage.model_validate_json(
271+
sse.data
272+
)
273+
)
274+
session_message = SessionMessage(
275+
message
276+
)
277+
await read_stream_writer.send(
278+
session_message
279+
)
280+
281+
# Call the resumption token callback if we have an ID
282+
if (
283+
sse.id
284+
and metadata
285+
and metadata.on_resumption_token_update
286+
):
287+
await metadata.on_resumption_token_update(
288+
sse.id
289+
)
290+
except Exception as exc:
291+
logger.exception(
292+
"Error parsing message"
293+
)
294+
await read_stream_writer.send(exc)
295+
else:
296+
logger.warning(
297+
f"Unknown event: {sse.event}"
298+
)
299+
300+
except Exception as e:
301+
logger.exception("Error reading SSE stream:")
302+
await read_stream_writer.send(e)
207303

208-
error_msg = f"Unexpected content type: {content_type}"
209-
logger.error(error_msg)
210-
await read_stream_writer.send(ValueError(error_msg))
304+
else:
305+
# For 202 Accepted with no body
306+
if response.status_code == 202:
307+
logger.debug("Received 202 Accepted")
308+
continue
309+
310+
error_msg = f"Unexpected content type: {content_type}"
311+
logger.error(error_msg)
312+
await read_stream_writer.send(ValueError(error_msg))
211313

212314
except Exception as exc:
213315
logger.error(f"Error in post_writer: {exc}")
@@ -240,6 +342,13 @@ async def terminate_session():
240342
except Exception as exc:
241343
logger.warning(f"Session termination failed: {exc}")
242344

345+
def get_session_id() -> str | None:
346+
"""
347+
Get the current session ID.
348+
"""
349+
nonlocal session_id
350+
return session_id
351+
243352
async with anyio.create_task_group() as tg:
244353
try:
245354
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
@@ -259,7 +368,7 @@ async def terminate_session():
259368
) as client:
260369
tg.start_soon(post_writer, client)
261370
try:
262-
yield read_stream, write_stream, terminate_session
371+
yield read_stream, write_stream, terminate_session, get_session_id
263372
finally:
264373
tg.cancel_scope.cancel()
265374
finally:

0 commit comments

Comments
 (0)