Skip to content

Commit ae7ad85

Browse files
committed
simplify SamplingMessage
1 parent bd36044 commit ae7ad85

File tree

2 files changed

+42
-68
lines changed

2 files changed

+42
-68
lines changed

src/types.test.ts

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import {
99
ToolUseContentSchema,
1010
ToolResultContentSchema,
1111
ToolChoiceSchema,
12-
UserMessageSchema,
13-
AssistantMessageSchema,
1412
SamplingMessageSchema,
1513
CreateMessageRequestSchema,
1614
CreateMessageResultSchema,
@@ -455,18 +453,20 @@ describe('Types', () => {
455453
});
456454
});
457455

458-
describe("UserMessage and AssistantMessage", () => {
456+
describe("SamplingMessage content types", () => {
459457
test("should validate user message with text", () => {
460458
const userMessage = {
461459
role: "user",
462460
content: { type: "text", text: "What's the weather?" }
463461
};
464462

465-
const result = UserMessageSchema.safeParse(userMessage);
463+
const result = SamplingMessageSchema.safeParse(userMessage);
466464
expect(result.success).toBe(true);
467465
if (result.success) {
468466
expect(result.data.role).toBe("user");
469-
expect(result.data.content.type).toBe("text");
467+
if (!Array.isArray(result.data.content)) {
468+
expect(result.data.content.type).toBe("text");
469+
}
470470
}
471471
});
472472

@@ -476,13 +476,13 @@ describe('Types', () => {
476476
content: {
477477
type: "tool_result",
478478
toolUseId: "call_123",
479-
content: { temperature: 72 }
479+
content: []
480480
}
481481
};
482482

483-
const result = UserMessageSchema.safeParse(userMessage);
483+
const result = SamplingMessageSchema.safeParse(userMessage);
484484
expect(result.success).toBe(true);
485-
if (result.success) {
485+
if (result.success && !Array.isArray(result.data.content)) {
486486
expect(result.data.content.type).toBe("tool_result");
487487
}
488488
});
@@ -493,7 +493,7 @@ describe('Types', () => {
493493
content: { type: "text", text: "I'll check the weather for you." }
494494
};
495495

496-
const result = AssistantMessageSchema.safeParse(assistantMessage);
496+
const result = SamplingMessageSchema.safeParse(assistantMessage);
497497
expect(result.success).toBe(true);
498498
if (result.success) {
499499
expect(result.data.role).toBe("assistant");
@@ -511,29 +511,28 @@ describe('Types', () => {
511511
}
512512
};
513513

514-
const result = AssistantMessageSchema.safeParse(assistantMessage);
514+
const result = SamplingMessageSchema.safeParse(assistantMessage);
515515
expect(result.success).toBe(true);
516-
if (result.success) {
516+
if (result.success && !Array.isArray(result.data.content)) {
517517
expect(result.data.content.type).toBe("tool_use");
518518
}
519519
});
520520

521-
test("should fail validation for assistant with tool result", () => {
522-
const invalidMessage = {
521+
test("should validate any content type for any role", () => {
522+
// The simplified schema allows any content type for any role
523+
const assistantWithToolResult = {
523524
role: "assistant",
524525
content: {
525526
type: "tool_result",
526527
toolUseId: "call_123",
527-
content: {}
528+
content: []
528529
}
529530
};
530531

531-
const result = AssistantMessageSchema.safeParse(invalidMessage);
532-
expect(result.success).toBe(false);
533-
});
532+
const result1 = SamplingMessageSchema.safeParse(assistantWithToolResult);
533+
expect(result1.success).toBe(true);
534534

535-
test("should fail validation for user with tool call", () => {
536-
const invalidMessage = {
535+
const userWithToolUse = {
537536
role: "user",
538537
content: {
539538
type: "tool_use",
@@ -543,8 +542,8 @@ describe('Types', () => {
543542
}
544543
};
545544

546-
const result = UserMessageSchema.safeParse(invalidMessage);
547-
expect(result.success).toBe(false);
545+
const result2 = SamplingMessageSchema.safeParse(userWithToolUse);
546+
expect(result2.success).toBe(true);
548547
});
549548
});
550549

src/types.ts

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ export const ToolChoiceSchema = z
11811181
* Controls when tools are used:
11821182
* - "auto": Model decides whether to use tools (default)
11831183
* - "required": Model MUST use at least one tool before completing
1184+
* - "none": Model MUST NOT use any tools
11841185
*/
11851186
mode: z.optional(z.enum(["auto", "required", "none"])),
11861187
/**
@@ -1212,59 +1213,33 @@ export const ToolResultContentSchema = z
12121213
})
12131214
.passthrough();
12141215

1215-
export const UserMessageContentSchema = z.discriminatedUnion("type", [
1216-
TextContentSchema,
1217-
ImageContentSchema,
1218-
AudioContentSchema,
1219-
ToolResultContentSchema,
1220-
]);
1221-
12221216
/**
1223-
* A message from the user (server) in a sampling conversation.
1217+
* Content block types allowed in sampling messages.
1218+
* This includes text, image, audio, tool use requests, and tool results.
12241219
*/
1225-
export const UserMessageSchema = z
1226-
.object({
1227-
role: z.literal("user"),
1228-
content: z.union([UserMessageContentSchema, z.array(UserMessageContentSchema)]),
1229-
/**
1230-
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1231-
* for notes on _meta usage.
1232-
*/
1233-
_meta: z.optional(z.object({}).passthrough()),
1234-
})
1235-
.passthrough();
1236-
1237-
export const AssistantMessageContentSchema = z.discriminatedUnion("type", [
1220+
export const SamplingMessageContentBlockSchema = z.discriminatedUnion("type", [
12381221
TextContentSchema,
12391222
ImageContentSchema,
12401223
AudioContentSchema,
12411224
ToolUseContentSchema,
1225+
ToolResultContentSchema,
12421226
]);
12431227

1244-
/**
1245-
* A message from the assistant (LLM) in a sampling conversation.
1246-
*/
1247-
export const AssistantMessageSchema = z
1248-
.object({
1249-
role: z.literal("assistant"),
1250-
content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]),
1251-
/**
1252-
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1253-
* for notes on _meta usage.
1254-
*/
1255-
_meta: z.optional(z.object({}).passthrough()),
1256-
})
1257-
.passthrough();
1258-
12591228
/**
12601229
* Describes a message issued to or received from an LLM API.
1261-
* This is a discriminated union of UserMessage and AssistantMessage, where
1262-
* each role has its own set of allowed content types.
12631230
*/
1264-
export const SamplingMessageSchema = z.discriminatedUnion("role", [
1265-
UserMessageSchema,
1266-
AssistantMessageSchema,
1267-
]);
1231+
export const SamplingMessageSchema = z.object({
1232+
role: z.enum(['user', 'assistant']),
1233+
content: z.union([
1234+
SamplingMessageContentBlockSchema,
1235+
z.array(SamplingMessageContentBlockSchema)
1236+
]),
1237+
/**
1238+
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1239+
* for notes on _meta usage.
1240+
*/
1241+
_meta: z.optional(z.object({}).passthrough()),
1242+
}).passthrough();
12681243

12691244
/**
12701245
* Parameters for a `sampling/createMessage` request.
@@ -1341,7 +1316,10 @@ export const CreateMessageResultSchema = ResultSchema.extend({
13411316
/**
13421317
* Response content. May be ToolUseContent if stopReason is "toolUse".
13431318
*/
1344-
content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]),
1319+
content: z.union([
1320+
SamplingMessageContentBlockSchema,
1321+
z.array(SamplingMessageContentBlockSchema)
1322+
]),
13451323
});
13461324

13471325
/* Elicitation */
@@ -1798,16 +1776,13 @@ export type LoggingMessageNotification = Infer<typeof LoggingMessageNotification
17981776

17991777
/* Sampling */
18001778
export type ToolChoice = Infer<typeof ToolChoiceSchema>;
1801-
export type UserMessage = Infer<typeof UserMessageSchema>;
1802-
export type AssistantMessage = Infer<typeof AssistantMessageSchema>;
18031779
export type ModelHint = Infer<typeof ModelHintSchema>;
18041780
export type ModelPreferences = Infer<typeof ModelPreferencesSchema>;
1781+
export type SamplingMessageContentBlock = Infer<typeof SamplingMessageContentBlockSchema>;
18051782
export type SamplingMessage = Infer<typeof SamplingMessageSchema>;
18061783
export type CreateMessageRequestParams = Infer<typeof CreateMessageRequestParamsSchema>;
18071784
export type CreateMessageRequest = Infer<typeof CreateMessageRequestSchema>;
18081785
export type CreateMessageResult = Infer<typeof CreateMessageResultSchema>;
1809-
export type AssistantMessageContent = Infer<typeof AssistantMessageContentSchema>;
1810-
export type UserMessageContent = Infer<typeof UserMessageContentSchema>;
18111786

18121787
/* Elicitation */
18131788
export type BooleanSchema = Infer<typeof BooleanSchemaSchema>;

0 commit comments

Comments
 (0)