diff --git a/src/client/index.ts b/src/client/index.ts index cc6819815..4bb94a251 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,4 +1,4 @@ -import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions } from '../shared/protocol.js'; +import { mergeCapabilities, Protocol, type ProtocolOptions, type RequestOptions, type RequestHandlerExtra } from '../shared/protocol.js'; import type { Transport } from '../shared/transport.js'; import { type CallToolRequest, @@ -10,6 +10,8 @@ import { type CompatibilityCallToolResultSchema, type CompleteRequest, CompleteResultSchema, + ElicitRequestSchema, + type ElicitResult, EmptyResultSchema, ErrorCode, type GetPromptRequest, @@ -40,6 +42,7 @@ import { } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; +import { ZodObject, ZodLiteral, z } from 'zod'; export type ClientOptions = ProtocolOptions & { /** @@ -437,4 +440,71 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: 'notifications/roots/list_changed' }); } + + /** + * Override setRequestHandler to automatically apply defaults for elicitation responses. + * When a handler is registered for ElicitRequestSchema, it wraps the handler to apply + * defaults from the schema before returning the response. + */ + override setRequestHandler< + T extends ZodObject<{ + method: ZodLiteral; + }> + >( + requestSchema: T, + handler: ( + request: z.infer, + extra: RequestHandlerExtra + ) => ClientResult | ResultT | Promise + ): void { + const method = requestSchema.shape.method.value; + + // Special handling for elicitation requests to apply defaults + if (method === 'elicitation/create') { + const wrappedHandler = async ( + request: z.infer, + extra: RequestHandlerExtra + ): Promise => { + // Call the original handler + const result = (await handler(request as z.infer, extra)) as ElicitResult; + + // Only apply defaults if action is 'accept' and content exists + if (result.action === 'accept' && result.content) { + // Convert requestedSchema to JSON Schema format for validation + const jsonSchema = { + type: 'object' as const, + properties: request.params.requestedSchema.properties, + required: request.params.requestedSchema.required + }; + + try { + // Get validator which will apply defaults during validation + const validator = this._jsonSchemaValidator.getValidator(jsonSchema); + const validationResult = validator(result.content); + + if (!validationResult.valid) { + throw new McpError( + ErrorCode.InvalidParams, + `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` + ); + } + } catch (error) { + if (error instanceof McpError) { + throw error; + } + // If validation fails, log but don't block - defaults were applied in-place + // This handles edge cases where schema might not perfectly match + } + } + + return result; + }; + + // Register the wrapped handler using the parent's setRequestHandler + super.setRequestHandler(ElicitRequestSchema, wrappedHandler); + } else { + // For all other request types, use default behavior + super.setRequestHandler(requestSchema, handler); + } + } } diff --git a/src/server/elicitation.test.ts b/src/server/elicitation.test.ts index 67361fc10..6147f3b34 100644 --- a/src/server/elicitation.test.ts +++ b/src/server/elicitation.test.ts @@ -609,6 +609,51 @@ function testElicitationFlow(validatorProvider: typeof ajvProvider | typeof cfWo }); }); + test(`${validatorName}: should default missing fields from schema defaults`, async () => { + const server = new Server( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: {}, + jsonSchemaValidator: validatorProvider + } + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }, { capabilities: { elicitation: {} } }); + + // Client returns no values; SDK should apply defaults automatically (and validate) + client.setRequestHandler(ElicitRequestSchema, _request => ({ + action: 'accept', + content: {} + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.elicitInput({ + message: 'Provide your preferences', + requestedSchema: { + type: 'object', + properties: { + subscribe: { type: 'boolean', default: true }, + nickname: { type: 'string', default: 'Guest' }, + age: { type: 'integer', minimum: 0, maximum: 150, default: 18 }, + color: { type: 'string', enum: ['red', 'green'], default: 'green' } + }, + required: ['subscribe', 'nickname', 'age', 'color'] + } + }); + + expect(result).toEqual({ + action: 'accept', + content: { + subscribe: true, + nickname: 'Guest', + age: 18, + color: 'green' + } + }); + }); + test(`${validatorName}: should reject invalid email format`, async () => { const server = new Server( { name: 'test-server', version: '1.0.0' }, diff --git a/src/types.ts b/src/types.ts index 8b695a73b..45c8f35b8 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1198,9 +1198,9 @@ export const CreateMessageResultSchema = ResultSchema.extend({ */ export const BooleanSchemaSchema = z.object({ type: z.literal('boolean'), - title: z.optional(z.string()), - description: z.optional(z.string()), - default: z.optional(z.boolean()) + title: z.string().optional(), + description: z.string().optional(), + default: z.boolean().optional() }); /** @@ -1208,11 +1208,12 @@ export const BooleanSchemaSchema = z.object({ */ export const StringSchemaSchema = z.object({ type: z.literal('string'), - title: z.optional(z.string()), - description: z.optional(z.string()), - minLength: z.optional(z.number()), - maxLength: z.optional(z.number()), - format: z.optional(z.enum(['email', 'uri', 'date', 'date-time'])) + title: z.string().optional(), + description: z.string().optional(), + minLength: z.number().optional(), + maxLength: z.number().optional(), + format: z.enum(['email', 'uri', 'date', 'date-time']).optional(), + default: z.string().optional() }); /** @@ -1220,10 +1221,11 @@ export const StringSchemaSchema = z.object({ */ export const NumberSchemaSchema = z.object({ type: z.enum(['number', 'integer']), - title: z.optional(z.string()), - description: z.optional(z.string()), - minimum: z.optional(z.number()), - maximum: z.optional(z.number()) + title: z.string().optional(), + description: z.string().optional(), + minimum: z.number().optional(), + maximum: z.number().optional(), + default: z.number().optional() }); /** @@ -1231,10 +1233,11 @@ export const NumberSchemaSchema = z.object({ */ export const EnumSchemaSchema = z.object({ type: z.literal('string'), - title: z.optional(z.string()), - description: z.optional(z.string()), + title: z.string().optional(), + description: z.string().optional(), enum: z.array(z.string()), - enumNames: z.optional(z.array(z.string())) + enumNames: z.array(z.string()).optional(), + default: z.string().optional() }); /** diff --git a/src/validation/ajv-provider.ts b/src/validation/ajv-provider.ts index 115a98521..c980500b6 100644 --- a/src/validation/ajv-provider.ts +++ b/src/validation/ajv-provider.ts @@ -11,7 +11,8 @@ function createDefaultAjvInstance(): Ajv { strict: false, validateFormats: true, validateSchema: false, - allErrors: true + allErrors: true, + useDefaults: true }); const addFormats = _addFormats as unknown as typeof _addFormats.default; diff --git a/src/validation/cfworker-provider.ts b/src/validation/cfworker-provider.ts index 60ec3f06e..5cdc6db97 100644 --- a/src/validation/cfworker-provider.ts +++ b/src/validation/cfworker-provider.ts @@ -10,6 +10,68 @@ import { type Schema, Validator } from '@cfworker/json-schema'; import type { JsonSchemaType, JsonSchemaValidator, JsonSchemaValidatorResult, jsonSchemaValidator } from './types.js'; +/** + * Apply JSON Schema defaults to the provided data in-place. + * This performs a best-effort traversal covering common constructs: + * - type: "object" with "properties" + * - type: "array" with "items" + * - allOf / anyOf / oneOf (applies defaults from each sub-schema) + * + * It intentionally does not attempt full $ref resolution or advanced constructs, + * which are not needed for the MCP elicitation top-level schemas. + */ +function applyDefaults(schema: Schema | undefined, data: unknown): void { + if (!schema || data === null || typeof data !== 'object') return; + + // Handle object properties + if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') { + const obj = data as Record; + const props = schema.properties as Record; + for (const key of Object.keys(props)) { + const propSchema = props[key]; + // If missing or explicitly undefined, apply default if present + if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) { + obj[key] = propSchema.default; + } + // Recurse into existing nested objects/arrays + if (obj[key] !== undefined) { + applyDefaults(propSchema, obj[key]); + } + } + } + + // Handle arrays + if (schema.type === 'array' && Array.isArray(data) && schema.items) { + const itemsSchema = schema.items as Schema | Schema[]; + if (Array.isArray(itemsSchema)) { + for (let i = 0; i < data.length && i < itemsSchema.length; i++) { + applyDefaults(itemsSchema[i], data[i]); + } + } else { + for (const item of data) { + applyDefaults(itemsSchema, item); + } + } + } + + // Combine schemas + if (Array.isArray(schema.allOf)) { + for (const sub of schema.allOf) { + applyDefaults(sub, data); + } + } + if (Array.isArray(schema.anyOf)) { + for (const sub of schema.anyOf) { + applyDefaults(sub, data); + } + } + if (Array.isArray(schema.oneOf)) { + for (const sub of schema.oneOf) { + applyDefaults(sub, data); + } + } +} + /** * JSON Schema draft version supported by @cfworker/json-schema */ @@ -58,6 +120,13 @@ export class CfWorkerJsonSchemaValidator implements jsonSchemaValidator { const validator = new Validator(cfSchema, this.draft, this.shortcircuit); return (input: unknown): JsonSchemaValidatorResult => { + // Mirror AJV's useDefaults behavior by applying defaults before validation. + try { + applyDefaults(cfSchema, input); + } catch { + // Best-effort only; ignore errors in default application + } + const result = validator.validate(input); if (result.valid) {