From db9ba7b419bfe0e0686576c5365e8a969bd97637 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Sat, 14 Jun 2025 01:08:02 +0300 Subject: [PATCH 1/3] raw request propagation in tools - implementation, unit tests, types --- package-lock.json | 4 +- package.json | 2 +- src/client/index.test.ts | 15 +- src/server/index.test.ts | 14 +- src/server/mcp.test.ts | 365 ++++++++++++++---------------- src/server/mcp.ts | 4 + src/server/sse.test.ts | 204 ++++++++++++++++- src/server/sse.ts | 8 +- src/server/streamableHttp.test.ts | 65 ++++++ src/server/streamableHttp.ts | 8 +- src/server/types/types.ts | 31 +++ src/shared/protocol.test.ts | 126 +++++++++++ src/shared/protocol.ts | 14 +- src/shared/transport.ts | 7 +- 14 files changed, 641 insertions(+), 226 deletions(-) create mode 100644 src/server/types/types.ts diff --git a/package-lock.json b/package-lock.json index 40bad9fe2..1a9a8f454 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.11.4", + "version": "1.12.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.11.4", + "version": "1.12.0", "license": "MIT", "dependencies": { "ajv": "^6.12.6", diff --git a/package.json b/package.json index 467800fc4..764ce2cbb 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.12.0", + "version": "1.12.2", "description": "Model Context Protocol implementation for TypeScript", "license": "MIT", "author": "Anthropic, PBC (https://anthropic.com)", diff --git a/src/client/index.test.ts b/src/client/index.test.ts index bbfa80faf..f80459f1f 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -20,7 +20,14 @@ import { import { Transport } from "../shared/transport.js"; import { Server } from "../server/index.js"; import { InMemoryTransport } from "../inMemory.js"; - +import { RequestInfo } from "../server/types/types.js"; + +const mockRequestInfo: RequestInfo = { + headers: { + 'content-type': 'application/json', + 'accept': 'application/json', + }, +}; /*** * Test: Initialize with Matching Protocol Version */ @@ -42,7 +49,7 @@ test("should initialize with matching protocol version", async () => { }, instructions: "test instructions", }, - }); + }, { requestInfo: mockRequestInfo }); } return Promise.resolve(); }), @@ -100,7 +107,7 @@ test("should initialize with supported older protocol version", async () => { version: "1.0", }, }, - }); + }, { requestInfo: mockRequestInfo }); } return Promise.resolve(); }), @@ -150,7 +157,7 @@ test("should reject unsupported protocol version", async () => { version: "1.0", }, }, - }); + }, { requestInfo: mockRequestInfo }); } return Promise.resolve(); }), diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 7c0fbc51a..e015be94c 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -19,6 +19,14 @@ import { import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; +import { RequestInfo } from "./types/types.js"; + +const mockRequestInfo: RequestInfo = { + headers: { + 'content-type': 'application/json', + 'traceparent': '00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01', + }, +}; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -77,7 +85,7 @@ test("should accept latest protocol version", async () => { version: "1.0", }, }, - }); + }, { requestInfo: mockRequestInfo }); await expect(sendPromise).resolves.toBeUndefined(); }); @@ -138,7 +146,7 @@ test("should accept supported older protocol version", async () => { version: "1.0", }, }, - }); + }, { requestInfo: mockRequestInfo }); await expect(sendPromise).resolves.toBeUndefined(); }); @@ -198,7 +206,7 @@ test("should handle unsupported protocol version", async () => { version: "1.0", }, }, - }); + }, { requestInfo: mockRequestInfo }); await expect(sendPromise).resolves.toBeUndefined(); }); diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 49f852d65..773777cbb 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -18,6 +18,14 @@ import { import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; import { UriTemplate } from "../shared/uriTemplate.js"; +import { RequestInfo } from "./types/types.js"; + +const mockRequestInfo: RequestInfo = { + headers: { + 'content-type': 'application/json', + 'accept': 'application/json', + }, +}; describe("McpServer", () => { /*** @@ -212,7 +220,8 @@ describe("ResourceTemplate", () => { signal: abortController.signal, requestId: 'not-implemented', sendRequest: () => { throw new Error("Not implemented") }, - sendNotification: () => { throw new Error("Not implemented") } + sendNotification: () => { throw new Error("Not implemented") }, + requestInfo: mockRequestInfo }); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); @@ -913,18 +922,10 @@ describe("tool()", () => { name: "test server", version: "1.0", }); - - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.tool( "test", @@ -1056,17 +1057,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); // Register a tool with outputSchema mcpServer.registerTool( @@ -1169,17 +1163,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); // Register a tool with outputSchema that returns only content without structuredContent mcpServer.registerTool( @@ -1233,17 +1220,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); // Register a tool with outputSchema that returns invalid data mcpServer.registerTool( @@ -1308,17 +1288,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); let receivedSessionId: string | undefined; mcpServer.tool("test-tool", async (extra) => { @@ -1364,17 +1337,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); let receivedRequestId: string | number | undefined; mcpServer.tool("request-id-test", async (extra) => { @@ -1423,17 +1389,10 @@ describe("tool()", () => { { capabilities: { logging: {} } }, ); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); let receivedLogMessage: string | undefined; const loggingMessage = "hello here is log message 1"; @@ -1480,17 +1439,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.tool( "test", @@ -1546,17 +1498,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.tool("error-test", async () => { throw new Error("Tool execution failed"); @@ -1598,17 +1543,10 @@ describe("tool()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - tools: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.tool("test-tool", async () => ({ content: [ @@ -2393,26 +2331,61 @@ describe("resource()", () => { }); /*** - * Test: Resource Template Parameter Completion + * Test: Registering a resource template with a complete callback should update server capabilities to advertise support for completion */ - test("should support completion of resource template parameters", async () => { + test("should advertise support for completion when a resource template with a complete callback is defined", async () => { const mcpServer = new McpServer({ name: "test server", version: "1.0", }); + const client = new Client({ + name: "test client", + version: "1.0", + }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - resources: {}, + mcpServer.resource( + "test", + new ResourceTemplate("test://resource/{category}", { + list: undefined, + complete: { + category: () => ["books", "movies", "music"], }, - }, + }), + async () => ({ + contents: [ + { + uri: "test://resource/test", + text: "Test content", + }, + ], + }), ); + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }) + }) + + /*** + * Test: Resource Template Parameter Completion + */ + test("should support completion of resource template parameters", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + mcpServer.resource( "test", new ResourceTemplate("test://resource/{category}", { @@ -2469,17 +2442,10 @@ describe("resource()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.resource( "test", @@ -2540,17 +2506,10 @@ describe("resource()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - resources: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); let receivedRequestId: string | number | undefined; mcpServer.resource("request-id-test", "test://resource", async (_uri, extra) => { @@ -3052,17 +3011,10 @@ describe("prompt()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.prompt( "test", @@ -3258,17 +3210,10 @@ describe("prompt()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.prompt("test-prompt", async () => ({ messages: [ @@ -3303,27 +3248,63 @@ describe("prompt()", () => { ).rejects.toThrow(/Prompt nonexistent-prompt not found/); }); + /*** - * Test: Prompt Argument Completion + * Test: Registering a prompt with a completable argument should update server capabilities to advertise support for completion */ - test("should support completion of prompt arguments", async () => { + test("should advertise support for completion when a prompt with a completable argument is defined", async () => { const mcpServer = new McpServer({ name: "test server", version: "1.0", }); + const client = new Client({ + name: "test client", + version: "1.0", + }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, + mcpServer.prompt( + "test-prompt", { - capabilities: { - prompts: {}, - }, + name: completable(z.string(), () => ["Alice", "Bob", "Charlie"]), }, + async ({ name }) => ({ + messages: [ + { + role: "assistant", + content: { + type: "text", + text: `Hello ${name}`, + }, + }, + ], + }), ); + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }) + }) + + /*** + * Test: Prompt Argument Completion + */ + test("should support completion of prompt arguments", async () => { + const mcpServer = new McpServer({ + name: "test server", + version: "1.0", + }); + + const client = new Client({ + name: "test client", + version: "1.0", + }); + mcpServer.prompt( "test-prompt", { @@ -3380,17 +3361,10 @@ describe("prompt()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); mcpServer.prompt( "test-prompt", @@ -3450,17 +3424,10 @@ describe("prompt()", () => { version: "1.0", }); - const client = new Client( - { - name: "test client", - version: "1.0", - }, - { - capabilities: { - prompts: {}, - }, - }, - ); + const client = new Client({ + name: "test client", + version: "1.0", + }); let receivedRequestId: string | number | undefined; mcpServer.prompt("request-id-test", async (extra) => { diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 5b864b8b4..38c869c78 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -236,6 +236,10 @@ export class McpServer { CompleteRequestSchema.shape.method.value, ); + this.server.registerCapabilities({ + completions: {}, + }); + this.server.setRequestHandler( CompleteRequestSchema, async (request): Promise => { diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 2fd2c0424..7edef6af0 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -1,20 +1,146 @@ import http from 'http'; import { jest } from '@jest/globals'; import { SSEServerTransport } from './sse.js'; +import { McpServer } from './mcp.js'; +import { createServer, type Server } from "node:http"; +import { AddressInfo } from "node:net"; +import { z } from 'zod'; +import { CallToolResult, JSONRPCMessage } from 'src/types.js'; const createMockResponse = () => { const res = { - writeHead: jest.fn(), - write: jest.fn().mockReturnValue(true), - on: jest.fn(), + writeHead: jest.fn().mockReturnThis(), + write: jest.fn().mockReturnThis(), + on: jest.fn().mockReturnThis(), + end: jest.fn().mockReturnThis(), }; - res.writeHead.mockReturnThis(); - res.on.mockReturnThis(); - return res as unknown as http.ServerResponse; + return res as unknown as jest.Mocked; }; +/** + * Helper to create and start test HTTP server with MCP setup + */ +async function createTestServerWithSse(args: { + mockRes: http.ServerResponse; +}): Promise<{ + server: Server; + transport: SSEServerTransport; + mcpServer: McpServer; + baseUrl: URL; + sessionId: string + serverPort: number; +}> { + const mcpServer = new McpServer( + { name: "test-server", version: "1.0.0" }, + { capabilities: { logging: {} } } + ); + + mcpServer.tool( + "greet", + "A simple greeting tool", + { name: z.string().describe("Name to greet") }, + async ({ name }): Promise => { + return { content: [{ type: "text", text: `Hello, ${name}!` }] }; + } + ); + + const endpoint = '/messages'; + + const transport = new SSEServerTransport(endpoint, args.mockRes); + const sessionId = transport.sessionId; + + await mcpServer.connect(transport); + + const server = createServer(async (req, res) => { + try { + await transport.handlePostMessage(req, res); + } catch (error) { + console.error("Error handling request:", error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + const port = (server.address() as AddressInfo).port; + + return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port }; +} + +async function readAllSSEEvents(response: Response): Promise { + const reader = response.body?.getReader(); + if (!reader) throw new Error('No readable stream'); + + const events: string[] = []; + const decoder = new TextDecoder(); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + if (value) { + events.push(decoder.decode(value)); + } + } + } finally { + reader.releaseLock(); + } + + return events; +} + +/** + * Helper to send JSON-RPC request + */ +async function sendSsePostRequest(baseUrl: URL, message: JSONRPCMessage | JSONRPCMessage[], sessionId?: string, extraHeaders?: Record): Promise { + const headers: Record = { + "Content-Type": "application/json", + Accept: "application/json, text/event-stream", + ...extraHeaders + }; + + if (sessionId) { + baseUrl.searchParams.set('sessionId', sessionId); + } + + return fetch(baseUrl, { + method: "POST", + headers, + body: JSON.stringify(message), + }); +} + describe('SSEServerTransport', () => { + + async function initializeServer(baseUrl: URL): Promise { + const response = await sendSsePostRequest(baseUrl, { + jsonrpc: "2.0", + method: "initialize", + params: { + clientInfo: { name: "test-client", version: "1.0" }, + protocolVersion: "2025-03-26", + capabilities: { + }, + }, + + id: "init-1", + } as JSONRPCMessage); + + expect(response.status).toBe(202); + + const text = await readAllSSEEvents(response); + + expect(text).toHaveLength(1); + expect(text[0]).toBe('Accepted'); + } + describe('start method', () => { it('should correctly append sessionId to a simple relative endpoint', async () => { const mockRes = createMockResponse(); @@ -105,5 +231,71 @@ describe('SSEServerTransport', () => { `event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n` ); }); + + /*** + * Test: Tool With Request Info + */ + it("should pass request info to tool callback", async () => { + const mockRes = createMockResponse(); + const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes }); + await initializeServer(baseUrl); + + mcpServer.tool( + "test-request-info", + "A simple test tool with request info", + { name: z.string().describe("Name to greet") }, + async ({ name }, { requestInfo }): Promise => { + return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "test-request-info", + arguments: { + name: "Test User", + }, + }, + id: "call-1", + }; + + const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId); + + expect(response.status).toBe(202); + + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`); + + const expectedMessage = { + result: { + content: [ + { + type: "text", + text: "Hello, Test User!", + }, + { + type: "text", + text: JSON.stringify({ + headers: { + host: `127.0.0.1:${serverPort}`, + connection: 'keep-alive', + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + 'accept-language': '*', + 'sec-fetch-mode': 'cors', + 'user-agent': 'node', + 'accept-encoding': 'gzip, deflate', + 'content-length': '124' + }, + }) + }, + ], + }, + jsonrpc: "2.0", + id: "call-1", + }; + expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`); + }); }); }); diff --git a/src/server/sse.ts b/src/server/sse.ts index 03f6fefc9..bac58c80a 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -5,6 +5,7 @@ import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { AuthInfo } from "./auth/types.js"; +import { MessageExtraInfo, RequestInfo } from "./types/types.js"; import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; @@ -20,7 +21,7 @@ export class SSEServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; + onmessage?: (message: JSONRPCMessage, extra: { authInfo?: AuthInfo, requestInfo: RequestInfo }) => void; /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. @@ -87,6 +88,7 @@ export class SSEServerTransport implements Transport { throw new Error(message); } const authInfo: AuthInfo | undefined = req.auth; + const requestInfo: RequestInfo = { headers: req.headers }; let body: string | unknown; try { @@ -106,7 +108,7 @@ export class SSEServerTransport implements Transport { } try { - await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { authInfo }); + await this.handleMessage(typeof body === 'string' ? JSON.parse(body) : body, { requestInfo, authInfo }); } catch { res.writeHead(400).end(`Invalid message: ${body}`); return; @@ -118,7 +120,7 @@ export class SSEServerTransport implements Transport { /** * Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST. */ - async handleMessage(message: unknown, extra?: { authInfo?: AuthInfo }): Promise { + async handleMessage(message: unknown, extra: MessageExtraInfo): Promise { let parsedMessage: JSONRPCMessage; try { parsedMessage = JSONRPCMessageSchema.parse(message); diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index b961f6c41..83af86cc8 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -206,6 +206,7 @@ function expectErrorResponse(data: unknown, expectedCode: number, expectedMessag describe("StreamableHTTPServerTransport", () => { let server: Server; + let mcpServer: McpServer; let transport: StreamableHTTPServerTransport; let baseUrl: URL; let sessionId: string; @@ -214,6 +215,7 @@ describe("StreamableHTTPServerTransport", () => { const result = await createTestServer(); server = result.server; transport = result.transport; + mcpServer = result.mcpServer; baseUrl = result.baseUrl; }); @@ -345,6 +347,69 @@ describe("StreamableHTTPServerTransport", () => { }); }); + /*** + * Test: Tool With Request Info + */ + it("should pass request info to tool callback", async () => { + sessionId = await initializeServer(); + + mcpServer.tool( + "test-request-info", + "A simple test tool with request info", + { name: z.string().describe("Name to greet") }, + async ({ name }, { requestInfo }): Promise => { + return { content: [{ type: "text", text: `Hello, ${name}!` }, { type: "text", text: `${JSON.stringify(requestInfo)}` }] }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "test-request-info", + arguments: { + name: "Test User", + }, + }, + id: "call-1", + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split("\n"); + const dataLine = eventLines.find(line => line.startsWith("data:")); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + + expect(eventData).toMatchObject({ + jsonrpc: "2.0", + result: { + content: [ + { type: "text", text: "Hello, Test User!" }, + { type: "text", text: expect.any(String) } + ], + }, + id: "call-1", + }); + + const requestInfo = JSON.parse(eventData.result.content[1].text); + expect(requestInfo).toMatchObject({ + headers: { + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + connection: 'keep-alive', + 'mcp-session-id': sessionId, + 'accept-language': '*', + 'user-agent': expect.any(String), + 'accept-encoding': expect.any(String), + 'content-length': expect.any(String), + }, + }); + }); + it("should reject requests without a valid session ID", async () => { const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index dc99c3065..779410957 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -5,6 +5,7 @@ import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; import { AuthInfo } from "./auth/types.js"; +import { MessageExtraInfo, RequestInfo } from "./types/types.js"; const MAXIMUM_MESSAGE_SIZE = "4mb"; @@ -113,7 +114,7 @@ export class StreamableHTTPServerTransport implements Transport { sessionId?: string | undefined; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; + onmessage?: (message: JSONRPCMessage, extra: MessageExtraInfo) => void; constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; @@ -318,6 +319,7 @@ export class StreamableHTTPServerTransport implements Transport { } const authInfo: AuthInfo | undefined = req.auth; + const requestInfo: RequestInfo = { headers: req.headers }; let rawMessage; if (parsedBody !== undefined) { @@ -395,7 +397,7 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message, { authInfo }); + this.onmessage?.(message, { authInfo, requestInfo }); } } else if (hasRequests) { // The default behavior is to use SSE streaming @@ -430,7 +432,7 @@ export class StreamableHTTPServerTransport implements Transport { // handle each message for (const message of messages) { - this.onmessage?.(message, { authInfo }); + this.onmessage?.(message, { authInfo, requestInfo }); } // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses // This will be handled by the send() method when responses are ready diff --git a/src/server/types/types.ts b/src/server/types/types.ts new file mode 100644 index 000000000..1114e50b7 --- /dev/null +++ b/src/server/types/types.ts @@ -0,0 +1,31 @@ +import { AuthInfo } from "../auth/types.js"; + +/** + * Headers that are compatible with both Node.js and the browser. + */ +export type IsomorphicHeaders = Record; + +/** + * Information about the incoming request. + */ +export interface RequestInfo { + /** + * The headers of the request. + */ + headers: IsomorphicHeaders; +} + +/** + * Extra information about a message. + */ +export interface MessageExtraInfo { + /** + * The request information. + */ + requestInfo: RequestInfo; + + /** + * The authentication information. + */ + authInfo?: AuthInfo; +} \ No newline at end of file diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index e0141da19..05bc8f3bc 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -27,9 +27,11 @@ class MockTransport implements Transport { describe("protocol tests", () => { let protocol: Protocol; let transport: MockTransport; + let sendSpy: jest.SpyInstance; beforeEach(() => { transport = new MockTransport(); + sendSpy = jest.spyOn(transport, 'send'); protocol = new (class extends Protocol { protected assertCapabilityForMethod(): void {} protected assertNotificationCapability(): void {} @@ -63,6 +65,130 @@ describe("protocol tests", () => { expect(oncloseMock).toHaveBeenCalled(); }); + describe("_meta preservation with onprogress", () => { + test("should preserve existing _meta when adding progressToken", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue", + anotherField: 123 + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue", + anotherField: 123, + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should create _meta with progressToken when no _meta exists", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test" + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should not modify _meta when onprogress is not provided", async () => { + await protocol.connect(transport); + const request = { + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue" + } + } + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + + protocol.request(request, mockSchema); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + data: "test", + _meta: { + customField: "customValue" + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + + test("should handle params being undefined with onprogress", async () => { + await protocol.connect(transport); + const request = { + method: "example" + }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + protocol.request(request, mockSchema, { + onprogress: onProgressMock, + }); + + expect(sendSpy).toHaveBeenCalledWith(expect.objectContaining({ + method: "example", + params: { + _meta: { + progressToken: expect.any(Number) + } + }, + jsonrpc: "2.0", + id: expect.any(Number) + }), expect.any(Object)); + }); + }); + describe("progress notification timeout behavior", () => { beforeEach(() => { jest.useFakeTimers(); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 4694929d7..ae539c177 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -25,6 +25,7 @@ import { } from "../types.js"; import { Transport, TransportSendOptions } from "./transport.js"; import { AuthInfo } from "../server/auth/types.js"; +import { MessageExtraInfo, RequestInfo } from "../server/types/types.js"; /** * Callback for progress notifications. @@ -127,6 +128,11 @@ export type RequestHandlerExtra void; + onmessage?: (message: JSONRPCMessage, extra: MessageExtraInfo) => void; /** * The session ID generated for this connection. From 166da76b0070a431c9f254a59e799a8c927f84fb Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Tue, 24 Jun 2025 19:39:37 +0300 Subject: [PATCH 2/3] extra parameter - remain optional for backwards compatibility --- package-lock.json | 1 + src/server/sse.ts | 4 ++-- src/server/streamableHttp.ts | 2 +- src/server/types/types.ts | 2 +- src/shared/protocol.ts | 6 +++--- src/shared/transport.ts | 2 +- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/package-lock.json b/package-lock.json index 9dd8236bd..016adf948 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,5 +1,6 @@ { "name": "@modelcontextprotocol/sdk", + "version": "1.13.1", "lockfileVersion": 3, "requires": true, "packages": { diff --git a/src/server/sse.ts b/src/server/sse.ts index 06c0bc8d4..a54e5788f 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -20,7 +20,7 @@ export class SSEServerTransport implements Transport { private _sessionId: string; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra: { authInfo?: AuthInfo, requestInfo: RequestInfo }) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. @@ -119,7 +119,7 @@ export class SSEServerTransport implements Transport { /** * Handle a client message, regardless of how it arrived. This can be used to inform the server of messages that arrive via a means different than HTTP POST. */ - async handleMessage(message: unknown, extra: MessageExtraInfo): Promise { + async handleMessage(message: unknown, extra?: MessageExtraInfo): Promise { let parsedMessage: JSONRPCMessage; try { parsedMessage = JSONRPCMessageSchema.parse(message); diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index b5f8aca77..807743eb2 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -114,7 +114,7 @@ export class StreamableHTTPServerTransport implements Transport { sessionId?: string; onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra: MessageExtraInfo) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; diff --git a/src/server/types/types.ts b/src/server/types/types.ts index 1114e50b7..3892af6cb 100644 --- a/src/server/types/types.ts +++ b/src/server/types/types.ts @@ -22,7 +22,7 @@ export interface MessageExtraInfo { /** * The request information. */ - requestInfo: RequestInfo; + requestInfo?: RequestInfo; /** * The authentication information. diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index ae539c177..33afd70ee 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -131,7 +131,7 @@ export type RequestHandlerExtra void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; /** * The session ID generated for this connection. From 606c278668c4328b2592da73f59d1b98b2ccf062 Mon Sep 17 00:00:00 2001 From: Konstantin Konstantinov Date: Wed, 25 Jun 2025 16:24:30 +0300 Subject: [PATCH 3/3] clean up tests - remove mockRequestInfo --- src/client/index.test.ts | 15 ++++----------- src/server/index.test.ts | 16 ++++------------ src/server/mcp.test.ts | 13 ++----------- src/server/sse.test.ts | 2 +- src/server/sse.ts | 3 +-- src/server/streamableHttp.ts | 3 +-- src/server/types/types.ts | 31 ------------------------------- src/shared/protocol.ts | 3 ++- src/shared/transport.ts | 3 +-- src/types.ts | 31 +++++++++++++++++++++++++++++++ 10 files changed, 47 insertions(+), 73 deletions(-) delete mode 100644 src/server/types/types.ts diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 02d6781c9..abd0c34e4 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -21,14 +21,7 @@ import { import { Transport } from "../shared/transport.js"; import { Server } from "../server/index.js"; import { InMemoryTransport } from "../inMemory.js"; -import { RequestInfo } from "../server/types/types.js"; - -const mockRequestInfo: RequestInfo = { - headers: { - 'content-type': 'application/json', - 'accept': 'application/json', - }, -}; + /*** * Test: Initialize with Matching Protocol Version */ @@ -50,7 +43,7 @@ test("should initialize with matching protocol version", async () => { }, instructions: "test instructions", }, - }, { requestInfo: mockRequestInfo }); + }); } return Promise.resolve(); }), @@ -108,7 +101,7 @@ test("should initialize with supported older protocol version", async () => { version: "1.0", }, }, - }, { requestInfo: mockRequestInfo }); + }); } return Promise.resolve(); }), @@ -158,7 +151,7 @@ test("should reject unsupported protocol version", async () => { version: "1.0", }, }, - }, { requestInfo: mockRequestInfo }); + }); } return Promise.resolve(); }), diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 137b89348..d91b90a9c 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -15,19 +15,11 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, SetLevelRequestSchema, - ErrorCode, + ErrorCode } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; import { Client } from "../client/index.js"; -import { RequestInfo } from "./types/types.js"; - -const mockRequestInfo: RequestInfo = { - headers: { - 'content-type': 'application/json', - 'traceparent': '00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01', - }, -}; test("should accept latest protocol version", async () => { let sendPromiseResolve: (value: unknown) => void; @@ -86,7 +78,7 @@ test("should accept latest protocol version", async () => { version: "1.0", }, }, - }, { requestInfo: mockRequestInfo }); + }); await expect(sendPromise).resolves.toBeUndefined(); }); @@ -147,7 +139,7 @@ test("should accept supported older protocol version", async () => { version: "1.0", }, }, - }, { requestInfo: mockRequestInfo }); + }); await expect(sendPromise).resolves.toBeUndefined(); }); @@ -207,7 +199,7 @@ test("should handle unsupported protocol version", async () => { version: "1.0", }, }, - }, { requestInfo: mockRequestInfo }); + }); await expect(sendPromise).resolves.toBeUndefined(); }); diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index d208d51e6..0764ffe88 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -14,21 +14,13 @@ import { LoggingMessageNotificationSchema, Notification, TextContent, - ElicitRequestSchema, + ElicitRequestSchema } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; import { UriTemplate } from "../shared/uriTemplate.js"; -import { RequestInfo } from "./types/types.js"; import { getDisplayName } from "../shared/metadataUtils.js"; -const mockRequestInfo: RequestInfo = { - headers: { - 'content-type': 'application/json', - 'accept': 'application/json', - }, -}; - describe("McpServer", () => { /*** * Test: Basic Server Instance @@ -222,8 +214,7 @@ describe("ResourceTemplate", () => { signal: abortController.signal, requestId: 'not-implemented', sendRequest: () => { throw new Error("Not implemented") }, - sendNotification: () => { throw new Error("Not implemented") }, - requestInfo: mockRequestInfo + sendNotification: () => { throw new Error("Not implemented") } }); expect(result?.resources).toHaveLength(1); expect(list).toHaveBeenCalled(); diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 7edef6af0..703cc5146 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -232,7 +232,7 @@ describe('SSEServerTransport', () => { ); }); - /*** + /** * Test: Tool With Request Info */ it("should pass request info to tool callback", async () => { diff --git a/src/server/sse.ts b/src/server/sse.ts index a54e5788f..de4dd60a6 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -1,11 +1,10 @@ import { randomUUID } from "node:crypto"; import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { JSONRPCMessage, JSONRPCMessageSchema, MessageExtraInfo, RequestInfo } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { AuthInfo } from "./auth/types.js"; -import { MessageExtraInfo, RequestInfo } from "./types/types.js"; import { URL } from 'url'; const MAXIMUM_MESSAGE_SIZE = "4mb"; diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 807743eb2..677da45ea 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,11 +1,10 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from "../types.js"; +import { MessageExtraInfo, RequestInfo, isInitializeRequest, isJSONRPCError, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId, SUPPORTED_PROTOCOL_VERSIONS, DEFAULT_NEGOTIATED_PROTOCOL_VERSION } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; import { randomUUID } from "node:crypto"; import { AuthInfo } from "./auth/types.js"; -import { MessageExtraInfo, RequestInfo } from "./types/types.js"; const MAXIMUM_MESSAGE_SIZE = "4mb"; diff --git a/src/server/types/types.ts b/src/server/types/types.ts deleted file mode 100644 index 3892af6cb..000000000 --- a/src/server/types/types.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { AuthInfo } from "../auth/types.js"; - -/** - * Headers that are compatible with both Node.js and the browser. - */ -export type IsomorphicHeaders = Record; - -/** - * Information about the incoming request. - */ -export interface RequestInfo { - /** - * The headers of the request. - */ - headers: IsomorphicHeaders; -} - -/** - * Extra information about a message. - */ -export interface MessageExtraInfo { - /** - * The request information. - */ - requestInfo?: RequestInfo; - - /** - * The authentication information. - */ - authInfo?: AuthInfo; -} \ No newline at end of file diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 33afd70ee..35839a4f8 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -22,10 +22,11 @@ import { Result, ServerCapabilities, RequestMeta, + MessageExtraInfo, + RequestInfo, } from "../types.js"; import { Transport, TransportSendOptions } from "./transport.js"; import { AuthInfo } from "../server/auth/types.js"; -import { MessageExtraInfo, RequestInfo } from "../server/types/types.js"; /** * Callback for progress notifications. diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 69fce10ed..96b291fab 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -1,5 +1,4 @@ -import { MessageExtraInfo } from "../server/types/types.js"; -import { JSONRPCMessage, RequestId } from "../types.js"; +import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js"; /** * Options for sending a JSON-RPC message. diff --git a/src/types.ts b/src/types.ts index 3606a6be7..f66d2c4b6 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,5 @@ import { z, ZodTypeAny } from "zod"; +import { AuthInfo } from "./server/auth/types.js"; export const LATEST_PROTOCOL_VERSION = "2025-06-18"; export const DEFAULT_NEGOTIATED_PROTOCOL_VERSION = "2025-03-26"; @@ -1463,6 +1464,36 @@ type Flatten = T extends Primitive type Infer = Flatten>; +/** + * Headers that are compatible with both Node.js and the browser. + */ +export type IsomorphicHeaders = Record; + +/** + * Information about the incoming request. + */ +export interface RequestInfo { + /** + * The headers of the request. + */ + headers: IsomorphicHeaders; +} + +/** + * Extra information about a message. + */ +export interface MessageExtraInfo { + /** + * The request information. + */ + requestInfo?: RequestInfo; + + /** + * The authentication information. + */ + authInfo?: AuthInfo; +} + /* JSON-RPC types */ export type ProgressToken = Infer; export type Cursor = Infer;