Skip to content

Commit d8a3924

Browse files
ochafikbhosmer-ant
authored andcommitted
simplify SamplingMessage
1 parent 7439cc0 commit d8a3924

File tree

2 files changed

+42
-68
lines changed

2 files changed

+42
-68
lines changed

src/types.test.ts

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ import {
99
ToolUseContentSchema,
1010
ToolResultContentSchema,
1111
ToolChoiceSchema,
12-
UserMessageSchema,
13-
AssistantMessageSchema,
1412
SamplingMessageSchema,
1513
CreateMessageRequestSchema,
1614
CreateMessageResultSchema,
@@ -455,18 +453,20 @@ describe('Types', () => {
455453
});
456454
});
457455

458-
describe("UserMessage and AssistantMessage", () => {
456+
describe("SamplingMessage content types", () => {
459457
test("should validate user message with text", () => {
460458
const userMessage = {
461459
role: "user",
462460
content: { type: "text", text: "What's the weather?" }
463461
};
464462

465-
const result = UserMessageSchema.safeParse(userMessage);
463+
const result = SamplingMessageSchema.safeParse(userMessage);
466464
expect(result.success).toBe(true);
467465
if (result.success) {
468466
expect(result.data.role).toBe("user");
469-
expect(result.data.content.type).toBe("text");
467+
if (!Array.isArray(result.data.content)) {
468+
expect(result.data.content.type).toBe("text");
469+
}
470470
}
471471
});
472472

@@ -476,13 +476,13 @@ describe('Types', () => {
476476
content: {
477477
type: "tool_result",
478478
toolUseId: "call_123",
479-
content: { temperature: 72 }
479+
content: []
480480
}
481481
};
482482

483-
const result = UserMessageSchema.safeParse(userMessage);
483+
const result = SamplingMessageSchema.safeParse(userMessage);
484484
expect(result.success).toBe(true);
485-
if (result.success) {
485+
if (result.success && !Array.isArray(result.data.content)) {
486486
expect(result.data.content.type).toBe("tool_result");
487487
}
488488
});
@@ -493,7 +493,7 @@ describe('Types', () => {
493493
content: { type: "text", text: "I'll check the weather for you." }
494494
};
495495

496-
const result = AssistantMessageSchema.safeParse(assistantMessage);
496+
const result = SamplingMessageSchema.safeParse(assistantMessage);
497497
expect(result.success).toBe(true);
498498
if (result.success) {
499499
expect(result.data.role).toBe("assistant");
@@ -511,29 +511,28 @@ describe('Types', () => {
511511
}
512512
};
513513

514-
const result = AssistantMessageSchema.safeParse(assistantMessage);
514+
const result = SamplingMessageSchema.safeParse(assistantMessage);
515515
expect(result.success).toBe(true);
516-
if (result.success) {
516+
if (result.success && !Array.isArray(result.data.content)) {
517517
expect(result.data.content.type).toBe("tool_use");
518518
}
519519
});
520520

521-
test("should fail validation for assistant with tool result", () => {
522-
const invalidMessage = {
521+
test("should validate any content type for any role", () => {
522+
// The simplified schema allows any content type for any role
523+
const assistantWithToolResult = {
523524
role: "assistant",
524525
content: {
525526
type: "tool_result",
526527
toolUseId: "call_123",
527-
content: {}
528+
content: []
528529
}
529530
};
530531

531-
const result = AssistantMessageSchema.safeParse(invalidMessage);
532-
expect(result.success).toBe(false);
533-
});
532+
const result1 = SamplingMessageSchema.safeParse(assistantWithToolResult);
533+
expect(result1.success).toBe(true);
534534

535-
test("should fail validation for user with tool call", () => {
536-
const invalidMessage = {
535+
const userWithToolUse = {
537536
role: "user",
538537
content: {
539538
type: "tool_use",
@@ -543,8 +542,8 @@ describe('Types', () => {
543542
}
544543
};
545544

546-
const result = UserMessageSchema.safeParse(invalidMessage);
547-
expect(result.success).toBe(false);
545+
const result2 = SamplingMessageSchema.safeParse(userWithToolUse);
546+
expect(result2.success).toBe(true);
548547
});
549548
});
550549

src/types.ts

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ export const ToolChoiceSchema = z
12091209
* Controls when tools are used:
12101210
* - "auto": Model decides whether to use tools (default)
12111211
* - "required": Model MUST use at least one tool before completing
1212+
* - "none": Model MUST NOT use any tools
12121213
*/
12131214
mode: z.optional(z.enum(["auto", "required", "none"])),
12141215
/**
@@ -1240,59 +1241,33 @@ export const ToolResultContentSchema = z
12401241
})
12411242
.passthrough();
12421243

1243-
export const UserMessageContentSchema = z.discriminatedUnion("type", [
1244-
TextContentSchema,
1245-
ImageContentSchema,
1246-
AudioContentSchema,
1247-
ToolResultContentSchema,
1248-
]);
1249-
12501244
/**
1251-
* A message from the user (server) in a sampling conversation.
1245+
* Content block types allowed in sampling messages.
1246+
* This includes text, image, audio, tool use requests, and tool results.
12521247
*/
1253-
export const UserMessageSchema = z
1254-
.object({
1255-
role: z.literal("user"),
1256-
content: z.union([UserMessageContentSchema, z.array(UserMessageContentSchema)]),
1257-
/**
1258-
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1259-
* for notes on _meta usage.
1260-
*/
1261-
_meta: z.optional(z.object({}).passthrough()),
1262-
})
1263-
.passthrough();
1264-
1265-
export const AssistantMessageContentSchema = z.discriminatedUnion("type", [
1248+
export const SamplingMessageContentBlockSchema = z.discriminatedUnion("type", [
12661249
TextContentSchema,
12671250
ImageContentSchema,
12681251
AudioContentSchema,
12691252
ToolUseContentSchema,
1253+
ToolResultContentSchema,
12701254
]);
12711255

1272-
/**
1273-
* A message from the assistant (LLM) in a sampling conversation.
1274-
*/
1275-
export const AssistantMessageSchema = z
1276-
.object({
1277-
role: z.literal("assistant"),
1278-
content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]),
1279-
/**
1280-
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1281-
* for notes on _meta usage.
1282-
*/
1283-
_meta: z.optional(z.object({}).passthrough()),
1284-
})
1285-
.passthrough();
1286-
12871256
/**
12881257
* Describes a message issued to or received from an LLM API.
1289-
* This is a discriminated union of UserMessage and AssistantMessage, where
1290-
* each role has its own set of allowed content types.
12911258
*/
1292-
export const SamplingMessageSchema = z.discriminatedUnion("role", [
1293-
UserMessageSchema,
1294-
AssistantMessageSchema,
1295-
]);
1259+
export const SamplingMessageSchema = z.object({
1260+
role: z.enum(['user', 'assistant']),
1261+
content: z.union([
1262+
SamplingMessageContentBlockSchema,
1263+
z.array(SamplingMessageContentBlockSchema)
1264+
]),
1265+
/**
1266+
* See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields)
1267+
* for notes on _meta usage.
1268+
*/
1269+
_meta: z.optional(z.object({}).passthrough()),
1270+
}).passthrough();
12961271

12971272
/**
12981273
* Parameters for a `sampling/createMessage` request.
@@ -1369,7 +1344,10 @@ export const CreateMessageResultSchema = ResultSchema.extend({
13691344
/**
13701345
* Response content. May be ToolUseContent if stopReason is "toolUse".
13711346
*/
1372-
content: z.union([AssistantMessageContentSchema, z.array(AssistantMessageContentSchema)]),
1347+
content: z.union([
1348+
SamplingMessageContentBlockSchema,
1349+
z.array(SamplingMessageContentBlockSchema)
1350+
]),
13731351
});
13741352

13751353
/* Elicitation */
@@ -1994,16 +1972,13 @@ export type LoggingMessageNotification = Infer<typeof LoggingMessageNotification
19941972

19951973
/* Sampling */
19961974
export type ToolChoice = Infer<typeof ToolChoiceSchema>;
1997-
export type UserMessage = Infer<typeof UserMessageSchema>;
1998-
export type AssistantMessage = Infer<typeof AssistantMessageSchema>;
19991975
export type ModelHint = Infer<typeof ModelHintSchema>;
20001976
export type ModelPreferences = Infer<typeof ModelPreferencesSchema>;
1977+
export type SamplingMessageContentBlock = Infer<typeof SamplingMessageContentBlockSchema>;
20011978
export type SamplingMessage = Infer<typeof SamplingMessageSchema>;
20021979
export type CreateMessageRequestParams = Infer<typeof CreateMessageRequestParamsSchema>;
20031980
export type CreateMessageRequest = Infer<typeof CreateMessageRequestSchema>;
20041981
export type CreateMessageResult = Infer<typeof CreateMessageResultSchema>;
2005-
export type AssistantMessageContent = Infer<typeof AssistantMessageContentSchema>;
2006-
export type UserMessageContent = Infer<typeof UserMessageContentSchema>;
20071982

20081983
/* Elicitation */
20091984
export type BooleanSchema = Infer<typeof BooleanSchemaSchema>;

0 commit comments

Comments
 (0)