|
4 | 4 | import time |
5 | 5 | from collections.abc import AsyncGenerator, Generator |
6 | 6 | from typing import Any |
| 7 | +from unittest.mock import Mock |
7 | 8 |
|
8 | 9 | import anyio |
9 | 10 | import httpx |
|
16 | 17 | from starlette.responses import Response |
17 | 18 | from starlette.routing import Mount, Route |
18 | 19 |
|
| 20 | +import mcp.client.sse |
19 | 21 | import mcp.types as types |
20 | 22 | from mcp.client.session import ClientSession |
21 | 23 | from mcp.client.sse import _extract_session_id_from_endpoint, sse_client |
@@ -220,22 +222,19 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non |
220 | 222 | async def test_sse_client_on_session_created_not_called_when_no_session_id( |
221 | 223 | server: None, server_url: str, monkeypatch: pytest.MonkeyPatch |
222 | 224 | ) -> None: |
223 | | - from mcp.client import sse |
| 225 | + callback_mock = Mock() |
224 | 226 |
|
225 | | - callback_called = False |
| 227 | + def mock_extract(url: str) -> None: |
| 228 | + return None |
226 | 229 |
|
227 | | - def on_session_created(session_id: str) -> None: |
228 | | - nonlocal callback_called |
229 | | - callback_called = True |
230 | | - |
231 | | - monkeypatch.setattr(sse, "_extract_session_id_from_endpoint", lambda url: None) |
| 230 | + monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) |
232 | 231 |
|
233 | | - async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: |
| 232 | + async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: |
234 | 233 | async with ClientSession(*streams) as session: |
235 | 234 | result = await session.initialize() |
236 | 235 | assert isinstance(result, InitializeResult) |
237 | 236 |
|
238 | | - assert callback_called is False |
| 237 | + callback_mock.assert_not_called() |
239 | 238 |
|
240 | 239 |
|
241 | 240 | @pytest.fixture |
|
0 commit comments