Skip to content

Commit 36d9b3e

Browse files
ochafikclaude
andcommitted
SEP-1577: Validate tools/toolChoice requires sampling.tools capability
The client MUST return an error if tools or toolChoice are provided but ClientCapabilities.sampling.tools is not declared. This validation is implemented in the client's setRequestHandler override, following the same pattern used for elicitation validation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 41c6b35 commit 36d9b3e

File tree

2 files changed

+351
-1
lines changed

2 files changed

+351
-1
lines changed

src/client/index.test.ts

Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,3 +1668,321 @@ describe('getSupportedElicitationModes', () => {
16681668
expect(result.supportsUrlMode).toBe(false);
16691669
});
16701670
});
1671+
1672+
// SEP-1577: Sampling tools capability validation tests
1673+
test('should reject sampling/createMessage with tools when client lacks sampling.tools capability', async () => {
1674+
const client = new Client(
1675+
{
1676+
name: 'test-client',
1677+
version: '1.0.0'
1678+
},
1679+
{
1680+
capabilities: {
1681+
sampling: {} // No tools sub-capability
1682+
}
1683+
}
1684+
);
1685+
1686+
const handler = vi.fn().mockResolvedValue({
1687+
model: 'test-model',
1688+
role: 'assistant',
1689+
content: { type: 'text', text: 'Response' }
1690+
});
1691+
client.setRequestHandler(CreateMessageRequestSchema, handler);
1692+
1693+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1694+
1695+
let resolveResponse: ((message: unknown) => void) | undefined;
1696+
const responsePromise = new Promise<unknown>(resolve => {
1697+
resolveResponse = resolve;
1698+
});
1699+
1700+
serverTransport.onmessage = async message => {
1701+
if ('method' in message) {
1702+
if (message.method === 'initialize') {
1703+
if (!('id' in message) || message.id === undefined) {
1704+
throw new Error('Expected initialize request to include an id');
1705+
}
1706+
const messageId = message.id;
1707+
await serverTransport.send({
1708+
jsonrpc: '2.0',
1709+
id: messageId,
1710+
result: {
1711+
protocolVersion: LATEST_PROTOCOL_VERSION,
1712+
capabilities: {},
1713+
serverInfo: {
1714+
name: 'test-server',
1715+
version: '1.0.0'
1716+
}
1717+
}
1718+
});
1719+
} else if (message.method === 'notifications/initialized') {
1720+
// ignore
1721+
}
1722+
} else {
1723+
resolveResponse?.(message);
1724+
}
1725+
};
1726+
1727+
await client.connect(clientTransport);
1728+
1729+
// Server sends request with tools despite client not advertising sampling.tools
1730+
const requestId = 1;
1731+
await serverTransport.send({
1732+
jsonrpc: '2.0',
1733+
id: requestId,
1734+
method: 'sampling/createMessage',
1735+
params: {
1736+
messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
1737+
maxTokens: 100,
1738+
tools: [{ name: 'test_tool', inputSchema: { type: 'object', properties: {} } }]
1739+
}
1740+
});
1741+
1742+
const response = (await responsePromise) as { id: number; error: { code: number; message: string } };
1743+
1744+
expect(response.id).toBe(requestId);
1745+
expect(response.error.code).toBe(ErrorCode.InvalidParams);
1746+
expect(response.error.message).toContain('Client does not support sampling with tools');
1747+
expect(response.error.message).toContain('tools');
1748+
expect(handler).not.toHaveBeenCalled();
1749+
1750+
await client.close();
1751+
});
1752+
1753+
test('should reject sampling/createMessage with toolChoice when client lacks sampling.tools capability', async () => {
1754+
const client = new Client(
1755+
{
1756+
name: 'test-client',
1757+
version: '1.0.0'
1758+
},
1759+
{
1760+
capabilities: {
1761+
sampling: {} // No tools sub-capability
1762+
}
1763+
}
1764+
);
1765+
1766+
const handler = vi.fn().mockResolvedValue({
1767+
model: 'test-model',
1768+
role: 'assistant',
1769+
content: { type: 'text', text: 'Response' }
1770+
});
1771+
client.setRequestHandler(CreateMessageRequestSchema, handler);
1772+
1773+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1774+
1775+
let resolveResponse: ((message: unknown) => void) | undefined;
1776+
const responsePromise = new Promise<unknown>(resolve => {
1777+
resolveResponse = resolve;
1778+
});
1779+
1780+
serverTransport.onmessage = async message => {
1781+
if ('method' in message) {
1782+
if (message.method === 'initialize') {
1783+
if (!('id' in message) || message.id === undefined) {
1784+
throw new Error('Expected initialize request to include an id');
1785+
}
1786+
const messageId = message.id;
1787+
await serverTransport.send({
1788+
jsonrpc: '2.0',
1789+
id: messageId,
1790+
result: {
1791+
protocolVersion: LATEST_PROTOCOL_VERSION,
1792+
capabilities: {},
1793+
serverInfo: {
1794+
name: 'test-server',
1795+
version: '1.0.0'
1796+
}
1797+
}
1798+
});
1799+
} else if (message.method === 'notifications/initialized') {
1800+
// ignore
1801+
}
1802+
} else {
1803+
resolveResponse?.(message);
1804+
}
1805+
};
1806+
1807+
await client.connect(clientTransport);
1808+
1809+
// Server sends request with toolChoice despite client not advertising sampling.tools
1810+
const requestId = 2;
1811+
await serverTransport.send({
1812+
jsonrpc: '2.0',
1813+
id: requestId,
1814+
method: 'sampling/createMessage',
1815+
params: {
1816+
messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
1817+
maxTokens: 100,
1818+
toolChoice: { mode: 'required' }
1819+
}
1820+
});
1821+
1822+
const response = (await responsePromise) as { id: number; error: { code: number; message: string } };
1823+
1824+
expect(response.id).toBe(requestId);
1825+
expect(response.error.code).toBe(ErrorCode.InvalidParams);
1826+
expect(response.error.message).toContain('Client does not support sampling with tools');
1827+
expect(response.error.message).toContain('toolChoice');
1828+
expect(handler).not.toHaveBeenCalled();
1829+
1830+
await client.close();
1831+
});
1832+
1833+
test('should accept sampling/createMessage with tools when client has sampling.tools capability', async () => {
1834+
const client = new Client(
1835+
{
1836+
name: 'test-client',
1837+
version: '1.0.0'
1838+
},
1839+
{
1840+
capabilities: {
1841+
sampling: {
1842+
tools: {} // Has tools capability
1843+
}
1844+
}
1845+
}
1846+
);
1847+
1848+
const handler = vi.fn().mockResolvedValue({
1849+
model: 'test-model',
1850+
role: 'assistant',
1851+
content: { type: 'tool_use', name: 'test_tool', id: 'call_1', input: {} },
1852+
stopReason: 'toolUse'
1853+
});
1854+
client.setRequestHandler(CreateMessageRequestSchema, handler);
1855+
1856+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1857+
1858+
let resolveResponse: ((message: unknown) => void) | undefined;
1859+
const responsePromise = new Promise<unknown>(resolve => {
1860+
resolveResponse = resolve;
1861+
});
1862+
1863+
serverTransport.onmessage = async message => {
1864+
if ('method' in message) {
1865+
if (message.method === 'initialize') {
1866+
if (!('id' in message) || message.id === undefined) {
1867+
throw new Error('Expected initialize request to include an id');
1868+
}
1869+
const messageId = message.id;
1870+
await serverTransport.send({
1871+
jsonrpc: '2.0',
1872+
id: messageId,
1873+
result: {
1874+
protocolVersion: LATEST_PROTOCOL_VERSION,
1875+
capabilities: {},
1876+
serverInfo: {
1877+
name: 'test-server',
1878+
version: '1.0.0'
1879+
}
1880+
}
1881+
});
1882+
} else if (message.method === 'notifications/initialized') {
1883+
// ignore
1884+
}
1885+
} else {
1886+
resolveResponse?.(message);
1887+
}
1888+
};
1889+
1890+
await client.connect(clientTransport);
1891+
1892+
const requestId = 3;
1893+
await serverTransport.send({
1894+
jsonrpc: '2.0',
1895+
id: requestId,
1896+
method: 'sampling/createMessage',
1897+
params: {
1898+
messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
1899+
maxTokens: 100,
1900+
tools: [{ name: 'test_tool', inputSchema: { type: 'object', properties: {} } }]
1901+
}
1902+
});
1903+
1904+
const response = (await responsePromise) as { id: number; result: unknown };
1905+
1906+
expect(response.id).toBe(requestId);
1907+
expect(response.result).toBeDefined();
1908+
expect(handler).toHaveBeenCalled();
1909+
1910+
await client.close();
1911+
});
1912+
1913+
test('should accept sampling/createMessage without tools when client lacks sampling.tools capability', async () => {
1914+
const client = new Client(
1915+
{
1916+
name: 'test-client',
1917+
version: '1.0.0'
1918+
},
1919+
{
1920+
capabilities: {
1921+
sampling: {} // No tools sub-capability, but not needed for this request
1922+
}
1923+
}
1924+
);
1925+
1926+
const handler = vi.fn().mockResolvedValue({
1927+
model: 'test-model',
1928+
role: 'assistant',
1929+
content: { type: 'text', text: 'Response' }
1930+
});
1931+
client.setRequestHandler(CreateMessageRequestSchema, handler);
1932+
1933+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
1934+
1935+
let resolveResponse: ((message: unknown) => void) | undefined;
1936+
const responsePromise = new Promise<unknown>(resolve => {
1937+
resolveResponse = resolve;
1938+
});
1939+
1940+
serverTransport.onmessage = async message => {
1941+
if ('method' in message) {
1942+
if (message.method === 'initialize') {
1943+
if (!('id' in message) || message.id === undefined) {
1944+
throw new Error('Expected initialize request to include an id');
1945+
}
1946+
const messageId = message.id;
1947+
await serverTransport.send({
1948+
jsonrpc: '2.0',
1949+
id: messageId,
1950+
result: {
1951+
protocolVersion: LATEST_PROTOCOL_VERSION,
1952+
capabilities: {},
1953+
serverInfo: {
1954+
name: 'test-server',
1955+
version: '1.0.0'
1956+
}
1957+
}
1958+
});
1959+
} else if (message.method === 'notifications/initialized') {
1960+
// ignore
1961+
}
1962+
} else {
1963+
resolveResponse?.(message);
1964+
}
1965+
};
1966+
1967+
await client.connect(clientTransport);
1968+
1969+
// Request without tools/toolChoice should succeed
1970+
const requestId = 4;
1971+
await serverTransport.send({
1972+
jsonrpc: '2.0',
1973+
id: requestId,
1974+
method: 'sampling/createMessage',
1975+
params: {
1976+
messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }],
1977+
maxTokens: 100
1978+
}
1979+
});
1980+
1981+
const response = (await responsePromise) as { id: number; result: unknown };
1982+
1983+
expect(response.id).toBe(requestId);
1984+
expect(response.result).toBeDefined();
1985+
expect(handler).toHaveBeenCalled();
1986+
1987+
await client.close();
1988+
});

src/client/index.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ import {
3838
type Tool,
3939
type UnsubscribeRequest,
4040
ElicitResultSchema,
41-
ElicitRequestSchema
41+
ElicitRequestSchema,
42+
CreateMessageRequestSchema
4243
} from '../types.js';
4344
import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js';
4445
import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js';
@@ -307,6 +308,37 @@ export class Client<
307308
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
308309
}
309310

311+
// SEP-1577: Validate tools/toolChoice require sampling.tools capability
312+
if (method === 'sampling/createMessage') {
313+
const wrappedHandler = async (
314+
request: SchemaOutput<T>,
315+
extra: RequestHandlerExtra<ClientRequest | RequestT, ClientNotification | NotificationT>
316+
): Promise<ClientResult | ResultT> => {
317+
const validatedRequest = safeParse(CreateMessageRequestSchema, request);
318+
if (!validatedRequest.success) {
319+
const errorMessage =
320+
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
321+
throw new McpError(ErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`);
322+
}
323+
324+
const { params } = validatedRequest.data;
325+
326+
// Check if tools or toolChoice are provided but capability not declared
327+
if ((params.tools !== undefined || params.toolChoice !== undefined) && !this._capabilities.sampling?.tools) {
328+
throw new McpError(
329+
ErrorCode.InvalidParams,
330+
'Client does not support sampling with tools but request contains ' +
331+
(params.tools !== undefined ? 'tools' : 'toolChoice') +
332+
' parameter'
333+
);
334+
}
335+
336+
return await Promise.resolve(handler(request, extra));
337+
};
338+
339+
return super.setRequestHandler(requestSchema, wrappedHandler as unknown as typeof handler);
340+
}
341+
310342
// Non-elicitation handlers use default behavior
311343
return super.setRequestHandler(requestSchema, handler);
312344
}

0 commit comments

Comments
 (0)