Skip to content

Commit 5cc10fa

Browse files
committed
run precommit
1 parent bd70a51 commit 5cc10fa

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

tests/shared/test_sse.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import json
22
from collections.abc import AsyncGenerator
3-
from anyio.abc import TaskGroup
43
from typing import Any
54

65
import anyio
76
import httpx
8-
from mcp.shared._httpx_utils import McpHttpClientFactory
97
import pytest
8+
from anyio.abc import TaskGroup
109
from inline_snapshot import snapshot
1110
from pydantic import AnyUrl
1211
from starlette.applications import Starlette
@@ -21,6 +20,7 @@
2120
from mcp.server.sse import SseServerTransport
2221
from mcp.server.streaming_asgi_transport import StreamingASGITransport
2322
from mcp.server.transport_security import TransportSecuritySettings
23+
from mcp.shared._httpx_utils import McpHttpClientFactory
2424
from mcp.shared.exceptions import McpError
2525
from mcp.types import (
2626
EmptyResult,
@@ -67,23 +67,23 @@ async def handle_list_tools() -> list[Tool]:
6767
async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]:
6868
return [TextContent(type="text", text=f"Called {name}")]
6969

70+
7071
def create_asgi_client_factory(app: Starlette, tg: TaskGroup) -> McpHttpClientFactory:
7172
"""Factory function to create httpx clients with StreamingASGITransport"""
73+
7274
def asgi_client_factory(
7375
headers: dict[str, str] | None = None,
7476
timeout: httpx.Timeout | None = None,
7577
auth: httpx.Auth | None = None,
7678
) -> httpx.AsyncClient:
7779
transport = StreamingASGITransport(app=app, task_group=tg)
7880
return httpx.AsyncClient(
79-
transport=transport,
80-
base_url=TEST_SERVER_BASE_URL,
81-
headers=headers,
82-
timeout=timeout,
83-
auth=auth
81+
transport=transport, base_url=TEST_SERVER_BASE_URL, headers=headers, timeout=timeout, auth=auth
8482
)
83+
8584
return asgi_client_factory
8685

86+
8787
def create_sse_app(server: Server) -> Starlette:
8888
"""Helper to create SSE app with given server"""
8989
security_settings = TransportSecuritySettings(
@@ -107,12 +107,14 @@ async def handle_sse(request: Request) -> Response:
107107

108108
# Test fixtures
109109

110+
110111
@pytest.fixture()
111112
def server_app() -> Starlette:
112113
"""Create test Starlette app with SSE transport"""
113114
app = create_sse_app(ServerTest())
114115
return app
115116

117+
116118
@pytest.fixture()
117119
async def tg() -> AsyncGenerator[TaskGroup, None]:
118120
async with anyio.create_task_group() as tg:
@@ -126,11 +128,14 @@ async def http_client(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ht
126128
async with httpx.AsyncClient(transport=transport, base_url=TEST_SERVER_BASE_URL) as client:
127129
yield client
128130

131+
129132
@pytest.fixture()
130133
async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
131134
asgi_client_factory = create_asgi_client_factory(server_app, tg)
132-
133-
async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams:
135+
136+
async with sse_client(
137+
f"{TEST_SERVER_BASE_URL}/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
138+
) as streams:
134139
async with ClientSession(*streams) as session:
135140
yield session
136141

@@ -139,7 +144,7 @@ async def sse_client_session(tg: TaskGroup, server_app: Starlette) -> AsyncGener
139144
@pytest.mark.anyio
140145
async def test_raw_sse_connection(http_client: httpx.AsyncClient) -> None:
141146
"""Test the SSE connection establishment simply with an HTTP client."""
142-
147+
143148
async def connection_test() -> None:
144149
async with http_client.stream("GET", "/sse") as response:
145150
assert response.status_code == 200
@@ -227,13 +232,18 @@ async def mounted_server_app(server_app: Starlette) -> Starlette:
227232

228233

229234
@pytest.fixture()
230-
async def sse_client_mounted_server_app_session(tg: TaskGroup, mounted_server_app: Starlette) -> AsyncGenerator[ClientSession, None]:
235+
async def sse_client_mounted_server_app_session(
236+
tg: TaskGroup, mounted_server_app: Starlette
237+
) -> AsyncGenerator[ClientSession, None]:
231238
asgi_client_factory = create_asgi_client_factory(mounted_server_app, tg)
232-
233-
async with sse_client(f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory) as streams:
239+
240+
async with sse_client(
241+
f"{TEST_SERVER_BASE_URL}/mounted_app/sse", sse_read_timeout=0.5, httpx_client_factory=asgi_client_factory
242+
) as streams:
234243
async with ClientSession(*streams) as session:
235244
yield session
236245

246+
237247
@pytest.mark.anyio
238248
async def test_sse_client_basic_connection_mounted_app(sse_client_mounted_server_app_session: ClientSession) -> None:
239249
session = sse_client_mounted_server_app_session
@@ -296,6 +306,7 @@ async def context_server_app() -> Starlette:
296306
app = create_sse_app(RequestContextServer())
297307
return app
298308

309+
299310
@pytest.mark.anyio
300311
async def test_request_context_propagation(tg: TaskGroup, context_server_app: Starlette) -> None:
301312
"""Test that request context is properly propagated through SSE transport."""
@@ -337,14 +348,16 @@ async def test_request_context_propagation(tg: TaskGroup, context_server_app: St
337348
async def test_request_context_isolation(tg: TaskGroup, context_server_app: Starlette) -> None:
338349
"""Test that request contexts are isolated between different SSE clients."""
339350
contexts: list[dict[str, Any]] = []
340-
351+
341352
asgi_client_factory = create_asgi_client_factory(context_server_app, tg)
342353

343354
# Create multiple clients with different headers
344355
for i in range(3):
345356
headers = {"X-Request-Id": f"request-{i}", "X-Custom-Value": f"value-{i}"}
346357

347-
async with sse_client(f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory) as (
358+
async with sse_client(
359+
f"{TEST_SERVER_BASE_URL}/sse", headers=headers, httpx_client_factory=asgi_client_factory
360+
) as (
348361
read_stream,
349362
write_stream,
350363
):

0 commit comments

Comments
 (0)