Skip to content

Commit 7ffd5ba

Browse files
committed
terminate session
1 parent 110526d commit 7ffd5ba

File tree

2 files changed

+130
-17
lines changed

2 files changed

+130
-17
lines changed

src/mcp/client/streamableHttp.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
import httpx
1515
from httpx_sse import EventSource, aconnect_sse
1616

17-
from mcp.types import JSONRPCMessage, JSONRPCNotification, JSONRPCRequest
17+
from mcp.types import (
18+
ErrorData,
19+
JSONRPCError,
20+
JSONRPCMessage,
21+
JSONRPCNotification,
22+
JSONRPCRequest,
23+
)
1824

1925
logger = logging.getLogger(__name__)
2026

@@ -39,6 +45,9 @@ async def streamablehttp_client(
3945
4046
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
4147
event before disconnecting. All other HTTP operations are controlled by `timeout`.
48+
49+
Yields:
50+
Tuple of (read_stream, write_stream, terminate_callback)
4251
"""
4352

4453
read_stream_writer, read_stream = anyio.create_memory_object_stream[
@@ -122,9 +131,19 @@ async def post_writer():
122131
headers=post_headers,
123132
) as new_response:
124133
response = new_response
125-
else:
126-
response.raise_for_status()
127-
134+
elif isinstance(message.root, JSONRPCRequest):
135+
jsonrpc_error = JSONRPCError(
136+
jsonrpc="2.0",
137+
id=message.root.id,
138+
error=ErrorData(
139+
code=32600,
140+
message="Session terminated",
141+
),
142+
)
143+
await read_stream_writer.send(
144+
JSONRPCMessage(jsonrpc_error)
145+
)
146+
continue
128147
response.raise_for_status()
129148

130149
# Extract session ID from response headers
@@ -204,7 +223,6 @@ async def post_writer():
204223

205224
except Exception as exc:
206225
logger.error(f"Error in post_writer: {exc}")
207-
await read_stream_writer.send(exc)
208226
finally:
209227
await read_stream_writer.aclose()
210228
await write_stream.aclose()
@@ -223,7 +241,11 @@ async def get_stream():
223241
get_headers[MCP_SESSION_ID_HEADER] = session_id
224242

225243
async with aconnect_sse(
226-
client, "GET", url, headers=get_headers
244+
client,
245+
"GET",
246+
url,
247+
headers=get_headers,
248+
timeout=httpx.Timeout(timeout, read=sse_read_timeout),
227249
) as event_source:
228250
event_source.response.raise_for_status()
229251
logger.debug("GET SSE connection established")
@@ -251,8 +273,35 @@ async def get_stream():
251273

252274
tg.start_soon(post_writer)
253275

276+
async def terminate_session():
277+
"""
278+
Terminate the session by sending a DELETE request.
279+
"""
280+
nonlocal session_id
281+
if not session_id:
282+
return # No session to terminate
283+
284+
try:
285+
delete_headers = request_headers.copy()
286+
delete_headers[MCP_SESSION_ID_HEADER] = session_id
287+
288+
response = await client.delete(
289+
url,
290+
headers=delete_headers,
291+
)
292+
293+
if response.status_code == 405:
294+
# Server doesn't allow client-initiated termination
295+
logger.debug("Server does not allow session termination")
296+
elif response.status_code != 200:
297+
logger.warning(
298+
f"Session termination failed: {response.status_code}"
299+
)
300+
except Exception as exc:
301+
logger.warning(f"Session termination failed: {exc}")
302+
254303
try:
255-
yield read_stream, write_stream
304+
yield read_stream, write_stream, terminate_session
256305
finally:
257306
tg.cancel_scope.cancel()
258307
finally:

tests/shared/test_streamableHttp.py

Lines changed: 74 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -650,17 +650,31 @@ async def http_client(basic_server, basic_server_url):
650650
@pytest.fixture
651651
async def initialized_client_session(basic_server, basic_server_url):
652652
"""Create initialized StreamableHTTP client session."""
653-
async with streamablehttp_client(f"{basic_server_url}/mcp") as streams:
654-
async with ClientSession(*streams) as session:
653+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
654+
read_stream,
655+
write_stream,
656+
_,
657+
):
658+
async with ClientSession(
659+
read_stream,
660+
write_stream,
661+
) as session:
655662
await session.initialize()
656663
yield session
657664

658665

659666
@pytest.mark.anyio
660667
async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url):
661668
"""Test basic client connection with initialization."""
662-
async with streamablehttp_client(f"{basic_server_url}/mcp") as streams:
663-
async with ClientSession(*streams) as session:
669+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
670+
read_stream,
671+
write_stream,
672+
_,
673+
):
674+
async with ClientSession(
675+
read_stream,
676+
write_stream,
677+
) as session:
664678
# Test initialization
665679
result = await session.initialize()
666680
assert isinstance(result, InitializeResult)
@@ -709,8 +723,15 @@ async def test_streamablehttp_client_session_persistence(
709723
basic_server, basic_server_url
710724
):
711725
"""Test that session ID persists across requests."""
712-
async with streamablehttp_client(f"{basic_server_url}/mcp") as streams:
713-
async with ClientSession(*streams) as session:
726+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
727+
read_stream,
728+
write_stream,
729+
_,
730+
):
731+
async with ClientSession(
732+
read_stream,
733+
write_stream,
734+
) as session:
714735
# Initialize the session
715736
result = await session.initialize()
716737
assert isinstance(result, InitializeResult)
@@ -732,8 +753,15 @@ async def test_streamablehttp_client_json_response(
732753
json_response_server, json_server_url
733754
):
734755
"""Test client with JSON response mode."""
735-
async with streamablehttp_client(f"{json_server_url}/mcp") as streams:
736-
async with ClientSession(*streams) as session:
756+
async with streamablehttp_client(f"{json_server_url}/mcp") as (
757+
read_stream,
758+
write_stream,
759+
_,
760+
):
761+
async with ClientSession(
762+
read_stream,
763+
write_stream,
764+
) as session:
737765
# Initialize the session
738766
result = await session.initialize()
739767
assert isinstance(result, InitializeResult)
@@ -767,8 +795,14 @@ async def message_handler(
767795
if isinstance(message, types.ServerNotification):
768796
notifications_received.append(message)
769797

770-
async with streamablehttp_client(f"{basic_server_url}/mcp") as streams:
771-
async with ClientSession(*streams, message_handler=message_handler) as session:
798+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
799+
read_stream,
800+
write_stream,
801+
_,
802+
):
803+
async with ClientSession(
804+
read_stream, write_stream, message_handler=message_handler
805+
) as session:
772806
# Initialize the session - this triggers the GET stream setup
773807
result = await session.initialize()
774808
assert isinstance(result, InitializeResult)
@@ -789,3 +823,33 @@ async def message_handler(
789823
assert (
790824
resource_update_found
791825
), "ResourceUpdatedNotification not received via GET stream"
826+
827+
828+
@pytest.mark.anyio
829+
async def test_streamablehttp_client_session_termination(
830+
basic_server, basic_server_url
831+
):
832+
"""Test client session termination functionality."""
833+
834+
# Create the streamablehttp_client with a custom httpx client to capture headers
835+
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
836+
read_stream,
837+
write_stream,
838+
terminate_session,
839+
):
840+
async with ClientSession(read_stream, write_stream) as session:
841+
# Initialize the session
842+
result = await session.initialize()
843+
assert isinstance(result, InitializeResult)
844+
845+
# Make a request to confirm session is working
846+
tools = await session.list_tools()
847+
assert len(tools.tools) == 2
848+
849+
# After exiting ClientSession context, explicitly terminate the session
850+
await terminate_session()
851+
with pytest.raises(
852+
McpError,
853+
match="Session terminated",
854+
):
855+
await session.list_tools()

0 commit comments

Comments
 (0)