Skip to content

Commit ae971de

Browse files
committed
Support Cancellation Notification
Signed-off-by: JermaineHua <crazyhzm@apache.org>
1 parent 2e953c8 commit ae971de

File tree

4 files changed

+97
-8
lines changed

4 files changed

+97
-8
lines changed

mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.Map;
99
import java.util.UUID;
1010
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.concurrent.TimeoutException;
1112
import java.util.concurrent.atomic.AtomicLong;
1213

1314
import com.fasterxml.jackson.core.type.TypeReference;
@@ -236,7 +237,18 @@ public <T> Mono<T> sendRequest(String method, Object requestParams, TypeReferenc
236237
this.pendingResponses.remove(requestId);
237238
sink.error(error);
238239
});
239-
}).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> {
240+
}).timeout(this.requestTimeout).onErrorResume(e -> {
241+
if (e instanceof TimeoutException) {
242+
return Mono.fromRunnable(() -> {
243+
this.pendingResponses.remove(requestId);
244+
McpSchema.CancellationMessageNotification cancellationMessageNotification = new McpSchema.CancellationMessageNotification(
245+
requestId, "The request times out, timeout: " + requestTimeout.toMillis() + " ms");
246+
sendNotification(McpSchema.METHOD_NOTIFICATION_CANCELLED, cancellationMessageNotification)
247+
.subscribe();
248+
}).then(Mono.error(e));
249+
}
250+
return Mono.error(e);
251+
}).handle((jsonRpcResponse, sink) -> {
240252
if (jsonRpcResponse.error() != null) {
241253
sink.error(new McpError(jsonRpcResponse.error()));
242254
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
* Context Protocol Schema</a>.
3030
*
3131
* @author Christian Tzolov
32+
* @author Jermaine Hua
3233
*/
3334
public final class McpSchema {
3435

@@ -50,6 +51,8 @@ private McpSchema() {
5051

5152
public static final String METHOD_NOTIFICATION_INITIALIZED = "notifications/initialized";
5253

54+
public static final String METHOD_NOTIFICATION_CANCELLED = "notifications/cancelled";
55+
5356
public static final String METHOD_PING = "ping";
5457

5558
// Tool Methods
@@ -211,6 +214,16 @@ public record JSONRPCError(
211214
}
212215
}// @formatter:on
213216

217+
// ---------------------------
218+
// Cancellation Message Notification
219+
// ---------------------------
220+
@JsonInclude(JsonInclude.Include.NON_ABSENT)
221+
@JsonIgnoreProperties(ignoreUnknown = true)
222+
public record CancellationMessageNotification( // @formatter:off
223+
@JsonProperty("requestId") String requestId,
224+
@JsonProperty("reason") String reason){
225+
} // @formatter:on
226+
214227
// ---------------------------
215228
// Initialization
216229
// ---------------------------

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public class McpServerSession implements McpSession {
5353

5454
private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED);
5555

56+
/**
57+
* keyed by request ID, value is true if the request is being cancelled.
58+
*/
59+
private final Map<Object, Boolean> requestCancellation = new ConcurrentHashMap<>();
60+
5661
/**
5762
* Creates a new server session with the given parameters and the transport to use.
5863
* @param id session id
@@ -165,13 +170,18 @@ public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
165170
}
166171
else if (message instanceof McpSchema.JSONRPCRequest request) {
167172
logger.debug("Received request: {}", request);
173+
requestCancellation.put(request.id(), false);
168174
return handleIncomingRequest(request).onErrorResume(error -> {
169175
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
170176
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
171177
error.getMessage(), null));
172178
// TODO: Should the error go to SSE or back as POST return?
173-
return this.transport.sendMessage(errorResponse).then(Mono.empty());
174-
}).flatMap(this.transport::sendMessage);
179+
return this.transport.sendMessage(errorResponse)
180+
.doFinally(signal -> requestCancellation.remove(request.id()))
181+
.then(Mono.empty());
182+
})
183+
.flatMap(response -> this.transport.sendMessage(response)
184+
.doFinally(signal -> requestCancellation.remove(request.id())));
175185
}
176186
else if (message instanceof McpSchema.JSONRPCNotification notification) {
177187
// TODO handle errors for communication to without initialization
@@ -207,6 +217,11 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
207217
resultMono = this.initRequestHandler.handle(initializeRequest);
208218
}
209219
else {
220+
// cancellation request
221+
if (requestCancellation.get(request.id())) {
222+
requestCancellation.remove(request.id());
223+
return Mono.empty();
224+
}
210225
// TODO handle errors for communication to this session without
211226
// initialization happening first
212227
var handler = this.requestHandlers.get(request.method());
@@ -217,14 +232,32 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
217232
error.message(), error.data())));
218233
}
219234

220-
resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, request.params()));
235+
resultMono = this.exchangeSink.asMono()
236+
.flatMap(exchange -> handler.handle(exchange, request.params()).flatMap(result -> {
237+
if (requestCancellation.get(request.id())) {
238+
requestCancellation.remove(request.id());
239+
return Mono.empty();
240+
}
241+
else {
242+
return Mono.just(result);
243+
}
244+
}).doOnCancel(() -> requestCancellation.remove(request.id())));
245+
221246
}
222247
return resultMono
223248
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
224-
.onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
225-
null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
226-
error.getMessage(), null)))); // TODO: add error message
227-
// through the data field
249+
.onErrorResume(error -> {
250+
if (requestCancellation.get(request.id())) {
251+
requestCancellation.remove(request.id());
252+
return Mono.empty();
253+
}
254+
else {
255+
return Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
256+
new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
257+
error.getMessage(), null)));
258+
}
259+
}); // TODO: add error message
260+
// through the data field
228261
});
229262
}
230263

@@ -240,6 +273,17 @@ private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification noti
240273
exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get()));
241274
return this.initNotificationHandler.handle();
242275
}
276+
else if (McpSchema.METHOD_NOTIFICATION_CANCELLED.equals(notification.method())) {
277+
McpSchema.CancellationMessageNotification cancellationMessageNotification = transport
278+
.unmarshalFrom(notification.params(), new TypeReference<>() {
279+
});
280+
if (requestCancellation.containsKey(cancellationMessageNotification.requestId())) {
281+
logger.warn("Received cancellation notification for request {}, cancellation reason is {}",
282+
cancellationMessageNotification.requestId(), cancellationMessageNotification.reason());
283+
requestCancellation.put(cancellationMessageNotification.requestId(), true);
284+
}
285+
return Mono.empty();
286+
}
243287

244288
var handler = notificationHandlers.get(notification.method());
245289
if (handler == null) {

mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import reactor.core.publisher.Sinks;
1919
import reactor.test.StepVerifier;
2020

21+
import static io.modelcontextprotocol.spec.McpSchema.METHOD_NOTIFICATION_CANCELLED;
2122
import static org.assertj.core.api.Assertions.assertThat;
2223
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2324

@@ -119,6 +120,25 @@ void testRequestTimeout() {
119120
.verify(TIMEOUT.plusSeconds(1));
120121
}
121122

123+
@Test
124+
void testCancellationMessageNotificationForRequestTimeout() {
125+
Mono<String> responseMono = session.sendRequest(TEST_METHOD, "test", responseType);
126+
127+
StepVerifier.create(responseMono)
128+
.expectError(java.util.concurrent.TimeoutException.class)
129+
.verify(TIMEOUT.plusSeconds(1));
130+
131+
McpSchema.JSONRPCMessage sentMessage = transport.getLastSentMessage();
132+
assertThat(sentMessage).isInstanceOf(McpSchema.JSONRPCNotification.class);
133+
McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification) sentMessage;
134+
assertThat(notification.method()).isEqualTo(METHOD_NOTIFICATION_CANCELLED);
135+
McpSchema.CancellationMessageNotification cancellationMessageNotification = transport
136+
.unmarshalFrom(notification.params(), new TypeReference<>() {
137+
});
138+
assertThat(cancellationMessageNotification.reason()
139+
.contains("The request times out, timeout: " + TIMEOUT.toMillis() + " ms")).isTrue();
140+
}
141+
122142
@Test
123143
void testSendNotification() {
124144
Map<String, Object> params = Map.of("key", "value");

0 commit comments

Comments
 (0)