diff --git a/src/client/index.ts b/src/client/index.ts index cc6819815..4dc6a20d9 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -36,10 +36,41 @@ import { SUPPORTED_PROTOCOL_VERSIONS, type SubscribeRequest, type Tool, - type UnsubscribeRequest + type UnsubscribeRequest, + ElicitResultSchema, + ElicitRequestSchema } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; +import { ZodLiteral, ZodObject, z } from 'zod'; +import type { RequestHandlerExtra } from '../shared/protocol.js'; + +/** + * Elicitation default application helper. Applies defaults to the data based on the schema. + * + * @param schema - The schema to apply defaults to. + * @param data - The data to apply defaults to. + */ +function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void { + if (!schema || data === null || typeof data !== 'object') return; + + // Handle object properties + if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') { + const obj = data as Record; + const props = schema.properties as Record; + for (const key of Object.keys(props)) { + const propSchema = props[key]; + // If missing or explicitly undefined, apply default if present + if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) { + obj[key] = propSchema.default; + } + // Recurse into existing nested objects/arrays + if (obj[key] !== undefined) { + applyElicitationDefaults(propSchema, obj[key]); + } + } + } +} export type ClientOptions = ProtocolOptions & { /** @@ -141,6 +172,64 @@ export class Client< this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + /** + * Override request handler registration to enforce client-side validation for elicitation. + */ + public override setRequestHandler< + T extends ZodObject<{ + method: ZodLiteral; + }> + >( + requestSchema: T, + handler: ( + request: z.infer, + extra: RequestHandlerExtra + ) => ClientResult | ResultT | Promise + ): void { + const method = requestSchema.shape.method.value; + if (method === 'elicitation/create') { + const wrappedHandler = async ( + request: z.infer, + extra: RequestHandlerExtra + ): Promise => { + const validatedRequest = ElicitRequestSchema.safeParse(request); + if (!validatedRequest.success) { + throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`); + } + + const result = await Promise.resolve(handler(request, extra)); + + const validationResult = ElicitResultSchema.safeParse(result); + if (!validationResult.success) { + throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`); + } + + const validatedResult = validationResult.data; + + if ( + this._capabilities.elicitation?.applyDefaults && + validatedResult.action === 'accept' && + validatedResult.content && + validatedRequest.data.params.requestedSchema + ) { + try { + applyElicitationDefaults(validatedRequest.data.params.requestedSchema, validatedResult.content); + } catch { + // gracefully ignore errors in default application + } + } + + return validatedResult; + }; + + // Install the wrapped handler + return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler); + } + + // Non-elicitation handlers use default behavior + return super.setRequestHandler(requestSchema, handler); + } + protected assertCapability(capability: keyof ServerCapabilities, method: string): void { if (!this._serverCapabilities?.[capability]) { throw new Error(`Server does not support ${capability} (required for ${method})`); diff --git a/src/server/elicitation.test.ts b/src/server/elicitation.test.ts index 67361fc10..671500bc7 100644 --- a/src/server/elicitation.test.ts +++ b/src/server/elicitation.test.ts @@ -609,6 +609,72 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo }); }); + test(`${validatorName}: should default missing fields from schema defaults`, async () => { + const server = new Server( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: {}, + jsonSchemaValidator: validatorProvider + } + ); + + const client = new Client( + { name: 'test-client', version: '1.0.0' }, + { + capabilities: { + elicitation: { + applyDefaults: true + } + } + } + ); + + // Client returns no values; SDK should apply defaults automatically (and validate) + client.setRequestHandler(ElicitRequestSchema, request => { + expect(request.params.requestedSchema).toEqual({ + type: 'object', + properties: { + subscribe: { type: 'boolean', default: true }, + nickname: { type: 'string', default: 'Guest' }, + age: { type: 'integer', minimum: 0, maximum: 150, default: 18 }, + color: { type: 'string', enum: ['red', 'green'], default: 'green' } + }, + required: ['subscribe', 'nickname', 'age', 'color'] + }); + return { + action: 'accept', + content: {} + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.elicitInput({ + message: 'Provide your preferences', + requestedSchema: { + type: 'object', + properties: { + subscribe: { type: 'boolean', default: true }, + nickname: { type: 'string', default: 'Guest' }, + age: { type: 'integer', minimum: 0, maximum: 150, default: 18 }, + color: { type: 'string', enum: ['red', 'green'], default: 'green' } + }, + required: ['subscribe', 'nickname', 'age', 'color'] + } + }); + + expect(result).toEqual({ + action: 'accept', + content: { + subscribe: true, + nickname: 'Guest', + age: 18, + color: 'green' + } + }); + }); + test(`${validatorName}: should reject invalid email format`, async () => { const server = new Server( { name: 'test-server', version: '1.0.0' }, diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts index 653e9522c..05ae3d0ce 100644 --- a/src/spec.types.test.ts +++ b/src/spec.types.test.ts @@ -62,6 +62,19 @@ type MakeUnknownsNotOptional = } : T; +// Targeted fix: in spec, treat ClientCapabilities.elicitation?: object as Record +type FixSpecClientCapabilities = T extends { elicitation?: object } + ? Omit & { elicitation?: Record } + : T; + +type FixSpecInitializeRequestParams = T extends { capabilities: infer C } + ? Omit & { capabilities: FixSpecClientCapabilities } + : T; + +type FixSpecInitializeRequest = T extends { params: infer P } ? Omit & { params: FixSpecInitializeRequestParams

} : T; + +type FixSpecClientRequest = T extends { params: infer P } ? Omit & { params: FixSpecInitializeRequestParams

} : T; + const sdkTypeChecks = { RequestParams: (sdk: SDKTypes.RequestParams, spec: SpecTypes.RequestParams) => { sdk = spec; @@ -75,7 +88,10 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - InitializeRequestParams: (sdk: SDKTypes.InitializeRequestParams, spec: SpecTypes.InitializeRequestParams) => { + InitializeRequestParams: ( + sdk: SDKTypes.InitializeRequestParams, + spec: FixSpecInitializeRequestParams + ) => { sdk = spec; spec = sdk; }, @@ -480,7 +496,10 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - InitializeRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.InitializeRequest) => { + InitializeRequest: ( + sdk: WithJSONRPCRequest, + spec: FixSpecInitializeRequest + ) => { sdk = spec; spec = sdk; }, @@ -488,7 +507,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: SpecTypes.ClientCapabilities) => { + ClientCapabilities: (sdk: SDKTypes.ClientCapabilities, spec: FixSpecClientCapabilities) => { sdk = spec; spec = sdk; }, @@ -496,7 +515,10 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - ClientRequest: (sdk: RemovePassthrough>, spec: SpecTypes.ClientRequest) => { + ClientRequest: ( + sdk: RemovePassthrough>, + spec: FixSpecClientRequest + ) => { sdk = spec; spec = sdk; }, diff --git a/src/types.ts b/src/types.ts index 8b695a73b..9e9ec3d31 100644 --- a/src/types.ts +++ b/src/types.ts @@ -286,7 +286,17 @@ export const ClientCapabilitiesSchema = z.object({ /** * Present if the client supports eliciting user input. */ - elicitation: AssertObjectSchema.optional(), + elicitation: z.intersection( + z + .object({ + /** + * Whether the client should apply defaults to the user input. + */ + applyDefaults: z.boolean().optional() + }) + .optional(), + z.record(z.string(), z.unknown()).optional() + ), /** * Present if the client supports listing roots. */ @@ -1198,9 +1208,9 @@ export const CreateMessageResultSchema = ResultSchema.extend({ */ export const BooleanSchemaSchema = z.object({ type: z.literal('boolean'), - title: z.optional(z.string()), - description: z.optional(z.string()), - default: z.optional(z.boolean()) + title: z.string().optional(), + description: z.string().optional(), + default: z.boolean().optional() }); /** @@ -1208,11 +1218,12 @@ export const BooleanSchemaSchema = z.object({ */ export const StringSchemaSchema = z.object({ type: z.literal('string'), - title: z.optional(z.string()), - description: z.optional(z.string()), - minLength: z.optional(z.number()), - maxLength: z.optional(z.number()), - format: z.optional(z.enum(['email', 'uri', 'date', 'date-time'])) + title: z.string().optional(), + description: z.string().optional(), + minLength: z.number().optional(), + maxLength: z.number().optional(), + format: z.enum(['email', 'uri', 'date', 'date-time']).optional(), + default: z.string().optional() }); /** @@ -1220,10 +1231,11 @@ export const StringSchemaSchema = z.object({ */ export const NumberSchemaSchema = z.object({ type: z.enum(['number', 'integer']), - title: z.optional(z.string()), - description: z.optional(z.string()), - minimum: z.optional(z.number()), - maximum: z.optional(z.number()) + title: z.string().optional(), + description: z.string().optional(), + minimum: z.number().optional(), + maximum: z.number().optional(), + default: z.number().optional() }); /** @@ -1231,16 +1243,17 @@ export const NumberSchemaSchema = z.object({ */ export const EnumSchemaSchema = z.object({ type: z.literal('string'), - title: z.optional(z.string()), - description: z.optional(z.string()), + title: z.string().optional(), + description: z.string().optional(), enum: z.array(z.string()), - enumNames: z.optional(z.array(z.string())) + enumNames: z.array(z.string()).optional(), + default: z.string().optional() }); /** * Union of all primitive schema definitions. */ -export const PrimitiveSchemaDefinitionSchema = z.union([BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema, EnumSchemaSchema]); +export const PrimitiveSchemaDefinitionSchema = z.union([EnumSchemaSchema, BooleanSchemaSchema, StringSchemaSchema, NumberSchemaSchema]); /** * Parameters for an `elicitation/create` request.