4040GetSessionIdCallback = Callable [[], str | None ]
4141
4242MCP_SESSION_ID = "mcp-session-id"
43+ MCP_PROTOCOL_VERSION = "MCP-Protocol-Version"
4344LAST_EVENT_ID = "last-event-id"
4445CONTENT_TYPE = "content-type"
4546ACCEPT = "Accept"
@@ -100,17 +101,22 @@ def __init__(
100101 self .sse_read_timeout = sse_read_timeout
101102 self .auth = auth
102103 self .session_id : str | None = None
104+ self .protocol_version : str | None = None
103105 self .request_headers = {
104106 ACCEPT : f"{ JSON } , { SSE } " ,
105107 CONTENT_TYPE : JSON ,
106108 ** self .headers ,
107109 }
108110
109- def _update_headers_with_session (self , base_headers : dict [str , str ]) -> dict [str , str ]:
110- """Update headers with session ID if available."""
111+ def _update_headers_with_session (
112+ self , base_headers : dict [str , str ]
113+ ) -> dict [str , str ]:
114+ """Update headers with session ID and protocol version if available."""
111115 headers = base_headers .copy ()
112116 if self .session_id :
113117 headers [MCP_SESSION_ID ] = self .session_id
118+ if self .protocol_version :
119+ headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
114120 return headers
115121
116122 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
@@ -131,19 +137,36 @@ def _maybe_extract_session_id_from_response(
131137 self .session_id = new_session_id
132138 logger .info (f"Received session ID: { self .session_id } " )
133139
140+ def _maybe_extract_protocol_version_from_message (
141+ self ,
142+ message : JSONRPCMessage ,
143+ ) -> None :
144+ """Extract protocol version from initialization response message."""
145+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
146+ # Check if result has protocolVersion field
147+ result = message .root .result
148+ if "protocolVersion" in result :
149+ self .protocol_version = result ["protocolVersion" ]
150+ logger .info (f"Negotiated protocol version: { self .protocol_version } " )
151+
134152 async def _handle_sse_event (
135153 self ,
136154 sse : ServerSentEvent ,
137155 read_stream_writer : StreamWriter ,
138156 original_request_id : RequestId | None = None ,
139157 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
158+ is_initialization : bool = False ,
140159 ) -> bool :
141160 """Handle an SSE event, returning True if the response is complete."""
142161 if sse .event == "message" :
143162 try :
144163 message = JSONRPCMessage .model_validate_json (sse .data )
145164 logger .debug (f"SSE message: { message } " )
146165
166+ # Extract protocol version from initialization response
167+ if is_initialization :
168+ self ._maybe_extract_protocol_version_from_message (message )
169+
147170 # If this is a response and we have original_request_id, replace it
148171 if original_request_id is not None and isinstance (message .root , JSONRPCResponse | JSONRPCError ):
149172 message .root .id = original_request_id
@@ -265,9 +288,11 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
265288 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
266289
267290 if content_type .startswith (JSON ):
268- await self ._handle_json_response (response , ctx .read_stream_writer )
291+ await self ._handle_json_response (
292+ response , ctx .read_stream_writer , is_initialization
293+ )
269294 elif content_type .startswith (SSE ):
270- await self ._handle_sse_response (response , ctx )
295+ await self ._handle_sse_response (response , ctx , is_initialization )
271296 else :
272297 await self ._handle_unexpected_content_type (
273298 content_type ,
@@ -278,26 +303,42 @@ async def _handle_json_response(
278303 self ,
279304 response : httpx .Response ,
280305 read_stream_writer : StreamWriter ,
306+ is_initialization : bool = False ,
281307 ) -> None :
282308 """Handle JSON response from the server."""
283309 try :
284310 content = await response .aread ()
285311 message = JSONRPCMessage .model_validate_json (content )
312+
313+ # Extract protocol version from initialization response
314+ if is_initialization :
315+ self ._maybe_extract_protocol_version_from_message (message )
316+
286317 session_message = SessionMessage (message )
287318 await read_stream_writer .send (session_message )
288319 except Exception as exc :
289320 logger .error (f"Error parsing JSON response: { exc } " )
290321 await read_stream_writer .send (exc )
291322
292- async def _handle_sse_response (self , response : httpx .Response , ctx : RequestContext ) -> None :
323+ async def _handle_sse_response (
324+ self ,
325+ response : httpx .Response ,
326+ ctx : RequestContext ,
327+ is_initialization : bool = False ,
328+ ) -> None :
293329 """Handle SSE response from the server."""
294330 try :
295331 event_source = EventSource (response )
296332 async for sse in event_source .aiter_sse ():
297333 is_complete = await self ._handle_sse_event (
298334 sse ,
299335 ctx .read_stream_writer ,
300- resumption_callback = (ctx .metadata .on_resumption_token_update if ctx .metadata else None ),
336+ resumption_callback = (
337+ ctx .metadata .on_resumption_token_update
338+ if ctx .metadata
339+ else None
340+ ),
341+ is_initialization = is_initialization ,
301342 )
302343 # If the SSE event indicates completion, like returning respose/error
303344 # break the loop
0 commit comments