@@ -64,6 +64,7 @@ class RequestContext:
6464
6565 client : httpx .AsyncClient
6666 headers : dict [str , str ]
67+ extensions : dict [str , str ] | None
6768 session_id : str | None
6869 session_message : SessionMessage
6970 metadata : ClientMessageMetadata | None
@@ -78,6 +79,7 @@ def __init__(
7879 self ,
7980 url : str ,
8081 headers : dict [str , str ] | None = None ,
82+ extensions : dict [str , str ] | None = None ,
8183 timeout : float | timedelta = 30 ,
8284 sse_read_timeout : float | timedelta = 60 * 5 ,
8385 auth : httpx .Auth | None = None ,
@@ -87,12 +89,14 @@ def __init__(
8789 Args:
8890 url: The endpoint URL.
8991 headers: Optional headers to include in requests.
92+ extensions: Optional extensions to include in requests.
9093 timeout: HTTP timeout for regular operations.
9194 sse_read_timeout: Timeout for SSE read operations.
9295 auth: Optional HTTPX authentication handler.
9396 """
9497 self .url = url
9598 self .headers = headers or {}
99+ self .extensions = extensions .copy () if extensions else {}
96100 self .timeout = timeout .total_seconds () if isinstance (timeout , timedelta ) else timeout
97101 self .sse_read_timeout = (
98102 sse_read_timeout .total_seconds () if isinstance (sse_read_timeout , timedelta ) else sse_read_timeout
@@ -115,6 +119,12 @@ def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, st
115119 headers [MCP_PROTOCOL_VERSION ] = self .protocol_version
116120 return headers
117121
122+ def _prepare_request_extensions (self , base_extensions : dict [str , str ] | None ) -> dict [str , str ]:
123+ """Update extensions with session-specific data if available."""
124+ extensions = base_extensions .copy () if base_extensions else {}
125+ # Add any session-specific extensions here if needed
126+ return extensions
127+
118128 def _is_initialization_request (self , message : JSONRPCMessage ) -> bool :
119129 """Check if the message is an initialization request."""
120130 return isinstance (message .root , JSONRPCRequest ) and message .root .method == "initialize"
@@ -138,16 +148,14 @@ def _maybe_extract_protocol_version_from_message(
138148 message : JSONRPCMessage ,
139149 ) -> None :
140150 """Extract protocol version from initialization response message."""
141- if isinstance (message .root , JSONRPCResponse ) and message .root .result : # pragma: no branch
151+ if isinstance (message .root , JSONRPCResponse ) and message .root .result :
142152 try :
143153 # Parse the result as InitializeResult for type safety
144154 init_result = InitializeResult .model_validate (message .root .result )
145155 self .protocol_version = str (init_result .protocolVersion )
146156 logger .info (f"Negotiated protocol version: { self .protocol_version } " )
147- except Exception as exc : # pragma: no cover
148- logger .warning (
149- f"Failed to parse initialization response as InitializeResult: { exc } "
150- ) # pragma: no cover
157+ except Exception as exc :
158+ logger .warning (f"Failed to parse initialization response as InitializeResult: { exc } " )
151159 logger .warning (f"Raw result: { message .root .result } " )
152160
153161 async def _handle_sse_event (
@@ -160,9 +168,6 @@ async def _handle_sse_event(
160168 ) -> bool :
161169 """Handle an SSE event, returning True if the response is complete."""
162170 if sse .event == "message" :
163- # Skip empty data (keep-alive pings)
164- if not sse .data :
165- return False
166171 try :
167172 message = JSONRPCMessage .model_validate_json (sse .data )
168173 logger .debug (f"SSE message: { message } " )
@@ -186,11 +191,11 @@ async def _handle_sse_event(
186191 # Otherwise, return False to continue listening
187192 return isinstance (message .root , JSONRPCResponse | JSONRPCError )
188193
189- except Exception as exc : # pragma: no cover
194+ except Exception as exc :
190195 logger .exception ("Error parsing SSE message" )
191196 await read_stream_writer .send (exc )
192197 return False
193- else : # pragma: no cover
198+ else :
194199 logger .warning (f"Unknown SSE event: { sse .event } " )
195200 return False
196201
@@ -220,19 +225,19 @@ async def handle_get_stream(
220225 await self ._handle_sse_event (sse , read_stream_writer )
221226
222227 except Exception as exc :
223- logger .debug (f"GET stream error (non-fatal): { exc } " ) # pragma: no cover
228+ logger .debug (f"GET stream error (non-fatal): { exc } " )
224229
225230 async def _handle_resumption_request (self , ctx : RequestContext ) -> None :
226231 """Handle a resumption request using GET with SSE."""
227232 headers = self ._prepare_request_headers (ctx .headers )
228233 if ctx .metadata and ctx .metadata .resumption_token :
229234 headers [LAST_EVENT_ID ] = ctx .metadata .resumption_token
230235 else :
231- raise ResumptionError ("Resumption request requires a resumption token" ) # pragma: no cover
236+ raise ResumptionError ("Resumption request requires a resumption token" )
232237
233238 # Extract original request ID to map responses
234239 original_request_id = None
235- if isinstance (ctx .session_message .message .root , JSONRPCRequest ): # pragma: no branch
240+ if isinstance (ctx .session_message .message .root , JSONRPCRequest ):
236241 original_request_id = ctx .session_message .message .root .id
237242
238243 async with aconnect_sse (
@@ -245,7 +250,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
245250 event_source .response .raise_for_status ()
246251 logger .debug ("Resumption GET SSE connection established" )
247252
248- async for sse in event_source .aiter_sse (): # pragma: no branch
253+ async for sse in event_source .aiter_sse ():
249254 is_complete = await self ._handle_sse_event (
250255 sse ,
251256 ctx .read_stream_writer ,
@@ -259,6 +264,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
259264 async def _handle_post_request (self , ctx : RequestContext ) -> None :
260265 """Handle a POST request with response processing."""
261266 headers = self ._prepare_request_headers (ctx .headers )
267+ extensions = self ._prepare_request_extensions (ctx .extensions )
262268 message = ctx .session_message .message
263269 is_initialization = self ._is_initialization_request (message )
264270
@@ -267,18 +273,19 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
267273 self .url ,
268274 json = message .model_dump (by_alias = True , mode = "json" , exclude_none = True ),
269275 headers = headers ,
276+ extensions = extensions ,
270277 ) as response :
271278 if response .status_code == 202 :
272279 logger .debug ("Received 202 Accepted" )
273280 return
274281
275- if response .status_code == 404 : # pragma: no branch
282+ if response .status_code == 404 :
276283 if isinstance (message .root , JSONRPCRequest ):
277- await self ._send_session_terminated_error ( # pragma: no cover
278- ctx .read_stream_writer , # pragma: no cover
279- message .root .id , # pragma: no cover
280- ) # pragma: no cover
281- return # pragma: no cover
284+ await self ._send_session_terminated_error (
285+ ctx .read_stream_writer ,
286+ message .root .id ,
287+ )
288+ return
282289
283290 response .raise_for_status ()
284291 if is_initialization :
@@ -293,10 +300,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
293300 elif content_type .startswith (SSE ):
294301 await self ._handle_sse_response (response , ctx , is_initialization )
295302 else :
296- await self ._handle_unexpected_content_type ( # pragma: no cover
297- content_type , # pragma: no cover
298- ctx .read_stream_writer , # pragma: no cover
299- ) # pragma: no cover
303+ await self ._handle_unexpected_content_type (
304+ content_type ,
305+ ctx .read_stream_writer ,
306+ )
300307
301308 async def _handle_json_response (
302309 self ,
@@ -315,7 +322,7 @@ async def _handle_json_response(
315322
316323 session_message = SessionMessage (message )
317324 await read_stream_writer .send (session_message )
318- except Exception as exc : # pragma: no cover
325+ except Exception as exc :
319326 logger .exception ("Error parsing JSON response" )
320327 await read_stream_writer .send (exc )
321328
@@ -328,7 +335,7 @@ async def _handle_sse_response(
328335 """Handle SSE response from the server."""
329336 try :
330337 event_source = EventSource (response )
331- async for sse in event_source .aiter_sse (): # pragma: no branch
338+ async for sse in event_source .aiter_sse ():
332339 is_complete = await self ._handle_sse_event (
333340 sse ,
334341 ctx .read_stream_writer ,
@@ -341,18 +348,18 @@ async def _handle_sse_response(
341348 await response .aclose ()
342349 break
343350 except Exception as e :
344- logger .exception ("Error reading SSE stream:" ) # pragma: no cover
345- await ctx .read_stream_writer .send (e ) # pragma: no cover
351+ logger .exception ("Error reading SSE stream:" )
352+ await ctx .read_stream_writer .send (e )
346353
347354 async def _handle_unexpected_content_type (
348355 self ,
349356 content_type : str ,
350357 read_stream_writer : StreamWriter ,
351- ) -> None : # pragma: no cover
358+ ) -> None :
352359 """Handle unexpected content type in response."""
353- error_msg = f"Unexpected content type: { content_type } " # pragma: no cover
354- logger .error (error_msg ) # pragma: no cover
355- await read_stream_writer .send (ValueError (error_msg )) # pragma: no cover
360+ error_msg = f"Unexpected content type: { content_type } "
361+ logger .error (error_msg )
362+ await read_stream_writer .send (ValueError (error_msg ))
356363
357364 async def _send_session_terminated_error (
358365 self ,
@@ -400,6 +407,7 @@ async def post_writer(
400407 ctx = RequestContext (
401408 client = client ,
402409 headers = self .request_headers ,
410+ extensions = self .extensions ,
403411 session_id = self .session_id ,
404412 session_message = session_message ,
405413 metadata = metadata ,
@@ -420,12 +428,12 @@ async def handle_request_async():
420428 await handle_request_async ()
421429
422430 except Exception :
423- logger .exception ("Error in post_writer" ) # pragma: no cover
431+ logger .exception ("Error in post_writer" )
424432 finally :
425433 await read_stream_writer .aclose ()
426434 await write_stream .aclose ()
427435
428- async def terminate_session (self , client : httpx .AsyncClient ) -> None : # pragma: no cover
436+ async def terminate_session (self , client : httpx .AsyncClient ) -> None :
429437 """Terminate the session by sending a DELETE request."""
430438 if not self .session_id :
431439 return
@@ -450,6 +458,7 @@ def get_session_id(self) -> str | None:
450458async def streamablehttp_client (
451459 url : str ,
452460 headers : dict [str , str ] | None = None ,
461+ extensions : dict [str , str ] | None = None ,
453462 timeout : float | timedelta = 30 ,
454463 sse_read_timeout : float | timedelta = 60 * 5 ,
455464 terminate_on_close : bool = True ,
@@ -475,7 +484,14 @@ async def streamablehttp_client(
475484 - write_stream: Stream for sending messages to the server
476485 - get_session_id_callback: Function to retrieve the current session ID
477486 """
478- transport = StreamableHTTPTransport (url , headers , timeout , sse_read_timeout , auth )
487+ transport = StreamableHTTPTransport (
488+ url = url ,
489+ headers = headers ,
490+ extensions = extensions ,
491+ timeout = timeout ,
492+ sse_read_timeout = sse_read_timeout ,
493+ auth = auth ,
494+ )
479495
480496 read_stream_writer , read_stream = anyio .create_memory_object_stream [SessionMessage | Exception ](0 )
481497 write_stream , write_stream_reader = anyio .create_memory_object_stream [SessionMessage ](0 )
0 commit comments