Skip to content

Commit 0e2437d

Browse files
committed
Add initial test for auth context propagation in async context
1 parent ef7944c commit 0e2437d

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

tests/server/lowlevel/test_result_cache.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from contextlib import AsyncExitStack
2-
from unittest.mock import AsyncMock, Mock
2+
from unittest.mock import AsyncMock, Mock, PropertyMock
33

44
import anyio
55
import pytest
66

77
from mcp import types
8+
from mcp.server.auth.middleware.auth_context import (
9+
auth_context_var as user_context,
10+
)
811
from mcp.server.lowlevel.result_cache import ResultCache
912

1013

@@ -376,3 +379,60 @@ def test_timer():
376379
assert len(result.content) == 1
377380
assert type(result.content[0]) is types.TextContent
378381
assert result.content[0].text == "Unknown async token"
382+
383+
384+
@pytest.mark.anyio
385+
async def test_async_call_pass_auth():
386+
"""Tests async calls pass auth context to background thread"""
387+
388+
mock_user = Mock()
389+
type(mock_user).username = PropertyMock(return_value="mock_user")
390+
391+
mock_session = AsyncMock()
392+
mock_context = Mock()
393+
mock_context.session = mock_session
394+
result_cache = ResultCache(max_size=1, max_keep_alive=1)
395+
396+
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
397+
user = user_context.get()
398+
if user is None:
399+
return types.ServerResult(
400+
types.CallToolResult(
401+
content=[types.TextContent(type="text", text="unauthorised")],
402+
isError=True,
403+
)
404+
)
405+
else:
406+
return types.ServerResult(
407+
types.CallToolResult(
408+
content=[types.TextContent(type="text", text=str(user.username))]
409+
)
410+
)
411+
412+
async_call = types.CallToolAsyncRequest(
413+
method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test")
414+
)
415+
416+
async with AsyncExitStack() as stack:
417+
await stack.enter_async_context(result_cache)
418+
419+
user_context.set(mock_user)
420+
async_call_ref = await result_cache.start_call(
421+
test_call, async_call, mock_context
422+
)
423+
assert async_call_ref.token is not None
424+
425+
result = await result_cache.get_result(
426+
types.GetToolAsyncResultRequest(
427+
method="tools/async/get",
428+
params=types.GetToolAsyncResultRequestParams(
429+
token=async_call_ref.token
430+
),
431+
)
432+
)
433+
434+
assert not result.isError
435+
assert not result.isPending
436+
assert len(result.content) == 1
437+
assert type(result.content[0]) is types.TextContent
438+
assert result.content[0].text == "mock_user"

0 commit comments

Comments
 (0)