Skip to content

Commit a09d0b1

Browse files
committed
fix: Eliminate port allocation race condition in all affected test files
Extended the fix from test_integration.py to all test files that suffered from the same TOCTOU (Time-of-Check-Time-of-Use) race condition. When running with pytest-xdist parallel workers, multiple workers would get the same "free" port from `socket.bind(("127.0.0.1", 0))`, leading to: - "address already in use" errors - Tests connecting to wrong servers Solution: - Updated 7 test files to use `get_worker_specific_port(worker_id)`: * tests/server/test_streamable_http_security.py * tests/server/test_sse_security.py * tests/shared/test_streamable_http.py (3 fixtures) * tests/shared/test_sse.py * tests/shared/test_ws.py * tests/client/test_http_unicode.py * tests/client/test_notification_response.py - Fixed socket cleanup issue in tests/test_test_helpers.py Each worker now gets an exclusive port range from 40000-59999, eliminating the race condition across all test files that start servers. Testing: - All 703 tests pass with pytest-xdist parallel execution - Port allocation now consistent across all parallel workers
1 parent 7769a83 commit a09d0b1

File tree

8 files changed

+39
-52
lines changed

8 files changed

+39
-52
lines changed

tests/client/test_http_unicode.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from mcp.client.session import ClientSession
1616
from mcp.client.streamable_http import streamablehttp_client
17+
from tests.test_helpers import get_worker_specific_port
1718

1819
# Test constants with various Unicode characters
1920
UNICODE_TEST_STRINGS = {
@@ -145,11 +146,9 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
145146

146147

147148
@pytest.fixture
148-
def unicode_server_port() -> int:
149+
def unicode_server_port(worker_id: str) -> int:
149150
"""Find an available port for the Unicode test server."""
150-
with socket.socket() as s:
151-
s.bind(("127.0.0.1", 0))
152-
return s.getsockname()[1]
151+
return get_worker_specific_port(worker_id)
153152

154153

155154
@pytest.fixture

tests/client/test_notification_response.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from mcp.client.streamable_http import streamablehttp_client
2323
from mcp.shared.session import RequestResponder
2424
from mcp.types import ClientNotification, RootsListChangedNotification
25+
from tests.test_helpers import get_worker_specific_port
2526

2627

2728
def create_non_sdk_server_app() -> Starlette:
@@ -81,11 +82,9 @@ def run_non_sdk_server(port: int) -> None:
8182

8283

8384
@pytest.fixture
84-
def non_sdk_server_port() -> int:
85+
def non_sdk_server_port(worker_id: str) -> int:
8586
"""Get an available port for the test server."""
86-
with socket.socket() as s:
87-
s.bind(("127.0.0.1", 0))
88-
return s.getsockname()[1]
87+
return get_worker_specific_port(worker_id)
8988

9089

9190
@pytest.fixture

tests/server/test_sse_security.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import multiprocessing
5-
import socket
65

76
import httpx
87
import pytest
@@ -16,17 +15,15 @@
1615
from mcp.server.sse import SseServerTransport
1716
from mcp.server.transport_security import TransportSecuritySettings
1817
from mcp.types import Tool
19-
from tests.test_helpers import wait_for_server
18+
from tests.test_helpers import get_worker_specific_port, wait_for_server
2019

2120
logger = logging.getLogger(__name__)
2221
SERVER_NAME = "test_sse_security_server"
2322

2423

2524
@pytest.fixture
26-
def server_port() -> int:
27-
with socket.socket() as s:
28-
s.bind(("127.0.0.1", 0))
29-
return s.getsockname()[1]
25+
def server_port(worker_id: str) -> int:
26+
return get_worker_specific_port(worker_id)
3027

3128

3229
@pytest.fixture

tests/server/test_streamable_http_security.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import logging
44
import multiprocessing
5-
import socket
65
from collections.abc import AsyncGenerator
76
from contextlib import asynccontextmanager
87

@@ -17,17 +16,15 @@
1716
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
1817
from mcp.server.transport_security import TransportSecuritySettings
1918
from mcp.types import Tool
20-
from tests.test_helpers import wait_for_server
19+
from tests.test_helpers import get_worker_specific_port, wait_for_server
2120

2221
logger = logging.getLogger(__name__)
2322
SERVER_NAME = "test_streamable_http_security_server"
2423

2524

2625
@pytest.fixture
27-
def server_port() -> int:
28-
with socket.socket() as s:
29-
s.bind(("127.0.0.1", 0))
30-
return s.getsockname()[1]
26+
def server_port(worker_id: str) -> int:
27+
return get_worker_specific_port(worker_id)
3128

3229

3330
@pytest.fixture

tests/shared/test_sse.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,14 @@
3232
TextResourceContents,
3333
Tool,
3434
)
35-
from tests.test_helpers import wait_for_server
35+
from tests.test_helpers import get_worker_specific_port, wait_for_server
3636

3737
SERVER_NAME = "test_server_for_SSE"
3838

3939

4040
@pytest.fixture
41-
def server_port() -> int:
42-
with socket.socket() as s:
43-
s.bind(("127.0.0.1", 0))
44-
return s.getsockname()[1]
41+
def server_port(worker_id: str) -> int:
42+
return get_worker_specific_port(worker_id)
4543

4644

4745
@pytest.fixture

tests/shared/test_streamable_http.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import json
88
import multiprocessing
9-
import socket
109
from collections.abc import Generator
1110
from typing import Any
1211

@@ -42,7 +41,7 @@
4241
from mcp.shared.message import ClientMessageMetadata
4342
from mcp.shared.session import RequestResponder
4443
from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool
45-
from tests.test_helpers import wait_for_server
44+
from tests.test_helpers import get_worker_specific_port, wait_for_server
4645

4746
# Test constants
4847
SERVER_NAME = "test_streamable_http_server"
@@ -322,19 +321,15 @@ def run_server(port: int, is_json_response_enabled: bool = False, event_store: E
322321

323322
# Test fixtures - using same approach as SSE tests
324323
@pytest.fixture
325-
def basic_server_port() -> int:
324+
def basic_server_port(worker_id: str) -> int:
326325
"""Find an available port for the basic server."""
327-
with socket.socket() as s:
328-
s.bind(("127.0.0.1", 0))
329-
return s.getsockname()[1]
326+
return get_worker_specific_port(worker_id)
330327

331328

332329
@pytest.fixture
333-
def json_server_port() -> int:
330+
def json_server_port(worker_id: str) -> int:
334331
"""Find an available port for the JSON response server."""
335-
with socket.socket() as s:
336-
s.bind(("127.0.0.1", 0))
337-
return s.getsockname()[1]
332+
return get_worker_specific_port(worker_id)
338333

339334

340335
@pytest.fixture
@@ -360,11 +355,9 @@ def event_store() -> SimpleEventStore:
360355

361356

362357
@pytest.fixture
363-
def event_server_port() -> int:
358+
def event_server_port(worker_id: str) -> int:
364359
"""Find an available port for the event store server."""
365-
with socket.socket() as s:
366-
s.bind(("127.0.0.1", 0))
367-
return s.getsockname()[1]
360+
return get_worker_specific_port(worker_id)
368361

369362

370363
@pytest.fixture

tests/shared/test_ws.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import multiprocessing
2-
import socket
32
import time
43
from collections.abc import AsyncGenerator, Generator
54
from typing import Any
@@ -26,16 +25,14 @@
2625
TextResourceContents,
2726
Tool,
2827
)
29-
from tests.test_helpers import wait_for_server
28+
from tests.test_helpers import get_worker_specific_port, wait_for_server
3029

3130
SERVER_NAME = "test_server_for_WS"
3231

3332

3433
@pytest.fixture
35-
def server_port() -> int:
36-
with socket.socket() as s:
37-
s.bind(("127.0.0.1", 0))
38-
return s.getsockname()[1]
34+
def server_port(worker_id: str) -> int:
35+
return get_worker_specific_port(worker_id)
3936

4037

4138
@pytest.fixture

tests/test_test_helpers.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -206,13 +206,20 @@ def test_get_worker_specific_port_raises_when_no_ports_available(monkeypatch: py
206206
try:
207207
# Try to bind all ports in range (may not succeed on all platforms)
208208
for port in range(start, min(start + 10, end)): # Just bind first 10 for speed
209+
s: socket.socket | None = None
209210
try:
210211
s = socket.socket()
211-
s.bind(("127.0.0.1", port))
212-
sockets.append(s)
213-
except OSError:
214-
# Port already in use, skip
215-
pass
212+
try:
213+
s.bind(("127.0.0.1", port))
214+
sockets.append(s)
215+
except OSError:
216+
# Port already in use, skip
217+
s.close()
218+
except Exception:
219+
# Clean up socket if any unexpected error
220+
if s is not None:
221+
s.close()
222+
raise
216223

217224
# If we managed to bind some ports, temporarily exhaust the small range
218225
if sockets:
@@ -221,5 +228,5 @@ def test_get_worker_specific_port_raises_when_no_ports_available(monkeypatch: py
221228
pass
222229
finally:
223230
# Clean up sockets
224-
for s in sockets:
225-
s.close()
231+
for sock in sockets:
232+
sock.close()

0 commit comments

Comments
 (0)