Skip to content

Commit d99df2b

Browse files
committed
refactor: improve SSE client transpor testing
- Extract eventStream() method - Enhance connect() method with improved SSE event processing and error handling - Introduce MyResponseInfo record and enhanced TestHttpClientSseClientTransport - Add collections to track JSON-RPC requests, notifications, and responses Signed-off-by: Christian Tzolov <christian.tzolov@broadcom.com>
1 parent 420aaf2 commit d99df2b

File tree

2 files changed

+152
-13
lines changed

2 files changed

+152
-13
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,89 @@ public HttpClientSseClientTransport build() {
321321

322322
}
323323

324+
protected Flux<ResponseSubscribers.SseResponseEvent> eventStream() {
325+
326+
HttpRequest request = requestBuilder.copy()
327+
.uri(Utils.resolveUri(this.baseUri, this.sseEndpoint))
328+
.header("Accept", "text/event-stream")
329+
.header("Cache-Control", "no-cache")
330+
.GET()
331+
.build();
332+
333+
Flux<ResponseSubscribers.SseResponseEvent> bla = Flux.<ResponseEvent>create(sseSink -> this.httpClient
334+
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
335+
.exceptionallyCompose(e -> {
336+
sseSink.error(e);
337+
return CompletableFuture.failedFuture(e);
338+
})).map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent);
339+
return bla;
340+
}
341+
324342
@Override
325343
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
326344

345+
return Mono.create(sink -> {
346+
347+
Flux<ResponseSubscribers.SseResponseEvent> events = eventStream();
348+
349+
Disposable connection = events.flatMap(responseEvent -> {
350+
if (isClosing) {
351+
return Mono.empty();
352+
}
353+
354+
int statusCode = responseEvent.responseInfo().statusCode();
355+
356+
if (statusCode >= 200 && statusCode < 300) {
357+
try {
358+
if (ENDPOINT_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
359+
String messageEndpointUri = responseEvent.sseEvent().data();
360+
if (this.messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
361+
sink.success();
362+
return Flux.empty(); // No further processing needed
363+
}
364+
else {
365+
sink.error(new McpError("Failed to handle SSE endpoint event"));
366+
}
367+
}
368+
else if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
369+
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper,
370+
responseEvent.sseEvent().data());
371+
sink.success();
372+
return Flux.just(message);
373+
}
374+
else {
375+
logger.error("Received unrecognized SSE event type: {}", responseEvent.sseEvent().event());
376+
sink.error(new McpError(
377+
"Received unrecognized SSE event type: " + responseEvent.sseEvent().event()));
378+
}
379+
}
380+
catch (IOException e) {
381+
logger.error("Error processing SSE event", e);
382+
sink.error(new McpError("Error processing SSE event"));
383+
}
384+
}
385+
return Flux.<McpSchema.JSONRPCMessage>error(
386+
new RuntimeException("Failed to send message: " + responseEvent));
387+
388+
}).flatMap(jsonRpcMessage -> handler.apply(Mono.just(jsonRpcMessage))).onErrorComplete(t -> {
389+
if (!isClosing) {
390+
logger.warn("SSE stream observed an error", t);
391+
sink.error(t);
392+
}
393+
return true;
394+
}).doFinally(s -> {
395+
Disposable ref = this.sseSubscription.getAndSet(null);
396+
if (ref != null && !ref.isDisposed()) {
397+
ref.dispose();
398+
}
399+
}).contextWrite(sink.contextView()).subscribe();
400+
401+
this.sseSubscription.set(connection);
402+
});
403+
}
404+
405+
public Mono<Void> connect2(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
406+
327407
return Mono.create(sink -> {
328408

329409
HttpRequest request = requestBuilder.copy()
@@ -333,6 +413,13 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
333413
.GET()
334414
.build();
335415

416+
Flux<ResponseSubscribers.SseResponseEvent> bla = Flux.<ResponseEvent>create(sseSink -> this.httpClient
417+
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
418+
.exceptionallyCompose(e -> {
419+
sseSink.error(e);
420+
return CompletableFuture.failedFuture(e);
421+
})).map(responseEvent -> (ResponseSubscribers.SseResponseEvent) responseEvent);
422+
336423
Disposable connection = Flux.<ResponseEvent>create(sseSink -> this.httpClient
337424
.sendAsync(request, responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, sseSink))
338425
.exceptionallyCompose(e -> {

mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,43 @@
44

55
package io.modelcontextprotocol.client.transport;
66

7+
import static org.assertj.core.api.Assertions.assertThat;
8+
import static org.assertj.core.api.Assertions.assertThatCode;
9+
710
import java.net.URI;
811
import java.net.http.HttpClient;
12+
import java.net.http.HttpClient.Version;
13+
import java.net.http.HttpHeaders;
914
import java.net.http.HttpRequest;
15+
import java.net.http.HttpResponse.ResponseInfo;
1016
import java.time.Duration;
1117
import java.util.Map;
18+
import java.util.concurrent.CopyOnWriteArrayList;
1219
import java.util.concurrent.atomic.AtomicBoolean;
1320
import java.util.concurrent.atomic.AtomicInteger;
1421
import java.util.concurrent.atomic.AtomicReference;
1522
import java.util.function.Function;
1623

17-
import com.fasterxml.jackson.databind.ObjectMapper;
18-
import io.modelcontextprotocol.spec.McpSchema;
19-
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
2024
import org.junit.jupiter.api.AfterEach;
2125
import org.junit.jupiter.api.BeforeEach;
2226
import org.junit.jupiter.api.Test;
2327
import org.junit.jupiter.api.Timeout;
2428
import org.testcontainers.containers.GenericContainer;
2529
import org.testcontainers.containers.wait.strategy.Wait;
30+
31+
import com.fasterxml.jackson.databind.ObjectMapper;
32+
33+
import io.modelcontextprotocol.client.transport.ResponseSubscribers.SseResponseEvent;
34+
import io.modelcontextprotocol.spec.McpSchema;
35+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
36+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCNotification;
37+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
38+
import io.modelcontextprotocol.spec.McpSchema.JSONRPCResponse;
39+
import reactor.core.publisher.Flux;
2640
import reactor.core.publisher.Mono;
2741
import reactor.core.publisher.Sinks;
2842
import reactor.test.StepVerifier;
2943

30-
import org.springframework.http.codec.ServerSentEvent;
31-
32-
import static org.assertj.core.api.Assertions.assertThat;
33-
import static org.assertj.core.api.Assertions.assertThatCode;
34-
3544
/**
3645
* Tests for the {@link HttpClientSseClientTransport} class.
3746
*
@@ -51,28 +60,70 @@ class HttpClientSseClientTransportTests {
5160

5261
private TestHttpClientSseClientTransport transport;
5362

63+
public record MyResponseInfo(int statusCode, HttpHeaders headers, Version version) implements ResponseInfo {
64+
MyResponseInfo(int statusCode, HttpHeaders headers) {
65+
this(statusCode, headers, Version.HTTP_1_1);
66+
}
67+
68+
MyResponseInfo(int statusCode) {
69+
this(statusCode, HttpHeaders.of(Map.of(), (k, v) -> true), Version.HTTP_1_1);
70+
}
71+
}
72+
5473
// Test class to access protected methods
5574
static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport {
5675

5776
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
5877

59-
private Sinks.Many<ServerSentEvent<String>> events = Sinks.many().unicast().onBackpressureBuffer();
78+
private Sinks.Many<SseResponseEvent> events = Sinks.many().unicast().onBackpressureBuffer();
6079

6180
public TestHttpClientSseClientTransport(final String baseUri) {
62-
super(HttpClient.newHttpClient(), HttpRequest.newBuilder(), baseUri, "/sse", new ObjectMapper());
81+
super(HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).build(),
82+
HttpRequest.newBuilder().header("Content-Type", "application/json"), baseUri, "/sse",
83+
new ObjectMapper());
84+
}
85+
86+
CopyOnWriteArrayList<JSONRPCRequest> requestMessages = new CopyOnWriteArrayList<>();
87+
88+
CopyOnWriteArrayList<JSONRPCNotification> notificationMessages = new CopyOnWriteArrayList<>();
89+
90+
CopyOnWriteArrayList<JSONRPCResponse> responseMessages = new CopyOnWriteArrayList<>();
91+
92+
Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler = (messageMono) -> messageMono
93+
.doOnNext(message -> {
94+
// System.out.println("Received message $$$$$$$$$$$$$$: " + message);
95+
if (message instanceof JSONRPCRequest request) {
96+
requestMessages.add(request);
97+
}
98+
else if (message instanceof JSONRPCNotification notificaiton) {
99+
notificationMessages.add(notificaiton);
100+
}
101+
else if (message instanceof JSONRPCResponse response) {
102+
responseMessages.add(response);
103+
}
104+
else {
105+
throw new IllegalArgumentException("Unsupported message type: " + message.getClass());
106+
}
107+
});
108+
109+
@Override
110+
protected Flux<SseResponseEvent> eventStream() {
111+
return super.eventStream().mergeWith(events.asFlux());
63112
}
64113

65114
public int getInboundMessageCount() {
66115
return inboundMessageCount.get();
67116
}
68117

69118
public void simulateEndpointEvent(String jsonMessage) {
70-
events.tryEmitNext(ServerSentEvent.<String>builder().event("endpoint").data(jsonMessage).build());
119+
events.tryEmitNext(new SseResponseEvent(new MyResponseInfo(200),
120+
new ResponseSubscribers.SseEvent(null, "endpoint", jsonMessage)));
71121
inboundMessageCount.incrementAndGet();
72122
}
73123

74124
public void simulateMessageEvent(String jsonMessage) {
75-
events.tryEmitNext(ServerSentEvent.<String>builder().event("message").data(jsonMessage).build());
125+
events.tryEmitNext(new SseResponseEvent(new MyResponseInfo(200),
126+
new ResponseSubscribers.SseEvent(null, "message", jsonMessage)));
76127
inboundMessageCount.incrementAndGet();
77128
}
78129

@@ -88,7 +139,7 @@ void startContainer() {
88139
void setUp() {
89140
startContainer();
90141
transport = new TestHttpClientSseClientTransport(host);
91-
transport.connect(Function.identity()).block();
142+
transport.connect(transport.handler).block();
92143
}
93144

94145
@AfterEach
@@ -123,6 +174,7 @@ void testMessageProcessing() {
123174
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
124175

125176
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
177+
assertThat(transport.requestMessages).hasSize(1);
126178
}
127179

128180
@Test

0 commit comments

Comments
 (0)