|
3 | 3 | */ |
4 | 4 | package io.modelcontextprotocol.client.transport; |
5 | 5 |
|
| 6 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 7 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 8 | +import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; |
| 9 | +import io.modelcontextprotocol.spec.McpClientTransport; |
| 10 | +import io.modelcontextprotocol.spec.McpError; |
| 11 | +import io.modelcontextprotocol.spec.McpSchema; |
| 12 | +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; |
| 13 | +import io.modelcontextprotocol.util.Assert; |
| 14 | +import io.modelcontextprotocol.util.Utils; |
6 | 15 | import java.io.IOException; |
7 | 16 | import java.net.URI; |
8 | 17 | import java.net.http.HttpClient; |
|
13 | 22 | import java.util.concurrent.atomic.AtomicReference; |
14 | 23 | import java.util.function.Consumer; |
15 | 24 | import java.util.function.Function; |
16 | | - |
17 | 25 | import org.slf4j.Logger; |
18 | 26 | import org.slf4j.LoggerFactory; |
19 | | - |
20 | | -import com.fasterxml.jackson.core.type.TypeReference; |
21 | | -import com.fasterxml.jackson.databind.ObjectMapper; |
22 | | - |
23 | | -import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent; |
24 | | -import io.modelcontextprotocol.spec.McpClientTransport; |
25 | | -import io.modelcontextprotocol.spec.McpError; |
26 | | -import io.modelcontextprotocol.spec.McpSchema; |
27 | | -import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; |
28 | | -import io.modelcontextprotocol.util.Assert; |
29 | | -import io.modelcontextprotocol.util.Utils; |
30 | 27 | import reactor.core.Disposable; |
31 | 28 | import reactor.core.publisher.Flux; |
32 | 29 | import reactor.core.publisher.Mono; |
@@ -102,6 +99,9 @@ public class HttpClientSseClientTransport implements McpClientTransport { |
102 | 99 | */ |
103 | 100 | protected final Sinks.One<String> messageEndpointSink = Sinks.one(); |
104 | 101 |
|
| 102 | + // TODO |
| 103 | + private final AsyncHttpRequestCustomizer httpRequestCustomizer; |
| 104 | + |
105 | 105 | /** |
106 | 106 | * Creates a new transport instance with default HTTP client and object mapper. |
107 | 107 | * @param baseUri the base URI of the MCP server |
@@ -172,18 +172,38 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques |
172 | 172 | * @param objectMapper the object mapper for JSON serialization/deserialization |
173 | 173 | * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null |
174 | 174 | */ |
| 175 | + @Deprecated(forRemoval = true) |
175 | 176 | HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, |
176 | 177 | String sseEndpoint, ObjectMapper objectMapper) { |
| 178 | + this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP); |
| 179 | + } |
| 180 | + |
| 181 | + /** |
| 182 | + * Creates a new transport instance with custom HTTP client builder, object mapper, |
| 183 | + * and headers. |
| 184 | + * @param httpClient the HTTP client to use |
| 185 | + * @param requestBuilder the HTTP request builder to use |
| 186 | + * @param baseUri the base URI of the MCP server |
| 187 | + * @param sseEndpoint the SSE endpoint path |
| 188 | + * @param objectMapper the object mapper for JSON serialization/deserialization |
| 189 | + * @param httpRequestCustomizer customizer for the requestBuilder before sending |
| 190 | + * requests |
| 191 | + * @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null |
| 192 | + */ |
| 193 | + HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, |
| 194 | + String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) { |
177 | 195 | Assert.notNull(objectMapper, "ObjectMapper must not be null"); |
178 | 196 | Assert.hasText(baseUri, "baseUri must not be empty"); |
179 | 197 | Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); |
180 | 198 | Assert.notNull(httpClient, "httpClient must not be null"); |
181 | 199 | Assert.notNull(requestBuilder, "requestBuilder must not be null"); |
| 200 | + Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null"); |
182 | 201 | this.baseUri = URI.create(baseUri); |
183 | 202 | this.sseEndpoint = sseEndpoint; |
184 | 203 | this.objectMapper = objectMapper; |
185 | 204 | this.httpClient = httpClient; |
186 | 205 | this.requestBuilder = requestBuilder; |
| 206 | + this.httpRequestCustomizer = httpRequestCustomizer; |
187 | 207 | } |
188 | 208 |
|
189 | 209 | /** |
@@ -213,6 +233,8 @@ public static class Builder { |
213 | 233 | private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() |
214 | 234 | .header("Content-Type", "application/json"); |
215 | 235 |
|
| 236 | + private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP; |
| 237 | + |
216 | 238 | /** |
217 | 239 | * Creates a new builder instance. |
218 | 240 | */ |
@@ -310,94 +332,109 @@ public Builder objectMapper(ObjectMapper objectMapper) { |
310 | 332 | return this; |
311 | 333 | } |
312 | 334 |
|
| 335 | + /** |
| 336 | + * In reactive, DONT USE THIS. Use AsyncHttpRequestCustomizer. |
| 337 | + */ |
| 338 | + public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) { |
| 339 | + this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer); |
| 340 | + return this; |
| 341 | + } |
| 342 | + |
| 343 | + public Builder httpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) { |
| 344 | + this.httpRequestCustomizer = asyncHttpRequestCustomizer; |
| 345 | + return this; |
| 346 | + } |
| 347 | + |
313 | 348 | /** |
314 | 349 | * Builds a new {@link HttpClientSseClientTransport} instance. |
315 | 350 | * @return a new transport instance |
316 | 351 | */ |
317 | 352 | public HttpClientSseClientTransport build() { |
318 | 353 | return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint, |
319 | | - objectMapper); |
| 354 | + objectMapper, httpRequestCustomizer); |
320 | 355 | } |
321 | 356 |
|
322 | 357 | } |
323 | 358 |
|
324 | 359 | @Override |
325 | 360 | public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) { |
| 361 | + var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint); |
326 | 362 |
|
327 | | - return Mono.create(sink -> { |
328 | | - |
329 | | - HttpRequest request = requestBuilder.copy() |
330 | | - .uri(Utils.resolveUri(this.baseUri, this.sseEndpoint)) |
| 363 | + return Mono |
| 364 | + .just(requestBuilder.copy() |
| 365 | + .uri(uri) |
331 | 366 | .header("Accept", "text/event-stream") |
332 | 367 | .header("Cache-Control", "no-cache") |
333 | | - .GET() |
334 | | - .build(); |
335 | | - |
336 | | - Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient |
337 | | - .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) |
338 | | - .exceptionallyCompose(e -> { |
339 | | - sseSink.error(e); |
340 | | - return CompletableFuture.failedFuture(e); |
341 | | - })) |
342 | | - .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) |
343 | | - .flatMap(responseEvent -> { |
344 | | - if (isClosing) { |
345 | | - return Mono.empty(); |
346 | | - } |
347 | | - |
348 | | - int statusCode = responseEvent.responseInfo().statusCode(); |
| 368 | + .GET()) |
| 369 | + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null))) |
| 370 | + .map(HttpRequest.Builder::build) |
| 371 | + .flatMap(request -> Mono.create(sink -> { |
| 372 | + Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient |
| 373 | + .sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink)) |
| 374 | + .exceptionallyCompose(e -> { |
| 375 | + sseSink.error(e); |
| 376 | + return CompletableFuture.failedFuture(e); |
| 377 | + })) |
| 378 | + .map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent) |
| 379 | + .flatMap(responseEvent -> { |
| 380 | + if (isClosing) { |
| 381 | + return Mono.empty(); |
| 382 | + } |
349 | 383 |
|
350 | | - if (statusCode >= 200 && statusCode < 300) { |
351 | | - try { |
352 | | - if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { |
353 | | - String messageEndpointUri = responseEvent.sseEvent().data(); |
354 | | - if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { |
| 384 | + int statusCode = responseEvent.responseInfo().statusCode(); |
| 385 | + |
| 386 | + if (statusCode >= 200 && statusCode < 300) { |
| 387 | + try { |
| 388 | + if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { |
| 389 | + String messageEndpointUri = responseEvent.sseEvent().data(); |
| 390 | + if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) { |
| 391 | + sink.success(); |
| 392 | + return Flux.empty(); // No further processing |
| 393 | + // needed |
| 394 | + } |
| 395 | + else { |
| 396 | + sink.error(new McpError("Failed to handle SSE endpoint event")); |
| 397 | + } |
| 398 | + } |
| 399 | + else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { |
| 400 | + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, |
| 401 | + responseEvent.sseEvent().data()); |
355 | 402 | sink.success(); |
356 | | - return Flux.empty(); // No further processing needed |
| 403 | + return Flux.just(message); |
357 | 404 | } |
358 | 405 | else { |
359 | | - sink.error(new McpError("Failed to handle SSE endpoint event")); |
| 406 | + logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent()); |
| 407 | + sink.success(); |
360 | 408 | } |
361 | 409 | } |
362 | | - else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) { |
363 | | - JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, |
364 | | - responseEvent.sseEvent().data()); |
365 | | - sink.success(); |
366 | | - return Flux.just(message); |
367 | | - } |
368 | | - else { |
369 | | - logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent()); |
370 | | - sink.success(); |
| 410 | + catch (IOException e) { |
| 411 | + logger.error("Error processing SSE event", e); |
| 412 | + sink.error(new McpError("Error processing SSE event")); |
371 | 413 | } |
372 | 414 | } |
373 | | - catch (IOException e) { |
374 | | - logger.error("Error processing SSE event", e); |
375 | | - sink.error(new McpError("Error processing SSE event")); |
| 415 | + return Flux.<McpSchema.JSONRPCMessage>error( |
| 416 | + new RuntimeException("Failed to send message: " + responseEvent)); |
| 417 | + |
| 418 | + }) |
| 419 | + .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) |
| 420 | + .onErrorComplete(t -> { |
| 421 | + if (!isClosing) { |
| 422 | + logger.warn("SSE stream observed an error", t); |
| 423 | + sink.error(t); |
376 | 424 | } |
377 | | - } |
378 | | - return Flux.<McpSchema.JSONRPCMessage>error( |
379 | | - new RuntimeException("Failed to send message: " + responseEvent)); |
380 | | - |
381 | | - }) |
382 | | - .flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))) |
383 | | - .onErrorComplete(t -> { |
384 | | - if (!isClosing) { |
385 | | - logger.warn("SSE stream observed an error", t); |
386 | | - sink.error(t); |
387 | | - } |
388 | | - return true; |
389 | | - }) |
390 | | - .doFinally(s -> { |
391 | | - Disposable ref = this.sseSubscription.getAndSet(null); |
392 | | - if (ref != null && !ref.isDisposed()) { |
393 | | - ref.dispose(); |
394 | | - } |
395 | | - }) |
396 | | - .contextWrite(sink.contextView()) |
397 | | - .subscribe(); |
| 425 | + return true; |
| 426 | + }) |
| 427 | + .doFinally(s -> { |
| 428 | + Disposable ref = this.sseSubscription.getAndSet(null); |
| 429 | + if (ref != null && !ref.isDisposed()) { |
| 430 | + ref.dispose(); |
| 431 | + } |
| 432 | + }) |
| 433 | + .contextWrite(sink.contextView()) |
| 434 | + .subscribe(); |
398 | 435 |
|
399 | | - this.sseSubscription.set(connection); |
400 | | - }); |
| 436 | + this.sseSubscription.set(connection); |
| 437 | + })); |
401 | 438 | } |
402 | 439 |
|
403 | 440 | /** |
@@ -453,13 +490,10 @@ private Mono<String> serializeMessage(final JSONRPCMessage message) { |
453 | 490 |
|
454 | 491 | private Mono<HttpResponse<String>> sendHttpPost(final String endpoint, final String body) { |
455 | 492 | final URI requestUri = Utils.resolveUri(baseUri, endpoint); |
456 | | - final HttpRequest request = this.requestBuilder.copy() |
457 | | - .uri(requestUri) |
458 | | - .POST(HttpRequest.BodyPublishers.ofString(body)) |
459 | | - .build(); |
460 | | - |
461 | | - // TODO: why discard the body? |
462 | | - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())); |
| 493 | + return Mono.just(this.requestBuilder.copy().uri(requestUri).POST(HttpRequest.BodyPublishers.ofString(body))) |
| 494 | + .flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body))) |
| 495 | + .map(HttpRequest.Builder::build) |
| 496 | + .flatMap(request -> Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()))); |
463 | 497 | } |
464 | 498 |
|
465 | 499 | /** |
|
0 commit comments