|
1 | 1 | from contextlib import AsyncExitStack |
2 | | -from unittest.mock import AsyncMock, Mock |
| 2 | +from unittest.mock import AsyncMock, Mock, PropertyMock |
3 | 3 |
|
4 | 4 | import anyio |
5 | 5 | import pytest |
6 | 6 |
|
7 | 7 | from mcp import types |
| 8 | +from mcp.server.auth.middleware.auth_context import ( |
| 9 | + auth_context_var as user_context, |
| 10 | +) |
8 | 11 | from mcp.server.lowlevel.result_cache import ResultCache |
9 | 12 |
|
10 | 13 |
|
@@ -376,3 +379,60 @@ def test_timer(): |
376 | 379 | assert len(result.content) == 1 |
377 | 380 | assert type(result.content[0]) is types.TextContent |
378 | 381 | assert result.content[0].text == "Unknown async token" |
| 382 | + |
| 383 | + |
| 384 | +@pytest.mark.anyio |
| 385 | +async def test_async_call_pass_auth(): |
| 386 | + """Tests async calls pass auth context to background thread""" |
| 387 | + |
| 388 | + mock_user = Mock() |
| 389 | + type(mock_user).username = PropertyMock(return_value="mock_user") |
| 390 | + |
| 391 | + mock_session = AsyncMock() |
| 392 | + mock_context = Mock() |
| 393 | + mock_context.session = mock_session |
| 394 | + result_cache = ResultCache(max_size=1, max_keep_alive=1) |
| 395 | + |
| 396 | + async def test_call(call: types.CallToolRequest) -> types.ServerResult: |
| 397 | + user = user_context.get() |
| 398 | + if user is None: |
| 399 | + return types.ServerResult( |
| 400 | + types.CallToolResult( |
| 401 | + content=[types.TextContent(type="text", text="unauthorised")], |
| 402 | + isError=True, |
| 403 | + ) |
| 404 | + ) |
| 405 | + else: |
| 406 | + return types.ServerResult( |
| 407 | + types.CallToolResult( |
| 408 | + content=[types.TextContent(type="text", text=str(user.username))] |
| 409 | + ) |
| 410 | + ) |
| 411 | + |
| 412 | + async_call = types.CallToolAsyncRequest( |
| 413 | + method="tools/async/call", params=types.CallToolAsyncRequestParams(name="test") |
| 414 | + ) |
| 415 | + |
| 416 | + async with AsyncExitStack() as stack: |
| 417 | + await stack.enter_async_context(result_cache) |
| 418 | + |
| 419 | + user_context.set(mock_user) |
| 420 | + async_call_ref = await result_cache.start_call( |
| 421 | + test_call, async_call, mock_context |
| 422 | + ) |
| 423 | + assert async_call_ref.token is not None |
| 424 | + |
| 425 | + result = await result_cache.get_result( |
| 426 | + types.GetToolAsyncResultRequest( |
| 427 | + method="tools/async/get", |
| 428 | + params=types.GetToolAsyncResultRequestParams( |
| 429 | + token=async_call_ref.token |
| 430 | + ), |
| 431 | + ) |
| 432 | + ) |
| 433 | + |
| 434 | + assert not result.isError |
| 435 | + assert not result.isPending |
| 436 | + assert len(result.content) == 1 |
| 437 | + assert type(result.content[0]) is types.TextContent |
| 438 | + assert result.content[0].text == "mock_user" |
0 commit comments