Skip to content

Commit 7a53c30

Browse files
committed
sampling: validate tools, tool_use, tool_result constraints
as per modelcontextprotocol/modelcontextprotocol#1577
1 parent 41c6b35 commit 7a53c30

File tree

2 files changed

+229
-1
lines changed

2 files changed

+229
-1
lines changed

src/server/index.test.ts

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,6 +1582,197 @@ test('should respect log level for transport without sessionId', async () => {
15821582
expect(clientTransport.onmessage).toHaveBeenCalled();
15831583
});
15841584

1585+
describe('createMessage validation', () => {
1586+
test('should throw when tools are provided without sampling.tools capability', async () => {
1587+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1588+
1589+
const client = new Client(
1590+
{ name: 'test client', version: '1.0' },
1591+
{ capabilities: { sampling: {} } } // No tools capability
1592+
);
1593+
1594+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1595+
model: 'test-model',
1596+
role: 'assistant',
1597+
content: { type: 'text', text: 'Response' }
1598+
}));
1599+
1600+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1601+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1602+
1603+
await expect(
1604+
server.createMessage({
1605+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1606+
maxTokens: 100,
1607+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1608+
})
1609+
).rejects.toThrow('Client does not support sampling tools capability.');
1610+
});
1611+
1612+
test('should throw when toolChoice is provided without sampling.tools capability', async () => {
1613+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1614+
1615+
const client = new Client(
1616+
{ name: 'test client', version: '1.0' },
1617+
{ capabilities: { sampling: {} } } // No tools capability
1618+
);
1619+
1620+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1621+
model: 'test-model',
1622+
role: 'assistant',
1623+
content: { type: 'text', text: 'Response' }
1624+
}));
1625+
1626+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1627+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1628+
1629+
await expect(
1630+
server.createMessage({
1631+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1632+
maxTokens: 100,
1633+
toolChoice: { type: 'auto' }
1634+
})
1635+
).rejects.toThrow('Client does not support sampling tools capability.');
1636+
});
1637+
1638+
test('should throw when tool_result is mixed with other content', async () => {
1639+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1640+
1641+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1642+
1643+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1644+
model: 'test-model',
1645+
role: 'assistant',
1646+
content: { type: 'text', text: 'Response' }
1647+
}));
1648+
1649+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1650+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1651+
1652+
await expect(
1653+
server.createMessage({
1654+
messages: [
1655+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1656+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1657+
{
1658+
role: 'user',
1659+
content: [
1660+
{ type: 'tool_result', toolUseId: 'call_1', content: [] },
1661+
{ type: 'text', text: 'mixed content' } // Mixed!
1662+
]
1663+
}
1664+
],
1665+
maxTokens: 100,
1666+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1667+
})
1668+
).rejects.toThrow('The last message must contain only tool_result content if any is present');
1669+
});
1670+
1671+
test('should throw when tool_result has no matching tool_use in previous message', async () => {
1672+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1673+
1674+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1675+
1676+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1677+
model: 'test-model',
1678+
role: 'assistant',
1679+
content: { type: 'text', text: 'Response' }
1680+
}));
1681+
1682+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1683+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1684+
1685+
// tool_result without previous tool_use
1686+
await expect(
1687+
server.createMessage({
1688+
messages: [
1689+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1690+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1691+
],
1692+
maxTokens: 100,
1693+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1694+
})
1695+
).rejects.toThrow('tool_result blocks are not matching any tool_use from the previous message');
1696+
});
1697+
1698+
test('should throw when tool_result IDs do not match tool_use IDs', async () => {
1699+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1700+
1701+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1702+
1703+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1704+
model: 'test-model',
1705+
role: 'assistant',
1706+
content: { type: 'text', text: 'Response' }
1707+
}));
1708+
1709+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1710+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1711+
1712+
await expect(
1713+
server.createMessage({
1714+
messages: [
1715+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1716+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1717+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'wrong_id', content: [] } }
1718+
],
1719+
maxTokens: 100,
1720+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1721+
})
1722+
).rejects.toThrow('ids of tool_result blocks and tool_use blocks from previous message do not match');
1723+
});
1724+
1725+
test('should allow text-only messages with tools (no tool_results)', async () => {
1726+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1727+
1728+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1729+
1730+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1731+
model: 'test-model',
1732+
role: 'assistant',
1733+
content: { type: 'text', text: 'Response' }
1734+
}));
1735+
1736+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1737+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1738+
1739+
await expect(
1740+
server.createMessage({
1741+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
1742+
maxTokens: 100,
1743+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1744+
})
1745+
).resolves.toMatchObject({ model: 'test-model' });
1746+
});
1747+
1748+
test('should allow valid matching tool_result/tool_use IDs', async () => {
1749+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
1750+
1751+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
1752+
1753+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
1754+
model: 'test-model',
1755+
role: 'assistant',
1756+
content: { type: 'text', text: 'Response' }
1757+
}));
1758+
1759+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1760+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
1761+
1762+
await expect(
1763+
server.createMessage({
1764+
messages: [
1765+
{ role: 'user', content: { type: 'text', text: 'hello' } },
1766+
{ role: 'assistant', content: { type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} } },
1767+
{ role: 'user', content: { type: 'tool_result', toolUseId: 'call_1', content: [] } }
1768+
],
1769+
maxTokens: 100,
1770+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
1771+
})
1772+
).resolves.toMatchObject({ model: 'test-model' });
1773+
});
1774+
});
1775+
15851776
test('should respect log level for transport with sessionId', async () => {
15861777
const server = new Server(
15871778
{

src/server/index.ts

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ import {
3030
type ServerRequest,
3131
type ServerResult,
3232
SetLevelRequestSchema,
33-
SUPPORTED_PROTOCOL_VERSIONS
33+
SUPPORTED_PROTOCOL_VERSIONS,
34+
type ToolResultContent,
35+
type ToolUseContent
3436
} from '../types.js';
3537
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
3638
import type { JsonSchemaType, jsonSchemaValidator } from '../validation/types.js';
@@ -326,6 +328,41 @@ export class Server<
326328
}
327329

328330
async createMessage(params: CreateMessageRequest['params'], options?: RequestOptions) {
331+
if (params.tools || params.toolChoice) {
332+
if (!this._clientCapabilities?.sampling?.tools) {
333+
throw new Error('Client does not support sampling tools capability.');
334+
}
335+
336+
const lastMessage = params.messages[params.messages.length - 1];
337+
const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content];
338+
const hasToolResults = lastContent.some(c => c.type === 'tool_result');
339+
340+
const previousMessage = params.messages[params.messages.length - 2];
341+
const previousContent = previousMessage
342+
? Array.isArray(previousMessage.content)
343+
? previousMessage.content
344+
: [previousMessage.content]
345+
: [];
346+
const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use');
347+
348+
if (hasToolResults) {
349+
if (lastContent.some(c => c.type !== 'tool_result')) {
350+
throw new Error('The last message must contain only tool_result content if any is present');
351+
}
352+
if (!hasPreviousToolUse) {
353+
throw new Error('tool_result blocks are not matching any tool_use from the previous message');
354+
}
355+
}
356+
if (hasPreviousToolUse) {
357+
const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id));
358+
const toolResultIds = new Set(
359+
lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId)
360+
);
361+
if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) {
362+
throw new Error('ids of tool_result blocks and tool_use blocks from previous message do not match');
363+
}
364+
}
365+
}
329366
return this.request({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options);
330367
}
331368

0 commit comments

Comments
 (0)