diff --git a/tests/integrations/mcp/test_mcp.py b/tests/integrations/mcp/test_mcp.py index 7358c3e6b4..f45564d26b 100644 --- a/tests/integrations/mcp/test_mcp.py +++ b/tests/integrations/mcp/test_mcp.py @@ -28,18 +28,15 @@ async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) +from mcp.types import GetPromptResult, PromptMessage, TextContent +from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.lowlevel import Server from mcp.server.lowlevel.server import request_ctx -from mcp.types import ( - JSONRPCMessage, - JSONRPCNotification, - JSONRPCRequest, - GetPromptResult, - PromptMessage, - TextContent, -) -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.shared.message import SessionMessage +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + +from starlette.routing import Mount +from starlette.applications import Starlette +from starlette.testclient import TestClient try: from mcp.server.lowlevel.server import request_ctx @@ -51,6 +48,71 @@ async def __call__(self, *args, **kwargs): from sentry_sdk.integrations.mcp import MCPIntegration +def json_rpc(app, method: str, params, request_id: str): + with TestClient(app) as client: + init_response = client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "clientInfo": {"name": "test-client", "version": "1.0"}, + "protocolVersion": "2025-11-25", + "capabilities": {}, + }, + "id": request_id, + }, + ) + + session_id = init_response.headers["mcp-session-id"] + + # Notification response is mandatory. + # https://modelcontextprotocol.io/specification/2025-11-25/basic/lifecycle + client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + "mcp-session-id": session_id, + }, + json={ + "jsonrpc": "2.0", + "method": "notifications/initialized", + "params": {}, + }, + ) + + response = client.post( + "/mcp/", + headers={ + "Accept": "application/json, text/event-stream", + "Content-Type": "application/json", + "mcp-session-id": session_id, + }, + json={ + "jsonrpc": "2.0", + "method": method, + "params": params, + "id": request_id, + }, + ) + + return session_id, response + + +def select_mcp_transactions(events): + return [ + event + for event in events + if event["type"] == "transaction" + and event["contexts"]["trace"]["op"] == "mcp.server" + ] + + @pytest.fixture(autouse=True) def reset_request_ctx(): """Reset request context before and after each test""" @@ -238,45 +300,71 @@ async def test_tool_handler_async( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-456", session_id="session-789", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.call_tool() async def test_tool_async(tool_name, arguments): - return {"status": "completed"} - - with start_transaction(name="mcp tx"): - result = await test_tool_async("process", {"data": "test"}) + return [ + TextContent( + type="text", + text=json.dumps({"status": "completed"}), + ) + ] - assert result == {"status": "completed"} + session_id, result = json_rpc( + app, + method="tools/call", + params={ + "name": "process", + "arguments": { + "data": "test", + }, + }, + request_id="req-456", + ) + assert result.json()["result"]["content"][0]["text"] == json.dumps( + {"status": "completed"} + ) - (tx,) = events + transactions = select_mcp_transactions(events) + assert len(transactions) == 1 + tx = transactions[0] assert tx["type"] == "transaction" - assert len(tx["spans"]) == 1 - span = tx["spans"][0] - assert span["op"] == OP.MCP_SERVER - assert span["description"] == "tools/call process" - assert span["origin"] == "auto.ai.mcp" + assert tx["contexts"]["trace"]["op"] == OP.MCP_SERVER + assert tx["transaction"] == "tools/call process" + assert tx["contexts"]["trace"]["origin"] == "auto.ai.mcp" # Check span data - assert span["data"][SPANDATA.MCP_TOOL_NAME] == "process" - assert span["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call" - assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" - assert span["data"][SPANDATA.MCP_REQUEST_ID] == "req-456" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-789" - assert span["data"]["mcp.request.argument.data"] == '"test"' + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_TOOL_NAME] == "process" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_METHOD_NAME] == "tools/call" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_TRANSPORT] == "http" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_REQUEST_ID] == "req-456" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_SESSION_ID] == session_id + assert tx["contexts"]["trace"]["data"]["mcp.request.argument.data"] == '"test"' # Check PII-sensitive data if send_default_pii and include_prompts: - assert span["data"][SPANDATA.MCP_TOOL_RESULT_CONTENT] == json.dumps( - {"status": "completed"} + # TODO: Investigate why tool result is double-serialized. + assert tx["contexts"]["trace"]["data"][ + SPANDATA.MCP_TOOL_RESULT_CONTENT + ] == json.dumps( + json.dumps( + {"status": "completed"}, + ) ) else: - assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in span["data"] + assert SPANDATA.MCP_TOOL_RESULT_CONTENT not in tx["contexts"]["trace"]["data"] @pytest.mark.asyncio @@ -426,39 +514,66 @@ async def test_prompt_handler_async( server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-async-prompt", session_id="session-abc", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.get_prompt() async def test_prompt_async(name, arguments): - return MockGetPromptResult( - [ - MockPromptMessage("system", "You are a helpful assistant"), - MockPromptMessage("user", "What is MCP?"), - ] + return GetPromptResult( + description="A helpful test prompt", + messages=[ + PromptMessage( + role="user", + content=TextContent( + type="text", text="You are a helpful assistant" + ), + ), + PromptMessage( + role="user", content=TextContent(type="text", text="What is MCP?") + ), + ], ) - with start_transaction(name="mcp tx"): - result = await test_prompt_async("mcp_info", {}) - - assert len(result.messages) == 2 + _, result = json_rpc( + app, + method="prompts/get", + params={ + "name": "mcp_info", + "arguments": {}, + }, + request_id="req-async-prompt", + ) + assert len(result.json()["result"]["messages"]) == 2 - (tx,) = events + transactions = select_mcp_transactions(events) + assert len(transactions) == 1 + tx = transactions[0] assert tx["type"] == "transaction" - assert len(tx["spans"]) == 1 - span = tx["spans"][0] - assert span["op"] == OP.MCP_SERVER - assert span["description"] == "prompts/get mcp_info" + assert tx["contexts"]["trace"]["op"] == OP.MCP_SERVER + assert tx["transaction"] == "prompts/get mcp_info" # For multi-message prompts, count is always captured - assert span["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT] == 2 + assert ( + tx["contexts"]["trace"]["data"][SPANDATA.MCP_PROMPT_RESULT_MESSAGE_COUNT] == 2 + ) # Role/content are never captured for multi-message prompts (even with PII) - assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE not in span["data"] - assert SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT not in span["data"] + assert ( + SPANDATA.MCP_PROMPT_RESULT_MESSAGE_ROLE not in tx["contexts"]["trace"]["data"] + ) + assert ( + SPANDATA.MCP_PROMPT_RESULT_MESSAGE_CONTENT + not in tx["contexts"]["trace"]["data"] + ) @pytest.mark.asyncio @@ -560,33 +675,53 @@ async def test_resource_handler_async(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context - mock_ctx = MockRequestContext( - request_id="req-async-resource", session_id="session-res", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, + ) + + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), ) - request_ctx.set(mock_ctx) @server.read_resource() async def test_resource_async(uri): - return {"data": "resource data"} + return [ + ReadResourceContents( + content=json.dumps({"data": "resource data"}), mime_type="text/plain" + ) + ] - with start_transaction(name="mcp tx"): - uri = MockURI("https://example.com/resource") - result = await test_resource_async(uri) + session_id, result = json_rpc( + app, + method="resources/read", + params={ + "uri": "https://example.com/resource", + }, + request_id="req-async-resource", + ) - assert result["data"] == "resource data" + assert result.json()["result"]["contents"][0]["text"] == json.dumps( + {"data": "resource data"} + ) - (tx,) = events + transactions = select_mcp_transactions(events) + assert len(transactions) == 1 + tx = transactions[0] assert tx["type"] == "transaction" - assert len(tx["spans"]) == 1 - span = tx["spans"][0] - assert span["op"] == OP.MCP_SERVER - assert span["description"] == "resources/read https://example.com/resource" + assert tx["contexts"]["trace"]["op"] == OP.MCP_SERVER + assert tx["transaction"] == "resources/read https://example.com/resource" - assert span["data"][SPANDATA.MCP_RESOURCE_URI] == "https://example.com/resource" - assert span["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "https" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-res" + assert ( + tx["contexts"]["trace"]["data"][SPANDATA.MCP_RESOURCE_URI] + == "https://example.com/resource" + ) + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_RESOURCE_PROTOCOL] == "https" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_SESSION_ID] == session_id @pytest.mark.asyncio @@ -1044,28 +1179,48 @@ def test_streamable_http_transport_detection(sentry_init, capture_events): server = Server("test-server") - # Set up mock request context with StreamableHTTP transport - mock_ctx = MockRequestContext( - request_id="req-http", session_id="session-http-456", transport="http" + session_manager = StreamableHTTPSessionManager( + app=server, + json_response=True, ) - request_ctx.set(mock_ctx) - @server.call_tool() - def test_tool(tool_name, arguments): - return {"result": "success"} + app = Starlette( + routes=[ + Mount("/mcp", app=session_manager.handle_request), + ], + lifespan=lambda app: session_manager.run(), + ) - with start_transaction(name="mcp tx"): - result = test_tool("http_tool", {}) + @server.call_tool() + async def test_tool(tool_name, arguments): + return [ + TextContent( + type="text", + text=json.dumps({"status": "success"}), + ) + ] - assert result == {"result": "success"} + session_id, result = json_rpc( + app, + method="tools/call", + params={ + "name": "http_tool", + "arguments": {}, + }, + request_id="req-http", + ) + assert result.json()["result"]["content"][0]["text"] == json.dumps( + {"status": "success"} + ) - (tx,) = events - span = tx["spans"][0] + transactions = select_mcp_transactions(events) + assert len(transactions) == 1 + tx = transactions[0] # Check that HTTP transport is detected - assert span["data"][SPANDATA.MCP_TRANSPORT] == "http" - assert span["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" - assert span["data"][SPANDATA.MCP_SESSION_ID] == "session-http-456" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_TRANSPORT] == "http" + assert tx["contexts"]["trace"]["data"][SPANDATA.NETWORK_TRANSPORT] == "tcp" + assert tx["contexts"]["trace"]["data"][SPANDATA.MCP_SESSION_ID] == session_id @pytest.mark.asyncio