44
55package io .modelcontextprotocol .client .transport ;
66
7+ import static org .assertj .core .api .Assertions .assertThat ;
8+ import static org .assertj .core .api .Assertions .assertThatCode ;
9+
710import java .net .URI ;
811import java .net .http .HttpClient ;
12+ import java .net .http .HttpClient .Version ;
13+ import java .net .http .HttpHeaders ;
914import java .net .http .HttpRequest ;
15+ import java .net .http .HttpResponse .ResponseInfo ;
1016import java .time .Duration ;
1117import java .util .Map ;
18+ import java .util .concurrent .CopyOnWriteArrayList ;
1219import java .util .concurrent .atomic .AtomicBoolean ;
1320import java .util .concurrent .atomic .AtomicInteger ;
1421import java .util .concurrent .atomic .AtomicReference ;
1522import java .util .function .Function ;
1623
17- import com .fasterxml .jackson .databind .ObjectMapper ;
18- import io .modelcontextprotocol .spec .McpSchema ;
19- import io .modelcontextprotocol .spec .McpSchema .JSONRPCRequest ;
2024import org .junit .jupiter .api .AfterEach ;
2125import org .junit .jupiter .api .BeforeEach ;
2226import org .junit .jupiter .api .Test ;
2327import org .junit .jupiter .api .Timeout ;
2428import org .testcontainers .containers .GenericContainer ;
2529import org .testcontainers .containers .wait .strategy .Wait ;
30+
31+ import com .fasterxml .jackson .databind .ObjectMapper ;
32+
33+ import io .modelcontextprotocol .client .transport .ResponseSubscribers .SseResponseEvent ;
34+ import io .modelcontextprotocol .spec .McpSchema ;
35+ import io .modelcontextprotocol .spec .McpSchema .JSONRPCMessage ;
36+ import io .modelcontextprotocol .spec .McpSchema .JSONRPCNotification ;
37+ import io .modelcontextprotocol .spec .McpSchema .JSONRPCRequest ;
38+ import io .modelcontextprotocol .spec .McpSchema .JSONRPCResponse ;
39+ import reactor .core .publisher .Flux ;
2640import reactor .core .publisher .Mono ;
2741import reactor .core .publisher .Sinks ;
2842import reactor .test .StepVerifier ;
2943
30- import org .springframework .http .codec .ServerSentEvent ;
31-
32- import static org .assertj .core .api .Assertions .assertThat ;
33- import static org .assertj .core .api .Assertions .assertThatCode ;
34-
3544/**
3645 * Tests for the {@link HttpClientSseClientTransport} class.
3746 *
@@ -51,28 +60,70 @@ class HttpClientSseClientTransportTests {
5160
5261 private TestHttpClientSseClientTransport transport ;
5362
63+ public record MyResponseInfo (int statusCode , HttpHeaders headers , Version version ) implements ResponseInfo {
64+ MyResponseInfo (int statusCode , HttpHeaders headers ) {
65+ this (statusCode , headers , Version .HTTP_1_1 );
66+ }
67+
68+ MyResponseInfo (int statusCode ) {
69+ this (statusCode , HttpHeaders .of (Map .of (), (k , v ) -> true ), Version .HTTP_1_1 );
70+ }
71+ }
72+
5473 // Test class to access protected methods
5574 static class TestHttpClientSseClientTransport extends HttpClientSseClientTransport {
5675
5776 private final AtomicInteger inboundMessageCount = new AtomicInteger (0 );
5877
59- private Sinks .Many <ServerSentEvent < String > > events = Sinks .many ().unicast ().onBackpressureBuffer ();
78+ private Sinks .Many <SseResponseEvent > events = Sinks .many ().unicast ().onBackpressureBuffer ();
6079
6180 public TestHttpClientSseClientTransport (final String baseUri ) {
62- super (HttpClient .newHttpClient (), HttpRequest .newBuilder (), baseUri , "/sse" , new ObjectMapper ());
81+ super (HttpClient .newBuilder ().version (HttpClient .Version .HTTP_1_1 ).build (),
82+ HttpRequest .newBuilder ().header ("Content-Type" , "application/json" ), baseUri , "/sse" ,
83+ new ObjectMapper ());
84+ }
85+
86+ CopyOnWriteArrayList <JSONRPCRequest > requestMessages = new CopyOnWriteArrayList <>();
87+
88+ CopyOnWriteArrayList <JSONRPCNotification > notificationMessages = new CopyOnWriteArrayList <>();
89+
90+ CopyOnWriteArrayList <JSONRPCResponse > responseMessages = new CopyOnWriteArrayList <>();
91+
92+ Function <Mono <JSONRPCMessage >, Mono <JSONRPCMessage >> handler = (messageMono ) -> messageMono
93+ .doOnNext (message -> {
94+ // System.out.println("Received message $$$$$$$$$$$$$$: " + message);
95+ if (message instanceof JSONRPCRequest request ) {
96+ requestMessages .add (request );
97+ }
98+ else if (message instanceof JSONRPCNotification notificaiton ) {
99+ notificationMessages .add (notificaiton );
100+ }
101+ else if (message instanceof JSONRPCResponse response ) {
102+ responseMessages .add (response );
103+ }
104+ else {
105+ throw new IllegalArgumentException ("Unsupported message type: " + message .getClass ());
106+ }
107+ });
108+
109+ @ Override
110+ protected Flux <SseResponseEvent > eventStream () {
111+ return super .eventStream ().mergeWith (events .asFlux ());
63112 }
64113
65114 public int getInboundMessageCount () {
66115 return inboundMessageCount .get ();
67116 }
68117
69118 public void simulateEndpointEvent (String jsonMessage ) {
70- events .tryEmitNext (ServerSentEvent .<String >builder ().event ("endpoint" ).data (jsonMessage ).build ());
119+ events .tryEmitNext (new SseResponseEvent (new MyResponseInfo (200 ),
120+ new ResponseSubscribers .SseEvent (null , "endpoint" , jsonMessage )));
71121 inboundMessageCount .incrementAndGet ();
72122 }
73123
74124 public void simulateMessageEvent (String jsonMessage ) {
75- events .tryEmitNext (ServerSentEvent .<String >builder ().event ("message" ).data (jsonMessage ).build ());
125+ events .tryEmitNext (new SseResponseEvent (new MyResponseInfo (200 ),
126+ new ResponseSubscribers .SseEvent (null , "message" , jsonMessage )));
76127 inboundMessageCount .incrementAndGet ();
77128 }
78129
@@ -88,7 +139,7 @@ void startContainer() {
88139 void setUp () {
89140 startContainer ();
90141 transport = new TestHttpClientSseClientTransport (host );
91- transport .connect (Function . identity () ).block ();
142+ transport .connect (transport . handler ).block ();
92143 }
93144
94145 @ AfterEach
@@ -123,6 +174,7 @@ void testMessageProcessing() {
123174 StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
124175
125176 assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
177+ assertThat (transport .requestMessages ).hasSize (1 );
126178 }
127179
128180 @ Test
0 commit comments