Skip to content

Commit b4207b9

Browse files
committed
Introduce shim task interface for tools
1 parent 0db292a commit b4207b9

File tree

8 files changed

+291
-96
lines changed

8 files changed

+291
-96
lines changed

src/client/index.test.ts

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,14 +2495,16 @@ test('should respect server task capabilities', async () => {
24952495

24962496
// This should throw because server doesn't support task creation for tools/list
24972497
await expect(
2498-
client.beginRequest(
2499-
{
2500-
method: 'tools/list',
2501-
params: {}
2502-
},
2503-
z.object({ tools: z.array(z.any()) }),
2504-
{ task: { taskId: 'test-task-2', keepAlive: 60000 } }
2505-
).result()
2498+
client
2499+
.beginRequest(
2500+
{
2501+
method: 'tools/list',
2502+
params: {}
2503+
},
2504+
z.object({ tools: z.array(z.any()) }),
2505+
{ task: { taskId: 'test-task-2', keepAlive: 60000 } }
2506+
)
2507+
.result()
25062508
).rejects.toThrow('Server does not support task creation for tools/list');
25072509

25082510
serverTaskStore.cleanup();

src/examples/server/simpleStreamableHttp.ts

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ const getServer = () => {
454454
);
455455

456456
// Register a long-running tool that demonstrates task execution
457-
server.registerTool(
457+
server.registerToolTask(
458458
'delay',
459459
{
460460
title: 'Delay',
@@ -463,16 +463,37 @@ const getServer = () => {
463463
duration: z.number().describe('Duration in milliseconds').default(5000)
464464
}
465465
},
466-
async ({ duration }): Promise<CallToolResult> => {
467-
await new Promise(resolve => setTimeout(resolve, duration));
468-
return {
469-
content: [
470-
{
471-
type: 'text',
472-
text: `Completed ${duration}ms delay`
473-
}
474-
]
475-
};
466+
{
467+
async createTask({ duration }, { taskId, taskStore, taskRequestedKeepAlive }) {
468+
// Simulate out-of-band work
469+
(async () => {
470+
await new Promise(resolve => setTimeout(resolve, duration));
471+
await taskStore.storeTaskResult(taskId, {
472+
content: [
473+
{
474+
type: 'text',
475+
text: `Completed ${duration}ms delay`
476+
}
477+
]
478+
});
479+
})();
480+
return await taskStore.createTask({
481+
taskId,
482+
keepAlive: taskRequestedKeepAlive
483+
});
484+
},
485+
async getTask(_args, { taskId, taskStore }) {
486+
const task = await taskStore.getTask(taskId);
487+
if (!task) {
488+
throw new Error(`Task ${taskId} not found`);
489+
}
490+
491+
return task;
492+
},
493+
async getTaskResult(_args, { taskId, taskStore }) {
494+
const result = await taskStore.getTaskResult(taskId);
495+
return result as CallToolResult;
496+
}
476497
}
477498
);
478499

src/examples/shared/inMemoryTaskStore.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export class InMemoryTaskStore implements TaskStore {
2121
private tasks = new Map<string, StoredTask>();
2222
private cleanupTimers = new Map<string, ReturnType<typeof setTimeout>>();
2323

24-
async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise<void> {
24+
async createTask(metadata: TaskMetadata, requestId: RequestId, request: Request, _sessionId?: string): Promise<Task> {
2525
const taskId = metadata.taskId;
2626

2727
if (this.tasks.has(taskId)) {
@@ -50,6 +50,8 @@ export class InMemoryTaskStore implements TaskStore {
5050

5151
this.cleanupTimers.set(taskId, timer);
5252
}
53+
54+
return task;
5355
}
5456

5557
async getTask(taskId: string, _sessionId?: string): Promise<Task | null> {

src/server/mcp.ts

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,20 @@ import {
3131
ServerNotification,
3232
ToolAnnotations,
3333
LoggingMessageNotification,
34+
CreateTaskResult,
35+
GetTaskResult,
36+
Result,
37+
TASK_META_KEY,
3438
CompleteRequestPrompt,
3539
CompleteRequestResourceTemplate,
3640
assertCompleteRequestPrompt,
3741
assertCompleteRequestResourceTemplate
3842
} from '../types.js';
3943
import { Completable, CompletableDef } from './completable.js';
4044
import { UriTemplate, Variables } from '../shared/uriTemplate.js';
41-
import { RequestHandlerExtra } from '../shared/protocol.js';
45+
import { RequestHandlerExtra, RequestTaskStore } from '../shared/protocol.js';
4246
import { Transport } from '../shared/transport.js';
47+
import { isTerminal } from '../shared/task.js';
4348

4449
/**
4550
* High-level MCP server that provides a simpler API for working with resources, tools, and prompts.
@@ -836,6 +841,51 @@ export class McpServer {
836841
);
837842
}
838843

844+
/**
845+
* Registers a task-based tool with a config object and callback.
846+
*/
847+
registerToolTask<InputArgs extends ZodRawShape, OutputArgs extends ZodRawShape>(
848+
name: string,
849+
config: {
850+
title?: string;
851+
description?: string;
852+
inputSchema?: InputArgs;
853+
outputSchema?: OutputArgs;
854+
annotations?: ToolAnnotations;
855+
_meta?: Record<string, unknown>;
856+
},
857+
handler: ToolTaskHandler<InputArgs>
858+
): RegisteredTool {
859+
// TODO: Attach to individual request handlers and remove this wrapper
860+
const cb: ToolCallback<InputArgs> = (async (...args) => {
861+
const [inputArgs, extra] = args;
862+
863+
const taskStore = extra.taskStore;
864+
if (!taskStore) {
865+
throw new Error('Task store is not available');
866+
}
867+
868+
const taskMetadata = extra._meta?.[TASK_META_KEY];
869+
const taskId = taskMetadata?.taskId;
870+
if (!taskId) {
871+
throw new Error('No task ID provided');
872+
}
873+
874+
// Internal polling to allow using this interface before internals are hooked up
875+
const taskExtra = { ...extra, taskId, taskStore };
876+
let task = await handler.createTask(inputArgs, taskExtra);
877+
do {
878+
await new Promise(resolve => setTimeout(resolve, task.pollInterval ?? 5000));
879+
task = await handler.getTask(inputArgs, taskExtra);
880+
} while (!isTerminal(task.status));
881+
882+
const result: CallToolResult = await handler.getTaskResult(inputArgs, taskExtra);
883+
return result;
884+
}) as ToolCallback<InputArgs>;
885+
886+
return this.registerTool(name, { ...config, annotations: { ...config.annotations, taskHint: true } }, cb);
887+
}
888+
839889
/**
840890
* Registers a zero-argument prompt `name`, which will run the given function when the client calls it.
841891
* @deprecated Use `registerPrompt` instead.
@@ -1044,6 +1094,21 @@ export type ToolCallback<Args extends undefined | ZodRawShape | ZodType<object>
10441094
? (args: T, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => CallToolResult | Promise<CallToolResult>
10451095
: (extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => CallToolResult | Promise<CallToolResult>;
10461096

1097+
export interface TaskRequestHandlerExtra extends RequestHandlerExtra<ServerRequest, ServerNotification> {
1098+
taskId: string;
1099+
taskStore: RequestTaskStore;
1100+
}
1101+
1102+
export type TaskRequestHandler<SendResultT extends Result, Args extends undefined | ZodRawShape = undefined> = Args extends ZodRawShape
1103+
? (args: z.objectOutputType<Args, ZodTypeAny>, extra: TaskRequestHandlerExtra) => SendResultT | Promise<SendResultT>
1104+
: (extra: TaskRequestHandlerExtra) => SendResultT | Promise<SendResultT>;
1105+
1106+
export interface ToolTaskHandler<Args extends undefined | ZodRawShape = undefined> {
1107+
createTask: TaskRequestHandler<CreateTaskResult, Args>;
1108+
getTask: TaskRequestHandler<GetTaskResult, Args>;
1109+
getTaskResult: TaskRequestHandler<CallToolResult, Args>;
1110+
}
1111+
10471112
export type RegisteredTool = {
10481113
title?: string;
10491114
description?: string;

src/shared/protocol.test.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ function createMockTaskStore(options?: {
3737
const tasks: Record<string, Task & { result?: Result }> = {};
3838
return {
3939
createTask: jest.fn((taskMetadata: TaskMetadata, _1: RequestId, _2: Request) => {
40-
tasks[taskMetadata.taskId] = {
40+
const task = (tasks[taskMetadata.taskId] = {
4141
taskId: taskMetadata.taskId,
4242
status: (taskMetadata.status as Task['status'] | undefined) ?? 'submitted',
4343
keepAlive: taskMetadata.keepAlive ?? null,
4444
pollInterval: (taskMetadata.pollInterval as Task['pollInterval'] | undefined) ?? 1000
45-
};
45+
});
4646
options?.onStatus?.('submitted');
47-
return Promise.resolve();
47+
return Promise.resolve(task);
4848
}),
4949
getTask: jest.fn((taskId: string) => {
5050
return Promise.resolve(tasks[taskId] ?? null);

0 commit comments

Comments
 (0)