Skip to content

Commit 373fbc7

Browse files
committed
Test cases
1 parent 42891f4 commit 373fbc7

File tree

2 files changed

+190
-92
lines changed

2 files changed

+190
-92
lines changed

tests/test_context_propagation.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import contextvars
2+
from collections.abc import Iterator
3+
from contextlib import contextmanager
4+
5+
import httpx
6+
import pytest
7+
from inline_snapshot import snapshot
8+
from starlette.types import Receive, Scope, Send
9+
10+
import mcp.types as types
11+
from mcp import Client
12+
from mcp.client.session import ClientSession
13+
from mcp.client.streamable_http import streamable_http_client
14+
from mcp.server import MCPServer
15+
16+
TEST_CONTEXTVAR = contextvars.ContextVar("test_var", default="initial")
17+
HOST = "testserver"
18+
19+
20+
@contextmanager
21+
def set_test_contextvar(value: str) -> Iterator[None]:
22+
token = TEST_CONTEXTVAR.set(value)
23+
try:
24+
yield
25+
finally:
26+
TEST_CONTEXTVAR.reset(token)
27+
28+
29+
@pytest.fixture
30+
def server() -> MCPServer:
31+
mcp = MCPServer("test_server")
32+
33+
# tool that returns the value of TEST_CONTEXT_VAR.
34+
@mcp.tool()
35+
async def my_tool() -> str:
36+
return TEST_CONTEXTVAR.get()
37+
38+
return mcp
39+
40+
41+
@pytest.mark.anyio
42+
async def test_memory_transport_client_to_server(server: MCPServer):
43+
async with Client(server) as client:
44+
with set_test_contextvar("client_value"):
45+
result = await client.call_tool(name="my_tool")
46+
47+
assert isinstance(result, types.CallToolResult)
48+
assert result.content == snapshot([types.TextContent(text="client_value")])
49+
50+
51+
@pytest.mark.anyio
52+
async def test_streamable_http_asgi_to_mcpserver(server: MCPServer):
53+
mcp_app = server.streamable_http_app(host=HOST)
54+
55+
# Wrap it in a middleware that sets the contextvar
56+
async def middleware_app(scope: Scope, receive: Receive, send: Send):
57+
with set_test_contextvar("from_middleware"):
58+
await mcp_app(scope, receive, send)
59+
60+
async with (
61+
mcp_app.router.lifespan_context(middleware_app),
62+
httpx.ASGITransport(app=middleware_app) as transport,
63+
httpx.AsyncClient(transport=transport) as client,
64+
streamable_http_client(f"http://{HOST}/mcp", http_client=client) as (read_stream, write_stream),
65+
ClientSession(read_stream, write_stream) as session,
66+
):
67+
await session.initialize()
68+
result = await session.call_tool("my_tool")
69+
assert result.content == snapshot([types.TextContent(text="from_middleware")])
70+
71+
72+
@pytest.mark.anyio
73+
async def test_streamable_http_mcpclient_to_httpx(server: MCPServer):
74+
mcp_app = server.streamable_http_app(host=HOST)
75+
76+
captured_context_var = None
77+
78+
# Intercepts the httpx call and capture the contextvar's value
79+
class ContextCapturingASGITransport(httpx.ASGITransport):
80+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
81+
nonlocal captured_context_var
82+
captured_context_var = TEST_CONTEXTVAR.get()
83+
return await super().handle_async_request(request)
84+
85+
async with (
86+
mcp_app.router.lifespan_context(mcp_app),
87+
ContextCapturingASGITransport(app=mcp_app) as transport,
88+
httpx.AsyncClient(transport=transport) as client,
89+
streamable_http_client(f"http://{HOST}/mcp", http_client=client) as (read_stream, write_stream),
90+
ClientSession(read_stream, write_stream) as session,
91+
):
92+
with set_test_contextvar("client_value_initialize"):
93+
await session.initialize()
94+
assert captured_context_var == snapshot("client_value_initialize")
95+
96+
with set_test_contextvar("client_value_call_tool"):
97+
await session.call_tool("my_tool")
98+
assert captured_context_var == snapshot("client_value_call_tool")

0 commit comments

Comments
 (0)