Skip to content

Commit 3c4cf10

Browse files
committed
terminations of a session
1 parent 27bc01e commit 3c4cf10

File tree

2 files changed

+281
-132
lines changed

2 files changed

+281
-132
lines changed

src/mcp/server/streamableHttp.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,16 @@ def __init__(
7777
ValueError: If the session ID contains invalid characters.
7878
"""
7979
if mcp_session_id is not None and (
80-
not SESSION_ID_PATTERN.match(mcp_session_id) or
81-
SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None
80+
not SESSION_ID_PATTERN.match(mcp_session_id)
81+
or SESSION_ID_PATTERN.fullmatch(mcp_session_id) is None
8282
):
8383
raise ValueError(
8484
"Session ID must only contain visible ASCII characters (0x21-0x7E)"
8585
)
8686

8787
self.mcp_session_id = mcp_session_id
8888
self._request_streams = {}
89+
self._terminated = False
8990

9091
def _create_error_response(
9192
self,
@@ -126,6 +127,14 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
126127
send: ASGI send function
127128
"""
128129
request = Request(scope, receive)
130+
if self._terminated:
131+
# If the session has been terminated, return 404 Not Found
132+
response = self._create_error_response(
133+
"Not Found: Session has been terminated",
134+
HTTPStatus.NOT_FOUND,
135+
)
136+
await response(scope, receive, send)
137+
return
129138

130139
if request.method == "POST":
131140
await self._handle_post_request(scope, request, receive, send)
@@ -192,7 +201,6 @@ async def _handle_post_request(
192201
raise ValueError(
193202
"No read stream writer available. Ensure connect() is called first."
194203
)
195-
return
196204
try:
197205
# Check Accept headers
198206
has_json, has_sse = self._check_accept_headers(request)
@@ -417,7 +425,6 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
417425
# Validate session ID if server has one
418426
if not await self._validate_session(request, send):
419427
return
420-
421428
# Validate Accept header - must include text/event-stream
422429
_, has_sse = self._check_accept_headers(request)
423430

@@ -454,9 +461,46 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None:
454461
)
455462
await response(request.scope, request.receive, send)
456463
return
464+
457465
if not await self._validate_session(request, send):
458466
return
459-
# TODO : Implement session termination logic
467+
468+
# Terminate the session
469+
self._terminate_session()
470+
471+
# Return success response
472+
response = self._create_error_response(
473+
"Session terminated",
474+
HTTPStatus.OK,
475+
)
476+
await response(request.scope, request.receive, send)
477+
478+
def _terminate_session(self) -> None:
479+
"""
480+
Terminate the current session, closing all streams and marking as terminated.
481+
482+
Once terminated, all requests with this session ID will receive 404 Not Found.
483+
"""
484+
485+
self._terminated = True
486+
logger.info(f"Terminating session: {self.mcp_session_id}")
487+
488+
# We need a copy of the keys to avoid modification during iteration
489+
request_stream_keys = list(self._request_streams.keys())
490+
491+
# Close all request streams (synchronously)
492+
for key in request_stream_keys:
493+
try:
494+
# Get the stream
495+
stream = self._request_streams.get(key)
496+
if stream:
497+
# We must use close() here, not aclose() since this is a sync method
498+
stream.close()
499+
except Exception as e:
500+
logger.debug(f"Error closing stream {key} during termination: {e}")
501+
502+
# Clear the request streams dictionary immediately
503+
self._request_streams.clear()
460504

461505
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
462506
"""
@@ -599,10 +643,16 @@ async def message_router():
599643
# Yield the streams for the caller to use
600644
yield read_stream, write_stream
601645
finally:
602-
# Clean up any remaining request streams
603646
for stream in list(self._request_streams.values()):
604647
try:
605648
await stream.aclose()
606649
except Exception:
607650
pass
608651
self._request_streams.clear()
652+
# Clean up read/write streams
653+
if self._read_stream_writer:
654+
try:
655+
await self._read_stream_writer.aclose()
656+
except Exception:
657+
pass
658+
self._read_stream_writer = None

0 commit comments

Comments
 (0)