Skip to content

Commit a8056b1

Browse files
committed
wip: HttpRequest.Builder customizer
1 parent c3a0b18 commit a8056b1

File tree

4 files changed

+425
-286
lines changed

4 files changed

+425
-286
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package io.modelcontextprotocol.client.transport;
2+
3+
import java.net.URI;
4+
import java.net.http.HttpRequest;
5+
import org.reactivestreams.Publisher;
6+
import reactor.core.publisher.Mono;
7+
import reactor.util.annotation.Nullable;
8+
9+
/**
10+
* Customize {@link HttpRequest.Builder} before sending out SSE or Streamable HTTP
11+
* transport.
12+
* <p>
13+
* When used in a non-blocking context, implementations MUST be non-blocking.
14+
*/
15+
public interface AsyncHttpRequestCustomizer {
16+
17+
Publisher<HttpRequest.Builder> customize(HttpRequest.Builder builder, String method, URI endpoint,
18+
@Nullable String body);
19+
20+
AsyncHttpRequestCustomizer NOOP = new Noop();
21+
22+
/**
23+
* Wrap a sync implementation in an async wrapper.
24+
* <p>
25+
* Do NOT use in a non-blocking context.
26+
*/
27+
static AsyncHttpRequestCustomizer fromSync(SyncHttpRequestCustomizer customizer) {
28+
return (builder, method, uri, body) -> Mono.defer(() -> {
29+
customizer.customize(builder, method, uri, body);
30+
return Mono.just(builder);
31+
});
32+
}
33+
34+
class Noop implements AsyncHttpRequestCustomizer {
35+
36+
@Override
37+
public Publisher<HttpRequest.Builder> customize(HttpRequest.Builder builder, String method, URI endpoint,
38+
String body) {
39+
return Mono.just(builder);
40+
}
41+
42+
}
43+
44+
}

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 116 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
*/
44
package io.modelcontextprotocol.client.transport;
55

6+
import com.fasterxml.jackson.core.type.TypeReference;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
9+
import io.modelcontextprotocol.spec.McpClientTransport;
10+
import io.modelcontextprotocol.spec.McpError;
11+
import io.modelcontextprotocol.spec.McpSchema;
12+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
13+
import io.modelcontextprotocol.util.Assert;
14+
import io.modelcontextprotocol.util.Utils;
615
import java.io.IOException;
716
import java.net.URI;
817
import java.net.http.HttpClient;
@@ -13,20 +22,8 @@
1322
import java.util.concurrent.atomic.AtomicReference;
1423
import java.util.function.Consumer;
1524
import java.util.function.Function;
16-
1725
import org.slf4j.Logger;
1826
import org.slf4j.LoggerFactory;
19-
20-
import com.fasterxml.jackson.core.type.TypeReference;
21-
import com.fasterxml.jackson.databind.ObjectMapper;
22-
23-
import io.modelcontextprotocol.client.transport.ResponseSubscribers.ResponseEvent;
24-
import io.modelcontextprotocol.spec.McpClientTransport;
25-
import io.modelcontextprotocol.spec.McpError;
26-
import io.modelcontextprotocol.spec.McpSchema;
27-
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
28-
import io.modelcontextprotocol.util.Assert;
29-
import io.modelcontextprotocol.util.Utils;
3027
import reactor.core.Disposable;
3128
import reactor.core.publisher.Flux;
3229
import reactor.core.publisher.Mono;
@@ -102,6 +99,9 @@ public class HttpClientSseClientTransport implements McpClientTransport {
10299
*/
103100
protected final Sinks.One<String> messageEndpointSink = Sinks.one();
104101

102+
// TODO
103+
private final AsyncHttpRequestCustomizer httpRequestCustomizer;
104+
105105
/**
106106
* Creates a new transport instance with default HTTP client and object mapper.
107107
* @param baseUri the base URI of the MCP server
@@ -172,18 +172,38 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques
172172
* @param objectMapper the object mapper for JSON serialization/deserialization
173173
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
174174
*/
175+
@Deprecated(forRemoval = true)
175176
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
176177
String sseEndpoint, ObjectMapper objectMapper) {
178+
this(httpClient, requestBuilder, baseUri, sseEndpoint, objectMapper, AsyncHttpRequestCustomizer.NOOP);
179+
}
180+
181+
/**
182+
* Creates a new transport instance with custom HTTP client builder, object mapper,
183+
* and headers.
184+
* @param httpClient the HTTP client to use
185+
* @param requestBuilder the HTTP request builder to use
186+
* @param baseUri the base URI of the MCP server
187+
* @param sseEndpoint the SSE endpoint path
188+
* @param objectMapper the object mapper for JSON serialization/deserialization
189+
* @param httpRequestCustomizer customizer for the requestBuilder before sending
190+
* requests
191+
* @throws IllegalArgumentException if objectMapper, clientBuilder, or headers is null
192+
*/
193+
HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri,
194+
String sseEndpoint, ObjectMapper objectMapper, AsyncHttpRequestCustomizer httpRequestCustomizer) {
177195
Assert.notNull(objectMapper, "ObjectMapper must not be null");
178196
Assert.hasText(baseUri, "baseUri must not be empty");
179197
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
180198
Assert.notNull(httpClient, "httpClient must not be null");
181199
Assert.notNull(requestBuilder, "requestBuilder must not be null");
200+
Assert.notNull(httpRequestCustomizer, "httpRequestCustomizer must not be null");
182201
this.baseUri = URI.create(baseUri);
183202
this.sseEndpoint = sseEndpoint;
184203
this.objectMapper = objectMapper;
185204
this.httpClient = httpClient;
186205
this.requestBuilder = requestBuilder;
206+
this.httpRequestCustomizer = httpRequestCustomizer;
187207
}
188208

189209
/**
@@ -213,6 +233,8 @@ public static class Builder {
213233
private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
214234
.header("Content-Type", "application/json");
215235

236+
private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP;
237+
216238
/**
217239
* Creates a new builder instance.
218240
*/
@@ -310,94 +332,109 @@ public Builder objectMapper(ObjectMapper objectMapper) {
310332
return this;
311333
}
312334

335+
/**
336+
* In reactive, DONT USE THIS. Use AsyncHttpRequestCustomizer.
337+
*/
338+
public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) {
339+
this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer);
340+
return this;
341+
}
342+
343+
public Builder httpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) {
344+
this.httpRequestCustomizer = asyncHttpRequestCustomizer;
345+
return this;
346+
}
347+
313348
/**
314349
* Builds a new {@link HttpClientSseClientTransport} instance.
315350
* @return a new transport instance
316351
*/
317352
public HttpClientSseClientTransport build() {
318353
return new HttpClientSseClientTransport(clientBuilder.build(), requestBuilder, baseUri, sseEndpoint,
319-
objectMapper);
354+
objectMapper, httpRequestCustomizer);
320355
}
321356

322357
}
323358

324359
@Override
325360
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
361+
var uri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
326362

327-
return Mono.create(sink -> {
328-
329-
HttpRequest request = requestBuilder.copy()
330-
.uri(Utils.resolveUri(this.baseUri, this.sseEndpoint))
363+
return Mono
364+
.just(requestBuilder.copy()
365+
.uri(uri)
331366
.header("Accept", "text/event-stream")
332367
.header("Cache-Control", "no-cache")
333-
.GET()
334-
.build();
335-
336-
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient
337-
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
338-
.exceptionallyCompose(e -> {
339-
sseSink.error(e);
340-
return CompletableFuture.failedFuture(e);
341-
}))
342-
.map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent)
343-
.flatMap(responseEvent -> {
344-
if (isClosing) {
345-
return Mono.empty();
346-
}
347-
348-
int statusCode = responseEvent.responseInfo().statusCode();
368+
.GET())
369+
.flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null)))
370+
.map(HttpRequest.Builder::build)
371+
.flatMap(request -> Mono.create(sink -> {
372+
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient
373+
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
374+
.exceptionallyCompose(e -> {
375+
sseSink.error(e);
376+
return CompletableFuture.failedFuture(e);
377+
}))
378+
.map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent)
379+
.flatMap(responseEvent -> {
380+
if (isClosing) {
381+
return Mono.empty();
382+
}
349383

350-
if (statusCode >= 200 && statusCode < 300) {
351-
try {
352-
if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
353-
String messageEndpointUri = responseEvent.sseEvent().data();
354-
if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
384+
int statusCode = responseEvent.responseInfo().statusCode();
385+
386+
if (statusCode >= 200 && statusCode < 300) {
387+
try {
388+
if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
389+
String messageEndpointUri = responseEvent.sseEvent().data();
390+
if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
391+
sink.success();
392+
return Flux.empty(); // No further processing
393+
// needed
394+
}
395+
else {
396+
sink.error(new McpError("Failed to handle SSE endpoint event"));
397+
}
398+
}
399+
else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
400+
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper,
401+
responseEvent.sseEvent().data());
355402
sink.success();
356-
return Flux.empty(); // No further processing needed
403+
return Flux.just(message);
357404
}
358405
else {
359-
sink.error(new McpError("Failed to handle SSE endpoint event"));
406+
logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent());
407+
sink.success();
360408
}
361409
}
362-
else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
363-
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper,
364-
responseEvent.sseEvent().data());
365-
sink.success();
366-
return Flux.just(message);
367-
}
368-
else {
369-
logger.debug("Received unrecognized SSE event type: {}", responseEvent.sseEvent());
370-
sink.success();
410+
catch (IOException e) {
411+
logger.error("Error processing SSE event", e);
412+
sink.error(new McpError("Error processing SSE event"));
371413
}
372414
}
373-
catch (IOException e) {
374-
logger.error("Error processing SSE event", e);
375-
sink.error(new McpError("Error processing SSE event"));
415+
return Flux.<McpSchema.JSONRPCMessage>error(
416+
new RuntimeException("Failed to send message: " + responseEvent));
417+
418+
})
419+
.flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage)))
420+
.onErrorComplete(t -> {
421+
if (!isClosing) {
422+
logger.warn("SSE stream observed an error", t);
423+
sink.error(t);
376424
}
377-
}
378-
return Flux.<McpSchema.JSONRPCMessage>error(
379-
new RuntimeException("Failed to send message: " + responseEvent));
380-
381-
})
382-
.flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage)))
383-
.onErrorComplete(t -> {
384-
if (!isClosing) {
385-
logger.warn("SSE stream observed an error", t);
386-
sink.error(t);
387-
}
388-
return true;
389-
})
390-
.doFinally(s -> {
391-
Disposable ref = this.sseSubscription.getAndSet(null);
392-
if (ref != null && !ref.isDisposed()) {
393-
ref.dispose();
394-
}
395-
})
396-
.contextWrite(sink.contextView())
397-
.subscribe();
425+
return true;
426+
})
427+
.doFinally(s -> {
428+
Disposable ref = this.sseSubscription.getAndSet(null);
429+
if (ref != null && !ref.isDisposed()) {
430+
ref.dispose();
431+
}
432+
})
433+
.contextWrite(sink.contextView())
434+
.subscribe();
398435

399-
this.sseSubscription.set(connection);
400-
});
436+
this.sseSubscription.set(connection);
437+
}));
401438
}
402439

403440
/**
@@ -453,13 +490,10 @@ private Mono<String> serializeMessage(final JSONRPCMessage message) {
453490

454491
private Mono<HttpResponse<String>> sendHttpPost(final String endpoint, final String body) {
455492
final URI requestUri = Utils.resolveUri(baseUri, endpoint);
456-
final HttpRequest request = this.requestBuilder.copy()
457-
.uri(requestUri)
458-
.POST(HttpRequest.BodyPublishers.ofString(body))
459-
.build();
460-
461-
// TODO: why discard the body?
462-
return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()));
493+
return Mono.just(this.requestBuilder.copy().uri(requestUri).POST(HttpRequest.BodyPublishers.ofString(body)))
494+
.flatMap(builder -> Mono.from(this.httpRequestCustomizer.customize(builder, "POST", requestUri, body)))
495+
.map(HttpRequest.Builder::build)
496+
.flatMap(request -> Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())));
463497
}
464498

465499
/**

0 commit comments

Comments
 (0)