Skip to content

Commit f28be9c

Browse files
feat: track protocol version in StreamableHttpTransport
The client now tracks the negotiated protocol version from the server's response headers, enabling version-aware communication between client and server.
1 parent 89455a4 commit f28be9c

File tree

1 file changed

+47
-6
lines changed

1 file changed

+47
-6
lines changed

src/mcp/client/streamable_http.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
GetSessionIdCallback = Callable[[], str | None]
4141

4242
MCP_SESSION_ID = "mcp-session-id"
43+
MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4344
LAST_EVENT_ID = "last-event-id"
4445
CONTENT_TYPE = "content-type"
4546
ACCEPT = "Accept"
@@ -100,17 +101,22 @@ def __init__(
100101
self.sse_read_timeout = sse_read_timeout
101102
self.auth = auth
102103
self.session_id: str | None = None
104+
self.protocol_version: str | None = None
103105
self.request_headers = {
104106
ACCEPT: f"{JSON}, {SSE}",
105107
CONTENT_TYPE: JSON,
106108
**self.headers,
107109
}
108110

109-
def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
110-
"""Update headers with session ID if available."""
111+
def _update_headers_with_session(
112+
self, base_headers: dict[str, str]
113+
) -> dict[str, str]:
114+
"""Update headers with session ID and protocol version if available."""
111115
headers = base_headers.copy()
112116
if self.session_id:
113117
headers[MCP_SESSION_ID] = self.session_id
118+
if self.protocol_version:
119+
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
114120
return headers
115121

116122
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
@@ -131,19 +137,36 @@ def _maybe_extract_session_id_from_response(
131137
self.session_id = new_session_id
132138
logger.info(f"Received session ID: {self.session_id}")
133139

140+
def _maybe_extract_protocol_version_from_message(
141+
self,
142+
message: JSONRPCMessage,
143+
) -> None:
144+
"""Extract protocol version from initialization response message."""
145+
if isinstance(message.root, JSONRPCResponse) and message.root.result:
146+
# Check if result has protocolVersion field
147+
result = message.root.result
148+
if "protocolVersion" in result:
149+
self.protocol_version = result["protocolVersion"]
150+
logger.info(f"Negotiated protocol version: {self.protocol_version}")
151+
134152
async def _handle_sse_event(
135153
self,
136154
sse: ServerSentEvent,
137155
read_stream_writer: StreamWriter,
138156
original_request_id: RequestId | None = None,
139157
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
158+
is_initialization: bool = False,
140159
) -> bool:
141160
"""Handle an SSE event, returning True if the response is complete."""
142161
if sse.event == "message":
143162
try:
144163
message = JSONRPCMessage.model_validate_json(sse.data)
145164
logger.debug(f"SSE message: {message}")
146165

166+
# Extract protocol version from initialization response
167+
if is_initialization:
168+
self._maybe_extract_protocol_version_from_message(message)
169+
147170
# If this is a response and we have original_request_id, replace it
148171
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
149172
message.root.id = original_request_id
@@ -265,9 +288,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
265288
content_type = response.headers.get(CONTENT_TYPE, "").lower()
266289

267290
if content_type.startswith(JSON):
268-
await self._handle_json_response(response, ctx.read_stream_writer)
291+
await self._handle_json_response(
292+
response, ctx.read_stream_writer, is_initialization
293+
)
269294
elif content_type.startswith(SSE):
270-
await self._handle_sse_response(response, ctx)
295+
await self._handle_sse_response(response, ctx, is_initialization)
271296
else:
272297
await self._handle_unexpected_content_type(
273298
content_type,
@@ -278,26 +303,42 @@ async def _handle_json_response(
278303
self,
279304
response: httpx.Response,
280305
read_stream_writer: StreamWriter,
306+
is_initialization: bool = False,
281307
) -> None:
282308
"""Handle JSON response from the server."""
283309
try:
284310
content = await response.aread()
285311
message = JSONRPCMessage.model_validate_json(content)
312+
313+
# Extract protocol version from initialization response
314+
if is_initialization:
315+
self._maybe_extract_protocol_version_from_message(message)
316+
286317
session_message = SessionMessage(message)
287318
await read_stream_writer.send(session_message)
288319
except Exception as exc:
289320
logger.error(f"Error parsing JSON response: {exc}")
290321
await read_stream_writer.send(exc)
291322

292-
async def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
323+
async def _handle_sse_response(
324+
self,
325+
response: httpx.Response,
326+
ctx: RequestContext,
327+
is_initialization: bool = False,
328+
) -> None:
293329
"""Handle SSE response from the server."""
294330
try:
295331
event_source = EventSource(response)
296332
async for sse in event_source.aiter_sse():
297333
is_complete = await self._handle_sse_event(
298334
sse,
299335
ctx.read_stream_writer,
300-
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
336+
resumption_callback=(
337+
ctx.metadata.on_resumption_token_update
338+
if ctx.metadata
339+
else None
340+
),
341+
is_initialization=is_initialization,
301342
)
302343
# If the SSE event indicates completion, like returning respose/error
303344
# break the loop

0 commit comments

Comments
 (0)