Skip to content

Commit 54df752

Browse files
committed
fix: properly queue responses and notifications during Task execution
1 parent 21ca745 commit 54df752

29 files changed

+1619
-1136
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskExtra.java

Lines changed: 89 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import io.modelcontextprotocol.server.McpAsyncServerExchange;
1010
import io.modelcontextprotocol.spec.McpSchema;
11+
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
12+
import io.modelcontextprotocol.spec.McpSchema.TaskStatus;
1113
import reactor.core.publisher.Mono;
1214

1315
/**
@@ -18,21 +20,15 @@
1820
*
1921
* <pre>{@code
2022
* CreateTaskHandler handler = (args, extra) -> {
21-
* // Decide TTL based on request or use a default
22-
* long ttl = extra.requestTtl() != null
23-
* ? Math.min(extra.requestTtl(), Duration.ofMinutes(30).toMillis())
24-
* : Duration.ofMinutes(5).toMillis();
23+
* return extra.createTask(opts -> opts.pollInterval(500L)).flatMap(task -> {
24+
* // Start async work that will complete the task later
25+
* doAsyncWork(args)
26+
* .flatMap(result -> extra.completeTask(task.taskId(), result))
27+
* .onErrorResume(e -> extra.failTask(task.taskId(), e.getMessage()))
28+
* .subscribe();
2529
*
26-
* return extra.taskStore()
27-
* .createTask(CreateTaskOptions.builder()
28-
* .requestedTtl(ttl)
29-
* .sessionId(extra.sessionId())
30-
* .build())
31-
* .flatMap(task -> {
32-
* // Use exchange for client communication
33-
* startBackgroundWork(task.taskId(), args, extra.exchange()).subscribe();
34-
* return Mono.just(new McpSchema.CreateTaskResult(task, null));
35-
* });
30+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
31+
* });
3632
* };
3733
* }</pre>
3834
*
@@ -47,38 +43,9 @@
4743
*
4844
* @see CreateTaskHandler
4945
* @see SyncCreateTaskExtra
50-
* @see TaskStore
51-
* @see TaskMessageQueue
5246
*/
5347
public interface CreateTaskExtra {
5448

55-
/**
56-
* The task store for creating and managing tasks.
57-
*
58-
* <p>
59-
* Tools use this to create tasks with their desired configuration:
60-
*
61-
* <pre>{@code
62-
* extra.taskStore().createTask(CreateTaskOptions.builder()
63-
* .requestedTtl(Duration.ofMinutes(5).toMillis())
64-
* .pollInterval(Duration.ofSeconds(1).toMillis())
65-
* .sessionId(extra.sessionId())
66-
* .build());
67-
* }</pre>
68-
* @return the TaskStore instance
69-
*/
70-
TaskStore<McpSchema.ServerTaskPayloadResult> taskStore();
71-
72-
/**
73-
* The message queue for task communication during INPUT_REQUIRED state.
74-
*
75-
* <p>
76-
* Use this for interactive tasks that need to communicate with the client during
77-
* execution.
78-
* @return the TaskMessageQueue instance, or null if not configured
79-
*/
80-
TaskMessageQueue taskMessageQueue();
81-
8249
/**
8350
* The server exchange for client interaction.
8451
*
@@ -130,81 +97,116 @@ public interface CreateTaskExtra {
13097
McpSchema.Request originatingRequest();
13198

13299
// --------------------------
133-
// Convenience Methods
100+
// Task Creation
134101
// --------------------------
135102

136103
/**
137104
* Convenience method to create a task with default options derived from this context.
138105
*
139106
* <p>
140107
* This method automatically uses {@link #originatingRequest()}, {@link #sessionId()},
141-
* and {@link #requestTtl()} from this context, eliminating common boilerplate:
108+
* and {@link #requestTtl()} from this context.
142109
*
143110
* <pre>{@code
144-
* // Instead of:
145-
* extra.taskStore().createTask(CreateTaskOptions.builder(extra.originatingRequest())
146-
* .sessionId(extra.sessionId())
147-
* .requestedTtl(extra.requestTtl())
148-
* .build())
149-
*
150-
* // You can simply use:
151-
* extra.createTask()
111+
* extra.createTask().flatMap(task -> {
112+
* doAsyncWork(args)
113+
* .flatMap(result -> extra.completeTask(task.taskId(), result))
114+
* .subscribe();
115+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
116+
* });
152117
* }</pre>
153-
* @return Mono that completes with the created Task
118+
* @return Mono that completes with the created task
154119
*/
155-
default Mono<McpSchema.Task> createTask() {
156-
return taskStore().createTask(CreateTaskOptions.builder(originatingRequest())
157-
.sessionId(sessionId())
158-
.requestedTtl(requestTtl())
159-
.build());
160-
}
120+
Mono<McpSchema.Task> createTask();
161121

162122
/**
163123
* Convenience method to create a task with custom options, but inheriting session
164124
* context.
165125
*
166126
* <p>
167127
* This method pre-populates the builder with {@link #originatingRequest()},
168-
* {@link #sessionId()}, and {@link #requestTtl()}, then allows customization:
128+
* {@link #sessionId()}, and {@link #requestTtl()}, then allows customization.
169129
*
170130
* <pre>{@code
171131
* // Create a task with custom poll interval:
172-
* extra.createTask(opts -> opts.pollInterval(500L))
173-
*
174-
* // Create a task with custom TTL (ignoring client request):
175-
* extra.createTask(opts -> opts.requestedTtl(Duration.ofMinutes(10).toMillis()))
132+
* extra.createTask(opts -> opts.pollInterval(500L)).flatMap(task -> {
133+
* // Pass task ID explicitly for side-channeling
134+
* extra.exchange().createElicitation(request, task.taskId()).subscribe();
135+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
136+
* });
176137
* }</pre>
177138
* @param customizer function to customize options beyond the defaults
178-
* @return Mono that completes with the created Task
139+
* @return Mono that completes with the created task
140+
*/
141+
Mono<McpSchema.Task> createTask(Consumer<CreateTaskOptions.Builder> customizer);
142+
143+
// --------------------------
144+
// Task Lifecycle
145+
// --------------------------
146+
147+
/**
148+
* Complete a task with a successful result.
149+
*
150+
* <p>
151+
* This marks the task as {@link TaskStatus#COMPLETED} and stores the result for
152+
* client retrieval.
153+
*
154+
* <pre>{@code
155+
* extra.createTask().flatMap(task -> {
156+
* doAsyncWork(args)
157+
* .flatMap(result -> extra.completeTask(task.taskId(), result))
158+
* .subscribe();
159+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
160+
* });
161+
* }</pre>
162+
* @param taskId the ID of the task to complete
163+
* @param result the tool result to store
164+
* @return Mono that completes when the task is updated
165+
*/
166+
Mono<Void> completeTask(String taskId, CallToolResult result);
167+
168+
/**
169+
* Mark a task as failed with an error message.
170+
*
171+
* <p>
172+
* This marks the task as {@link TaskStatus#FAILED} with the provided message.
173+
*
174+
* <pre>{@code
175+
* extra.createTask().flatMap(task -> {
176+
* doAsyncWork(args)
177+
* .flatMap(result -> extra.completeTask(task.taskId(), result))
178+
* .onErrorResume(e -> extra.failTask(task.taskId(), e.getMessage()))
179+
* .subscribe();
180+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
181+
* });
182+
* }</pre>
183+
* @param taskId the ID of the task to fail
184+
* @param message the error message describing what went wrong
185+
* @return Mono that completes when the task is updated
179186
*/
180-
default Mono<McpSchema.Task> createTask(Consumer<CreateTaskOptions.Builder> customizer) {
181-
CreateTaskOptions.Builder builder = CreateTaskOptions.builder(originatingRequest())
182-
.sessionId(sessionId())
183-
.requestedTtl(requestTtl());
184-
customizer.accept(builder);
185-
return taskStore().createTask(builder.build());
186-
}
187+
Mono<Void> failTask(String taskId, String message);
187188

188189
/**
189-
* Create a TaskContext for managing the given task's lifecycle.
190+
* Set a task to INPUT_REQUIRED status, triggering side-channel delivery.
190191
*
191192
* <p>
192-
* This convenience method creates a TaskContext that uses this extra's task store and
193-
* message queue, reducing boilerplate in task handlers:
193+
* When a task is in {@link TaskStatus#INPUT_REQUIRED}, the client will poll via
194+
* {@code tasks/result} and receive any queued notifications or requests via
195+
* side-channeling.
194196
*
195197
* <pre>{@code
196-
* extra.createTask()
197-
* .map(task -> extra.createTaskContext(task))
198-
* .flatMap(ctx -> {
199-
* // Use ctx to update status, send messages, etc.
200-
* return ctx.complete(result);
201-
* });
198+
* extra.createTask().flatMap(task -> {
199+
* // Queue a notification for side-channel delivery
200+
* extra.exchange().loggingNotification(notification, task.taskId())
201+
* .then(extra.setInputRequired(task.taskId(), "Waiting for user input"))
202+
* .subscribe();
203+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
204+
* });
202205
* }</pre>
203-
* @param task the task to create a context for
204-
* @return a TaskContext bound to the given task and this extra's infrastructure
206+
* @param taskId the ID of the task
207+
* @param message a status message describing what input is required
208+
* @return Mono that completes when the task is updated
205209
*/
206-
default TaskContext createTaskContext(McpSchema.Task task) {
207-
return new DefaultTaskContext<>(task.taskId(), sessionId(), taskStore(), taskMessageQueue());
208-
}
210+
Mono<Void> setInputRequired(String taskId, String message);
209211

210212
}

mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/CreateTaskHandler.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,15 @@
1616
*
1717
* <pre>{@code
1818
* CreateTaskHandler handler = (args, extra) -> {
19-
* // Tool decides TTL directly
20-
* long ttl = Duration.ofMinutes(5).toMillis();
19+
* return extra.createTask(opts -> opts.pollInterval(500L)).flatMap(task -> {
20+
* // Start background work that will complete the task later
21+
* doAsyncWork(args)
22+
* .flatMap(result -> extra.completeTask(task.taskId(), result))
23+
* .onErrorResume(e -> extra.failTask(task.taskId(), e.getMessage()))
24+
* .subscribe();
2125
*
22-
* return extra.taskStore()
23-
* .createTask(CreateTaskOptions.builder()
24-
* .requestedTtl(ttl)
25-
* .sessionId(extra.sessionId())
26-
* .build())
27-
* .flatMap(task -> {
28-
* // Start background work
29-
* doWork(task.taskId(), args, extra.exchange()).subscribe();
30-
* return Mono.just(new McpSchema.CreateTaskResult(task, null));
31-
* });
26+
* return Mono.just(McpSchema.CreateTaskResult.builder().task(task).build());
27+
* });
3228
* };
3329
* }</pre>
3430
*
@@ -52,7 +48,8 @@ public interface CreateTaskHandler {
5248
* <li>Returning the created task wrapped in a CreateTaskResult</li>
5349
* </ul>
5450
* @param args The parsed tool arguments from the CallToolRequest
55-
* @param extra Context providing taskStore, exchange, and request metadata
51+
* @param extra Context providing task lifecycle methods, exchange, and request
52+
* metadata
5653
* @return a Mono emitting the CreateTaskResult containing the created Task
5754
*/
5855
Mono<McpSchema.CreateTaskResult> createTask(Map<String, Object> args, CreateTaskExtra extra);

mcp-core/src/main/java/io/modelcontextprotocol/experimental/tasks/DefaultCreateTaskExtra.java

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44

55
package io.modelcontextprotocol.experimental.tasks;
66

7+
import java.util.function.Consumer;
8+
79
import io.modelcontextprotocol.server.McpAsyncServerExchange;
810
import io.modelcontextprotocol.spec.McpSchema;
11+
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
12+
import io.modelcontextprotocol.spec.McpSchema.TaskStatus;
913
import io.modelcontextprotocol.util.Assert;
14+
import reactor.core.publisher.Mono;
1015

1116
/**
1217
* Default implementation of {@link CreateTaskExtra}.
@@ -61,16 +66,32 @@ public DefaultCreateTaskExtra(TaskStore<McpSchema.ServerTaskPayloadResult> taskS
6166
this.originatingRequest = originatingRequest;
6267
}
6368

64-
@Override
65-
public TaskStore<McpSchema.ServerTaskPayloadResult> taskStore() {
69+
// --------------------------
70+
// Internal accessors (for framework use only)
71+
// --------------------------
72+
73+
/**
74+
* Returns the task store. This method is package-private for internal framework use
75+
* only.
76+
* @return the task store
77+
*/
78+
TaskStore<McpSchema.ServerTaskPayloadResult> taskStore() {
6679
return this.taskStore;
6780
}
6881

69-
@Override
70-
public TaskMessageQueue taskMessageQueue() {
82+
/**
83+
* Returns the message queue. This method is package-private for internal framework
84+
* use only.
85+
* @return the message queue, or null if not configured
86+
*/
87+
TaskMessageQueue taskMessageQueue() {
7188
return this.taskMessageQueue;
7289
}
7390

91+
// --------------------------
92+
// CreateTaskExtra implementation
93+
// --------------------------
94+
7495
@Override
7596
public McpAsyncServerExchange exchange() {
7697
return this.exchange;
@@ -91,4 +112,36 @@ public McpSchema.Request originatingRequest() {
91112
return this.originatingRequest;
92113
}
93114

115+
@Override
116+
public Mono<McpSchema.Task> createTask() {
117+
return this.taskStore.createTask(CreateTaskOptions.builder(originatingRequest())
118+
.sessionId(sessionId())
119+
.requestedTtl(requestTtl())
120+
.build());
121+
}
122+
123+
@Override
124+
public Mono<McpSchema.Task> createTask(Consumer<CreateTaskOptions.Builder> customizer) {
125+
CreateTaskOptions.Builder builder = CreateTaskOptions.builder(originatingRequest())
126+
.sessionId(sessionId())
127+
.requestedTtl(requestTtl());
128+
customizer.accept(builder);
129+
return this.taskStore.createTask(builder.build());
130+
}
131+
132+
@Override
133+
public Mono<Void> completeTask(String taskId, CallToolResult result) {
134+
return this.taskStore.storeTaskResult(taskId, this.sessionId, TaskStatus.COMPLETED, result);
135+
}
136+
137+
@Override
138+
public Mono<Void> failTask(String taskId, String message) {
139+
return this.taskStore.updateTaskStatus(taskId, this.sessionId, TaskStatus.FAILED, message);
140+
}
141+
142+
@Override
143+
public Mono<Void> setInputRequired(String taskId, String message) {
144+
return this.taskStore.updateTaskStatus(taskId, this.sessionId, TaskStatus.INPUT_REQUIRED, message);
145+
}
146+
94147
}

0 commit comments

Comments
 (0)