Skip to content

Commit 236a041

Browse files
committed
fix: Resolve integration test issues and import problems
- Fix broken test_notifications function with correct implementation - Fix multiprocessing import issues in run_server_with_transport by adding snippets path - Apply code formatting improvements - Tests should now start servers properly and run without hanging
1 parent 2f568c0 commit 236a041

File tree

1 file changed

+79
-16
lines changed

1 file changed

+79
-16
lines changed

tests/server/fastmcp/test_integration.py

Lines changed: 79 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,29 @@ def server_url(server_port: int) -> str:
8585

8686
def run_server_with_transport(module_name: str, port: int, transport: str) -> None:
8787
"""Run server with specified transport."""
88+
import sys
89+
import os
90+
91+
# Add examples/snippets to Python path for multiprocessing context
92+
snippets_path = os.path.join(
93+
os.path.dirname(__file__), "..", "..", "..", "examples", "snippets"
94+
)
95+
sys.path.insert(0, os.path.abspath(snippets_path))
96+
97+
# Import the servers module in the multiprocessing context
98+
from servers import (
99+
basic_tool,
100+
basic_resource,
101+
basic_prompt,
102+
tool_progress,
103+
sampling,
104+
elicitation,
105+
completion,
106+
notifications,
107+
fastmcp_quickstart,
108+
structured_output,
109+
)
110+
88111
# Get the MCP instance based on module name
89112
if module_name == "basic_tool":
90113
mcp = basic_tool.mcp
@@ -117,7 +140,9 @@ def run_server_with_transport(module_name: str, port: int, transport: str) -> No
117140
else:
118141
raise ValueError(f"Invalid transport for test server: {transport}")
119142

120-
server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error"))
143+
server = uvicorn.Server(
144+
config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")
145+
)
121146
print(f"Starting {transport} server on port {port}")
122147
server.run()
123148

@@ -325,10 +350,14 @@ async def test_basic_prompts(server_transport: str, server_url: str) -> None:
325350

326351
# Test review_code prompt
327352
prompts = await session.list_prompts()
328-
review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None)
353+
review_prompt = next(
354+
(p for p in prompts.prompts if p.name == "review_code"), None
355+
)
329356
assert review_prompt is not None
330357

331-
prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"})
358+
prompt_result = await session.get_prompt(
359+
"review_code", {"code": "def hello():\n print('Hello')"}
360+
)
332361
assert isinstance(prompt_result, GetPromptResult)
333362
assert len(prompt_result.messages) == 1
334363
assert isinstance(prompt_result.messages[0].content, TextContent)
@@ -384,16 +413,18 @@ async def test_tool_progress(server_transport: str, server_url: str) -> None:
384413
assert result.capabilities.tools is not None
385414

386415
# Test long_running_task tool that reports progress
387-
tool_result = await session.call_tool("long_running_task", {"task_name": "test", "steps": 3})
416+
tool_result = await session.call_tool(
417+
"long_running_task", {"task_name": "test", "steps": 3}
418+
)
388419
assert len(tool_result.content) == 1
389420
assert isinstance(tool_result.content[0], TextContent)
390421
assert "Task 'test' completed" in tool_result.content[0].text
391422

392423
# Verify that progress notifications or log messages were sent
393424
# Progress can come through either progress notifications or log messages
394-
total_notifications = len(notification_collector.progress_notifications) + len(
395-
notification_collector.log_messages
396-
)
425+
total_notifications = len(
426+
notification_collector.progress_notifications
427+
) + len(notification_collector.log_messages)
397428
assert total_notifications > 0
398429

399430

@@ -414,7 +445,9 @@ async def test_sampling(server_transport: str, server_url: str) -> None:
414445

415446
async with client_cm as client_streams:
416447
read_stream, write_stream = unpack_streams(client_streams)
417-
async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session:
448+
async with ClientSession(
449+
read_stream, write_stream, sampling_callback=sampling_callback
450+
) as session:
418451
# Test initialization
419452
result = await session.initialize()
420453
assert isinstance(result, InitializeResult)
@@ -445,7 +478,9 @@ async def test_elicitation(server_transport: str, server_url: str) -> None:
445478

446479
async with client_cm as client_streams:
447480
read_stream, write_stream = unpack_streams(client_streams)
448-
async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session:
481+
async with ClientSession(
482+
read_stream, write_stream, elicitation_callback=elicitation_callback
483+
) as session:
449484
# Test initialization
450485
result = await session.initialize()
451486
assert isinstance(result, InitializeResult)
@@ -491,7 +526,9 @@ async def test_completion(server_transport: str, server_url: str) -> None:
491526
assert len(prompts.prompts) > 0
492527

493528
# Test getting a prompt
494-
prompt_result = await session.get_prompt("review_code", {"language": "python", "code": "def test(): pass"})
529+
prompt_result = await session.get_prompt(
530+
"review_code", {"language": "python", "code": "def test(): pass"}
531+
)
495532
assert len(prompt_result.messages) > 0
496533

497534

@@ -510,11 +547,35 @@ async def test_notifications(server_transport: str, server_url: str) -> None:
510547
transport = server_transport
511548
client_cm = create_client_for_transport(transport, server_url)
512549

513-
assert completion_result is not None
514-
assert hasattr(completion_result, "completion")
515-
assert completion_result.completion is not None
516-
assert "python" in completion_result.completion.values
517-
assert all(lang.startswith("py") for lang in completion_result.completion.values)
550+
notification_collector = NotificationCollector()
551+
552+
async with client_cm as client_streams:
553+
read_stream, write_stream = unpack_streams(client_streams)
554+
async with ClientSession(
555+
read_stream,
556+
write_stream,
557+
message_handler=notification_collector.handle_generic_notification,
558+
) as session:
559+
# Test initialization
560+
result = await session.initialize()
561+
assert isinstance(result, InitializeResult)
562+
assert result.serverInfo.name == "Notifications Example"
563+
assert result.capabilities.tools is not None
564+
565+
# Test process_data tool that sends log notifications
566+
tool_result = await session.call_tool("process_data", {"data": "test_data"})
567+
assert len(tool_result.content) == 1
568+
assert isinstance(tool_result.content[0], TextContent)
569+
assert "Processed: test_data" in tool_result.content[0].text
570+
571+
# Verify log messages were sent at different levels
572+
assert len(notification_collector.log_messages) >= 1
573+
log_levels = {msg.level for msg in notification_collector.log_messages}
574+
# Should have at least one of these log levels
575+
assert log_levels & {"debug", "info", "warning", "error"}
576+
577+
# Verify resource list change notification was sent
578+
assert len(notification_collector.resource_notifications) > 0
518579

519580

520581
# Test FastMCP quickstart example
@@ -579,7 +640,9 @@ async def test_structured_output(server_transport: str, server_url: str) -> None
579640
assert result.serverInfo.name == "Structured Output Example"
580641

581642
# Test get_weather tool
582-
weather_result = await session.call_tool("get_weather", {"city": "New York"})
643+
weather_result = await session.call_tool(
644+
"get_weather", {"city": "New York"}
645+
)
583646
assert len(weather_result.content) == 1
584647
assert isinstance(weather_result.content[0], TextContent)
585648

0 commit comments

Comments
 (0)