22
33import com .fasterxml .jackson .core .type .TypeReference ;
44import com .fasterxml .jackson .databind .ObjectMapper ;
5+ import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
56import io .modelcontextprotocol .spec .McpError ;
67import io .modelcontextprotocol .spec .McpSchema ;
78import io .modelcontextprotocol .spec .McpServerTransport ;
89import io .modelcontextprotocol .spec .McpStreamableServerSession ;
910import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
11+ import io .modelcontextprotocol .spec .McpTransportContext ;
1012import io .modelcontextprotocol .util .Assert ;
1113import org .slf4j .Logger ;
1214import org .slf4j .LoggerFactory ;
2527
2628import java .io .IOException ;
2729import java .util .concurrent .ConcurrentHashMap ;
30+ import java .util .function .Function ;
2831
2932public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
3033
@@ -48,6 +51,9 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
4851
4952 private final ConcurrentHashMap <String , McpStreamableServerSession > sessions = new ConcurrentHashMap <>();
5053
54+ // TODO: add means to specify this
55+ private Function <ServerRequest , McpTransportContext > contextExtractor = req -> new DefaultMcpTransportContext ();
56+
5157 /**
5258 * Flag indicating if the transport is shutting down.
5359 */
@@ -183,6 +189,8 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
183189 return ServerResponse .status (HttpStatus .SERVICE_UNAVAILABLE ).bodyValue ("Server is shutting down" );
184190 }
185191
192+ McpTransportContext transportContext = this .contextExtractor .apply (request );
193+
186194 return Mono .defer (() -> {
187195 if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
188196 return ServerResponse .badRequest ().build (); // TODO: say we need a session id
@@ -204,11 +212,11 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
204212 return ServerResponse .ok ().contentType (MediaType .TEXT_EVENT_STREAM )
205213 .body (Flux .<ServerSentEvent <?>>create (sink -> {
206214 WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport (sink );
207- McpStreamableServerSession .McpStreamableServerSessionStream genericStream = session .newStream (sessionTransport );
208- sink .onDispose (genericStream ::close );
215+ McpStreamableServerSession .McpStreamableServerSessionStream listeningStream = session .listeningStream (sessionTransport );
216+ sink .onDispose (listeningStream ::close );
209217 }), ServerSentEvent .class );
210218
211- });
219+ }). contextWrite ( ctx -> ctx . put ( McpTransportContext . KEY , transportContext )) ;
212220 }
213221
214222 /**
@@ -231,6 +239,8 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
231239 return ServerResponse .status (HttpStatus .SERVICE_UNAVAILABLE ).bodyValue ("Server is shutting down" );
232240 }
233241
242+ McpTransportContext transportContext = this .contextExtractor .apply (request );
243+
234244 return request .bodyToMono (String .class ).<ServerResponse >flatMap (body -> {
235245 try {
236246 McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (objectMapper , body );
@@ -261,7 +271,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
261271 return ServerResponse .ok ().contentType (MediaType .TEXT_EVENT_STREAM )
262272 .body (Flux .<ServerSentEvent <?>>create (sink -> {
263273 WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport (sink );
264- Mono <Void > stream = session .handleStream (jsonrpcRequest , st );
274+ Mono <Void > stream = session .responseStream (jsonrpcRequest , st );
265275 Disposable streamSubscription = stream
266276 .doOnError (err -> sink .error (err ))
267277 .contextWrite (sink .contextView ())
@@ -276,7 +286,7 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
276286 logger .error ("Failed to deserialize message: {}" , e .getMessage ());
277287 return ServerResponse .badRequest ().bodyValue (new McpError ("Invalid message format" ));
278288 }
279- });
289+ }). contextWrite ( ctx -> ctx . put ( McpTransportContext . KEY , transportContext )) ;
280290 }
281291
282292 private class WebFluxStreamableMcpSessionTransport implements McpServerTransport {
0 commit comments