Skip to content

Commit 1df201f

Browse files
authored
Merge branch 'main' into ochafik/sep-1577-sampling-tools
2 parents 80f3408 + 5bcf53f commit 1df201f

39 files changed

+3054
-3666
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ on:
33
branches:
44
- main
55
pull_request:
6+
workflow_dispatch:
67
release:
78
types: [published]
89

jest.config.js

Lines changed: 0 additions & 14 deletions
This file was deleted.

package-lock.json

Lines changed: 1315 additions & 3048 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@modelcontextprotocol/sdk",
3-
"version": "1.21.1",
3+
"version": "1.22.0",
44
"description": "Model Context Protocol implementation for TypeScript",
55
"license": "MIT",
66
"author": "Anthropic, PBC (https://anthropic.com)",
@@ -71,7 +71,8 @@
7171
"lint": "eslint src/ && prettier --check .",
7272
"lint:fix": "eslint src/ --fix && prettier --write .",
7373
"check": "npm run typecheck && npm run lint",
74-
"test": "jest",
74+
"test": "vitest run",
75+
"test:watch": "vitest",
7576
"start": "npm run server",
7677
"server": "tsx watch --clear-screen=false scripts/cli.ts server",
7778
"client": "tsx scripts/cli.ts client"
@@ -102,27 +103,24 @@
102103
"devDependencies": {
103104
"@cfworker/json-schema": "^4.1.1",
104105
"@eslint/js": "^9.8.0",
105-
"@jest-mock/express": "^3.0.0",
106106
"@types/content-type": "^1.1.8",
107107
"@types/cors": "^2.8.17",
108108
"@types/cross-spawn": "^6.0.6",
109109
"@types/eslint__js": "^8.42.3",
110110
"@types/eventsource": "^1.1.15",
111111
"@types/express": "^5.0.0",
112-
"@types/jest": "^29.5.12",
113112
"@types/node": "^22.0.2",
114113
"@types/supertest": "^6.0.2",
115114
"@types/ws": "^8.5.12",
116115
"@typescript/native-preview": "^7.0.0-dev.20251103.1",
117116
"eslint": "^9.8.0",
118117
"eslint-config-prettier": "^10.1.8",
119-
"jest": "^29.7.0",
120118
"prettier": "3.6.2",
121119
"supertest": "^7.0.0",
122-
"ts-jest": "^29.2.4",
123120
"tsx": "^4.16.5",
124121
"typescript": "^5.5.4",
125122
"typescript-eslint": "^8.0.0",
123+
"vitest": "^4.0.8",
126124
"ws": "^8.18.0"
127125
},
128126
"resolutions": {

src/client/auth.test.ts

Lines changed: 151 additions & 69 deletions
Large diffs are not rendered by default.

src/client/auth.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ async function authInternal(
348348
): Promise<AuthResult> {
349349
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
350350
let authorizationServerUrl: string | URL | undefined;
351+
351352
try {
352353
resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl, { resourceMetadataUrl }, fetchFn);
353354
if (resourceMetadata.authorization_servers && resourceMetadata.authorization_servers.length > 0) {
@@ -359,10 +360,10 @@ async function authInternal(
359360

360361
/**
361362
* If we don't get a valid authorization server metadata from protected resource metadata,
362-
* fallback to the legacy MCP spec's implementation (version 2025-03-26): MCP server acts as the Authorization server.
363+
* fallback to the legacy MCP spec's implementation (version 2025-03-26): MCP server base URL acts as the Authorization server.
363364
*/
364365
if (!authorizationServerUrl) {
365-
authorizationServerUrl = serverUrl;
366+
authorizationServerUrl = new URL('/', serverUrl);
366367
}
367368

368369
const resource: URL | undefined = await selectResourceURL(serverUrl, provider, resourceMetadata);

src/client/cross-spawn.test.ts

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,34 @@ import { StdioClientTransport, getDefaultEnvironment } from './stdio.js';
22
import spawn from 'cross-spawn';
33
import { JSONRPCMessage } from '../types.js';
44
import { ChildProcess } from 'node:child_process';
5+
import { Mock, MockedFunction } from 'vitest';
56

67
// mock cross-spawn
7-
jest.mock('cross-spawn');
8-
const mockSpawn = spawn as jest.MockedFunction<typeof spawn>;
8+
vi.mock('cross-spawn');
9+
const mockSpawn = spawn as unknown as MockedFunction<typeof spawn>;
910

1011
describe('StdioClientTransport using cross-spawn', () => {
1112
beforeEach(() => {
1213
// mock cross-spawn's return value
1314
mockSpawn.mockImplementation(() => {
1415
const mockProcess: {
15-
on: jest.Mock;
16-
stdin?: { on: jest.Mock; write: jest.Mock };
17-
stdout?: { on: jest.Mock };
16+
on: Mock;
17+
stdin?: { on: Mock; write: Mock };
18+
stdout?: { on: Mock };
1819
stderr?: null;
1920
} = {
20-
on: jest.fn((event: string, callback: () => void) => {
21+
on: vi.fn((event: string, callback: () => void) => {
2122
if (event === 'spawn') {
2223
callback();
2324
}
2425
return mockProcess;
2526
}),
2627
stdin: {
27-
on: jest.fn(),
28-
write: jest.fn().mockReturnValue(true)
28+
on: vi.fn(),
29+
write: vi.fn().mockReturnValue(true)
2930
},
3031
stdout: {
31-
on: jest.fn()
32+
on: vi.fn()
3233
},
3334
stderr: null
3435
};
@@ -37,7 +38,7 @@ describe('StdioClientTransport using cross-spawn', () => {
3738
});
3839

3940
afterEach(() => {
40-
jest.clearAllMocks();
41+
vi.clearAllMocks();
4142
});
4243

4344
test('should call cross-spawn correctly', async () => {
@@ -105,30 +106,30 @@ describe('StdioClientTransport using cross-spawn', () => {
105106

106107
// get the mock process object
107108
const mockProcess: {
108-
on: jest.Mock;
109+
on: Mock;
109110
stdin: {
110-
on: jest.Mock;
111-
write: jest.Mock;
112-
once: jest.Mock;
111+
on: Mock;
112+
write: Mock;
113+
once: Mock;
113114
};
114115
stdout: {
115-
on: jest.Mock;
116+
on: Mock;
116117
};
117118
stderr: null;
118119
} = {
119-
on: jest.fn((event: string, callback: () => void) => {
120+
on: vi.fn((event: string, callback: () => void) => {
120121
if (event === 'spawn') {
121122
callback();
122123
}
123124
return mockProcess;
124125
}),
125126
stdin: {
126-
on: jest.fn(),
127-
write: jest.fn().mockReturnValue(true),
128-
once: jest.fn()
127+
on: vi.fn(),
128+
write: vi.fn().mockReturnValue(true),
129+
once: vi.fn()
129130
},
130131
stdout: {
131-
on: jest.fn()
132+
on: vi.fn()
132133
},
133134
stderr: null
134135
};

src/client/index.test.ts

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ import { InMemoryTransport } from '../inMemory.js';
2727
*/
2828
test('should initialize with matching protocol version', async () => {
2929
const clientTransport: Transport = {
30-
start: jest.fn().mockResolvedValue(undefined),
31-
close: jest.fn().mockResolvedValue(undefined),
32-
send: jest.fn().mockImplementation(message => {
30+
start: vi.fn().mockResolvedValue(undefined),
31+
close: vi.fn().mockResolvedValue(undefined),
32+
send: vi.fn().mockImplementation(message => {
3333
if (message.method === 'initialize') {
3434
clientTransport.onmessage?.({
3535
jsonrpc: '2.0',
@@ -86,9 +86,9 @@ test('should initialize with matching protocol version', async () => {
8686
test('should initialize with supported older protocol version', async () => {
8787
const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1];
8888
const clientTransport: Transport = {
89-
start: jest.fn().mockResolvedValue(undefined),
90-
close: jest.fn().mockResolvedValue(undefined),
91-
send: jest.fn().mockImplementation(message => {
89+
start: vi.fn().mockResolvedValue(undefined),
90+
close: vi.fn().mockResolvedValue(undefined),
91+
send: vi.fn().mockImplementation(message => {
9292
if (message.method === 'initialize') {
9393
clientTransport.onmessage?.({
9494
jsonrpc: '2.0',
@@ -136,9 +136,9 @@ test('should initialize with supported older protocol version', async () => {
136136
*/
137137
test('should reject unsupported protocol version', async () => {
138138
const clientTransport: Transport = {
139-
start: jest.fn().mockResolvedValue(undefined),
140-
close: jest.fn().mockResolvedValue(undefined),
141-
send: jest.fn().mockImplementation(message => {
139+
start: vi.fn().mockResolvedValue(undefined),
140+
close: vi.fn().mockResolvedValue(undefined),
141+
send: vi.fn().mockImplementation(message => {
142142
if (message.method === 'initialize') {
143143
clientTransport.onmessage?.({
144144
jsonrpc: '2.0',

src/client/index.ts

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,54 @@ import {
3636
SUPPORTED_PROTOCOL_VERSIONS,
3737
type SubscribeRequest,
3838
type Tool,
39-
type UnsubscribeRequest
39+
type UnsubscribeRequest,
40+
ElicitResultSchema,
41+
ElicitRequestSchema
4042
} from '../types.js';
4143
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
4244
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
45+
import { ZodLiteral, ZodObject, z } from 'zod';
46+
import type { RequestHandlerExtra } from '../shared/protocol.js';
47+
48+
/**
49+
* Elicitation default application helper. Applies defaults to the data based on the schema.
50+
*
51+
* @param schema - The schema to apply defaults to.
52+
* @param data - The data to apply defaults to.
53+
*/
54+
function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void {
55+
if (!schema || data === null || typeof data !== 'object') return;
56+
57+
// Handle object properties
58+
if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') {
59+
const obj = data as Record<string, unknown>;
60+
const props = schema.properties as Record<string, JsonSchemaType & { default?: unknown }>;
61+
for (const key of Object.keys(props)) {
62+
const propSchema = props[key];
63+
// If missing or explicitly undefined, apply default if present
64+
if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) {
65+
obj[key] = propSchema.default;
66+
}
67+
// Recurse into existing nested objects/arrays
68+
if (obj[key] !== undefined) {
69+
applyElicitationDefaults(propSchema, obj[key]);
70+
}
71+
}
72+
}
73+
74+
if (Array.isArray(schema.anyOf)) {
75+
for (const sub of schema.anyOf) {
76+
applyElicitationDefaults(sub, data);
77+
}
78+
}
79+
80+
// Combine schemas
81+
if (Array.isArray(schema.oneOf)) {
82+
for (const sub of schema.oneOf) {
83+
applyElicitationDefaults(sub, data);
84+
}
85+
}
86+
}
4387

4488
export type ClientOptions = ProtocolOptions & {
4589
/**
@@ -141,6 +185,64 @@ export class Client<
141185
this._capabilities = mergeCapabilities(this._capabilities, capabilities);
142186
}
143187

188+
/**
189+
* Override request handler registration to enforce client-side validation for elicitation.
190+
*/
191+
public override setRequestHandler<
192+
T extends ZodObject<{
193+
method: ZodLiteral<string>;
194+
}>
195+
>(
196+
requestSchema: T,
197+
handler: (
198+
request: z.infer<T>,
199+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
200+
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
201+
): void {
202+
const method = requestSchema.shape.method.value;
203+
if (method === 'elicitation/create') {
204+
const wrappedHandler = async (
205+
request: z.infer<T>,
206+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
207+
): Promise<ClientResult | ResultT> => {
208+
const validatedRequest = ElicitRequestSchema.safeParse(request);
209+
if (!validatedRequest.success) {
210+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`);
211+
}
212+
213+
const result = await Promise.resolve(handler(request, extra));
214+
215+
const validationResult = ElicitResultSchema.safeParse(result);
216+
if (!validationResult.success) {
217+
throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`);
218+
}
219+
220+
const validatedResult = validationResult.data;
221+
222+
if (
223+
this._capabilities.elicitation?.applyDefaults &&
224+
validatedResult.action === 'accept' &&
225+
validatedResult.content &&
226+
validatedRequest.data.params.requestedSchema
227+
) {
228+
try {
229+
applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content);
230+
} catch {
231+
// gracefully ignore errors in default application
232+
}
233+
}
234+
235+
return validatedResult;
236+
};
237+
238+
// Install the wrapped handler
239+
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
240+
}
241+
242+
// Non-elicitation handlers use default behavior
243+
return super.setRequestHandler(requestSchema, handler);
244+
}
245+
144246
protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
145247
if (!this._serverCapabilities?.[capability]) {
146248
throw new Error(`Server does not support ${capability} (required for ${method})`);

0 commit comments

Comments
 (0)