From 3491098019d2d1cd3cbbb73fed10893f3a8fa5e0 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Fri, 11 Apr 2025 14:48:28 +0200 Subject: [PATCH 1/6] feat(client): adds StreamableHttpClientTransport --- .../StreamableHttpClientTransport.java | 410 ++++++++++++++++++ ...treamableHttpClientTransportAsyncTest.java | 44 ++ ...StreamableHttpClientTransportSyncTest.java | 44 ++ 3 files changed, 498 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java new file mode 100644 index 000000000..446a74138 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java @@ -0,0 +1,410 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ + +package io.modelcontextprotocol.client.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.spec.McpClientTransport; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * A transport implementation for the Model Context Protocol (MCP) using JSON streaming. + * + * @author Aliaksei Darafeyeu + */ +public class StreamableHttpClientTransport implements McpClientTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class); + + private final HttpClientSseClientTransport sseClientTransport; + + private final HttpClient httpClient; + + private final HttpRequest.Builder requestBuilder; + + private final ObjectMapper objectMapper; + + private final URI uri; + + private final AtomicReference state = new AtomicReference<>(TransportState.DISCONNECTED); + + private final AtomicReference lastEventId = new AtomicReference<>(); + + private final AtomicBoolean fallbackToSse = new AtomicBoolean(false); + + StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder, + final ObjectMapper objectMapper, final String baseUri, final String endpoint, + final HttpClientSseClientTransport sseClientTransport) { + this.httpClient = httpClient; + this.requestBuilder = requestBuilder; + this.objectMapper = objectMapper; + this.uri = URI.create(baseUri + endpoint); + this.sseClientTransport = sseClientTransport; + } + + /** + * Creates a new StreamableHttpClientTransport instance with the specified URI. + * @param uri the URI to connect to + * @return a new Builder instance + */ + public static Builder builder(final String uri) { + return new Builder().withBaseUri(uri); + } + + /** + * The state of the Transport connection. + */ + public enum TransportState { + + DISCONNECTED, CONNECTING, CONNECTED, CLOSED + + } + + /** + * A builder for creating instances of WebSocketClientTransport. + */ + public static class Builder { + + private final HttpClient.Builder clientBuilder = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_1_1) + .connectTimeout(Duration.ofSeconds(10)); + + private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() + .header("Accept", "application/json, text/event-stream"); + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String baseUri; + + private String endpoint = "/mcp"; + + private Consumer clientCustomizer; + + private Consumer requestCustomizer; + + public Builder withCustomizeClient(final Consumer clientCustomizer) { + Assert.notNull(clientCustomizer, "clientCustomizer must not be null"); + clientCustomizer.accept(clientBuilder); + this.clientCustomizer = clientCustomizer; + return this; + } + + public Builder withCustomizeRequest(final Consumer requestCustomizer) { + Assert.notNull(requestCustomizer, "requestCustomizer must not be null"); + requestCustomizer.accept(requestBuilder); + this.requestCustomizer = requestCustomizer; + return this; + } + + public Builder withObjectMapper(final ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + public Builder withBaseUri(final String baseUri) { + Assert.hasText(baseUri, "baseUri must not be empty"); + this.baseUri = baseUri; + return this; + } + + public Builder withEndpoint(final String endpoint) { + Assert.hasText(endpoint, "endpoint must not be empty"); + this.endpoint = endpoint; + return this; + } + + public StreamableHttpClientTransport build() { + final HttpClientSseClientTransport.Builder builder = HttpClientSseClientTransport.builder(baseUri) + .objectMapper(objectMapper); + if (clientCustomizer != null) { + builder.customizeClient(clientCustomizer); + } + + if (requestCustomizer != null) { + builder.customizeRequest(requestCustomizer); + } + + if (!endpoint.equals("/mcp")) { + builder.sseEndpoint(endpoint); + } + + return new StreamableHttpClientTransport(clientBuilder.build(), requestBuilder, objectMapper, baseUri, + endpoint, builder.build()); + } + + } + + @Override + public Mono connect(final Function, Mono> handler) { + if (fallbackToSse.get()) { + return sseClientTransport.connect(handler); + } + + if (!state.compareAndSet(TransportState.DISCONNECTED, TransportState.CONNECTING)) { + return Mono.error(new IllegalStateException("Already connected or connecting")); + } + + return Mono.defer(() -> Mono.fromFuture(() -> { + final HttpRequest.Builder builder = requestBuilder.copy().GET().uri(uri); + final String lastId = lastEventId.get(); + if (lastId != null) { + builder.header("Last-Event-ID", lastId); + } + return httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + }).flatMap(response -> { + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + return sseClientTransport.connect(handler); + } + return handleStreamingResponse(response, handler); + }) + .retryWhen(Retry.backoff(3, Duration.ofSeconds(3)).filter(err -> err instanceof IllegalStateException)) + .doOnSuccess(v -> state.set(TransportState.CONNECTED)) + .doOnTerminate(() -> state.set(TransportState.CLOSED)) + .onErrorResume(e -> { + LOGGER.error("Streamable transport connection error", e); + return Mono.error(e); + })); + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return sendMessage(message, msg -> msg); + } + + public Mono sendMessage(final McpSchema.JSONRPCMessage message, + final Function, Mono> handler) { + if (fallbackToSse.get()) { + return sseClientTransport.sendMessage(message); + } + + if (state.get() == TransportState.CLOSED) { + return Mono.empty(); + } + + return sentPost(message, handler).onErrorResume(e -> { + LOGGER.error("Streamable transport sendMessage error", e); + return Mono.error(e); + }); + } + + /** + * Sends a list of messages to the server. + * @param messages the list of messages to send + * @return a Mono that completes when all messages have been sent + */ + public Mono sendMessages(final List messages, + final Function, Mono> handler) { + if (fallbackToSse.get()) { + return Flux.fromIterable(messages).flatMap(this::sendMessage).then(); + } + + if (state.get() == TransportState.CLOSED) { + return Mono.empty(); + } + + return sentPost(messages, handler).onErrorResume(e -> { + LOGGER.error("Streamable transport sendMessages error", e); + return Mono.error(e); + }); + } + + private Mono sentPost(final Object msg, + final Function, Mono> handler) { + return serializeJson(msg).flatMap(json -> { + final HttpRequest request = requestBuilder.copy() + .POST(HttpRequest.BodyPublishers.ofString(json)) + .uri(uri) + .build(); + return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream())) + .flatMap(response -> { + + // If the response is 202 Accepted, there's no body to process + if (response.statusCode() == 202) { + return Mono.empty(); + } + + if (response.statusCode() == 405 || response.statusCode() == 404) { + LOGGER.warn("Operation not allowed, falling back to SSE"); + fallbackToSse.set(true); + if (msg instanceof McpSchema.JSONRPCMessage message) { + return sseClientTransport.sendMessage(message); + } + + if (msg instanceof List list) { + @SuppressWarnings("unchecked") + final List messages = (List) list; + return Flux.fromIterable(messages).flatMap(this::sendMessage).then(); + } + } + + if (response.statusCode() >= 400) { + return Mono + .error(new IllegalArgumentException("Unexpected status code: " + response.statusCode())); + } + + return handleStreamingResponse(response, handler); + }); + }); + + } + + private Mono serializeJson(final Object input) { + try { + if (input instanceof McpSchema.JSONRPCMessage || input instanceof List) { + return Mono.just(objectMapper.writeValueAsString(input)); + } + else { + return Mono.error(new IllegalArgumentException("Unsupported message type for serialization")); + } + } + catch (IOException e) { + LOGGER.error("Error serializing JSON-RPC message", e); + return Mono.error(e); + } + } + + private Mono handleStreamingResponse(final HttpResponse response, + final Function, Mono> handler) { + final String contentType = response.headers().firstValue("Content-Type").orElse(""); + if (contentType.contains("application/json-seq")) { + return handleJsonStream(response, handler); + } + else if (contentType.contains("text/event-stream")) { + return handleSseStream(response, handler); + } + else if (contentType.contains("application/json")) { + return handleSingleJson(response, handler); + } + else { + return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType)); + } + } + + private Mono handleSingleJson(final HttpResponse response, + final Function, Mono> handler) { + return Mono.fromCallable(() -> { + final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + new String(response.body().readAllBytes(), StandardCharsets.UTF_8)); + return handler.apply(Mono.just(msg)); + }).flatMap(Function.identity()).then(); + } + + private Mono handleJsonStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()).flatMap(jsonLine -> { + try { + final McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, jsonLine); + return handler.apply(Mono.just(message)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON line", e); + return Mono.empty(); + } + }).then(); + } + + private Mono handleSseStream(final HttpResponse response, + final Function, Mono> handler) { + return Flux.fromStream(new BufferedReader(new InputStreamReader(response.body())).lines()) + .map(String::trim) + .bufferUntil(String::isEmpty) + .map(eventLines -> { + String event = ""; + String data = ""; + String id = ""; + + for (String line : eventLines) { + if (line.startsWith("event: ")) + event = line.substring(7).trim(); + else if (line.startsWith("data: ")) + data += line.substring(6).trim() + "\n"; + else if (line.startsWith("id: ")) + id = line.substring(4).trim(); + } + + if (data.endsWith("\n")) { + data = data.substring(0, data.length() - 1); + } + + return new FlowSseClient.SseEvent(event, data, id); + }) + .filter(sseEvent -> "message".equals(sseEvent.type())) + .doOnNext(sseEvent -> { + lastEventId.set(sseEvent.id()); + try { + String rawData = sseEvent.data().trim(); + JsonNode node = objectMapper.readTree(rawData); + + if (node.isArray()) { + for (JsonNode item : node) { + String rawMessage = objectMapper.writeValueAsString(item); + McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + rawMessage); + handler.apply(Mono.just(msg)).subscribe(); + } + } + else if (node.isObject()) { + String rawMessage = objectMapper.writeValueAsString(node); + McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, rawMessage); + handler.apply(Mono.just(msg)).subscribe(); + } + else { + LOGGER.warn("Unexpected JSON in SSE data: {}", rawData); + } + } + catch (IOException e) { + LOGGER.error("Error processing SSE event: {}", sseEvent.data(), e); + } + }) + .then(); + } + + @Override + public Mono closeGracefully() { + state.set(TransportState.CLOSED); + if (fallbackToSse.get()) { + return sseClientTransport.closeGracefully(); + } + return Mono.empty(); + } + + @Override + public T unmarshalFrom(final Object data, final TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + public TransportState getState() { + return state.get(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java new file mode 100644 index 000000000..4447d0b57 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportAsyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpAsyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportAsyncTest extends AbstractMcpAsyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java new file mode 100644 index 000000000..4fc203953 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StreamableHttpClientTransportSyncTest.java @@ -0,0 +1,44 @@ +package io.modelcontextprotocol.client; + +import org.junit.jupiter.api.Timeout; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +import io.modelcontextprotocol.client.transport.StreamableHttpClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Tests for the {@link McpSyncClient} with {@link StreamableHttpClientTransport}. + * + * @author Aliaksei Darafeyeu + */ +@Timeout(15) +public class StreamableHttpClientTransportSyncTest extends AbstractMcpSyncClientTests { + + String host = "http://localhost:3003"; + + // Uses the https://github.com/tzolov/mcp-everything-server-docker-image + @SuppressWarnings("resource") + GenericContainer container = new GenericContainer<>("docker.io/tzolov/mcp-everything-server:v1") + .withLogConsumer(outputFrame -> System.out.println(outputFrame.getUtf8String())) + .withExposedPorts(3001) + .waitingFor(Wait.forHttp("/").forStatusCode(404)); + + @Override + protected McpClientTransport createMcpTransport() { + return StreamableHttpClientTransport.builder(host).build(); + } + + @Override + protected void onStart() { + container.start(); + int port = container.getMappedPort(3001); + host = "http://" + container.getHost() + ":" + port; + } + + @Override + protected void onClose() { + container.stop(); + } + +} From 3582abf9015ca4768659af5e4af771fe790be503 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Mon, 14 Apr 2025 20:25:40 +0200 Subject: [PATCH 2/6] feat(client): small enhancements + adds Batch to McpSchema to simplify StreamableHttpClientTransport --- .../StreamableHttpClientTransport.java | 204 ++++++++++-------- .../modelcontextprotocol/spec/McpSchema.java | 28 ++- 2 files changed, 147 insertions(+), 85 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java index 446a74138..556d85593 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java @@ -14,6 +14,7 @@ import java.net.http.HttpResponse; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -43,6 +44,24 @@ public class StreamableHttpClientTransport implements McpClientTransport { private static final Logger LOGGER = LoggerFactory.getLogger(StreamableHttpClientTransport.class); + private static final String DEFAULT_MCP_ENDPOINT = "/mcp"; + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + + private static final String LAST_EVENT_ID = "Last-Event-ID"; + + private static final String ACCEPT = "Accept"; + + private static final String CONTENT_TYPE = "Content-Type"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private static final String APPLICATION_JSON_SEQ = "application/json-seq"; + + private static final String DEFAULT_ACCEPT_VALUES = "%s, %s".formatted(APPLICATION_JSON, TEXT_EVENT_STREAM); + private final HttpClientSseClientTransport sseClientTransport; private final HttpClient httpClient; @@ -57,6 +76,8 @@ public class StreamableHttpClientTransport implements McpClientTransport { private final AtomicReference lastEventId = new AtomicReference<>(); + private final AtomicReference mcpSessionId = new AtomicReference<>(); + private final AtomicBoolean fallbackToSse = new AtomicBoolean(false); StreamableHttpClientTransport(final HttpClient httpClient, final HttpRequest.Builder requestBuilder, @@ -96,14 +117,13 @@ public static class Builder { .version(HttpClient.Version.HTTP_1_1) .connectTimeout(Duration.ofSeconds(10)); - private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder() - .header("Accept", "application/json, text/event-stream"); + private final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(); private ObjectMapper objectMapper = new ObjectMapper(); private String baseUri; - private String endpoint = "/mcp"; + private String endpoint = DEFAULT_MCP_ENDPOINT; private Consumer clientCustomizer; @@ -152,7 +172,7 @@ public StreamableHttpClientTransport build() { builder.customizeRequest(requestCustomizer); } - if (!endpoint.equals("/mcp")) { + if (!endpoint.equals(DEFAULT_MCP_ENDPOINT)) { builder.sseEndpoint(endpoint); } @@ -173,13 +193,24 @@ public Mono connect(final Function, Mono Mono.fromFuture(() -> { - final HttpRequest.Builder builder = requestBuilder.copy().GET().uri(uri); + final HttpRequest.Builder request = requestBuilder.copy().GET().header(ACCEPT, TEXT_EVENT_STREAM).uri(uri); final String lastId = lastEventId.get(); if (lastId != null) { - builder.header("Last-Event-ID", lastId); + request.header(LAST_EVENT_ID, lastId); } - return httpClient.sendAsync(builder.build(), HttpResponse.BodyHandlers.ofInputStream()); + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream()); }).flatMap(response -> { + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + if (response.statusCode() == 405 || response.statusCode() == 404) { LOGGER.warn("Operation not allowed, falling back to SSE"); fallbackToSse.set(true); @@ -192,6 +223,7 @@ public Mono connect(final Function, Mono state.set(TransportState.CLOSED)) .onErrorResume(e -> { LOGGER.error("Streamable transport connection error", e); + state.set(TransportState.DISCONNECTED); return Mono.error(e); })); } @@ -204,67 +236,52 @@ public Mono sendMessage(final McpSchema.JSONRPCMessage message) { public Mono sendMessage(final McpSchema.JSONRPCMessage message, final Function, Mono> handler) { if (fallbackToSse.get()) { - return sseClientTransport.sendMessage(message); + return fallbackToSse(message); } if (state.get() == TransportState.CLOSED) { return Mono.empty(); } - return sentPost(message, handler).onErrorResume(e -> { - LOGGER.error("Streamable transport sendMessage error", e); - return Mono.error(e); - }); - } - - /** - * Sends a list of messages to the server. - * @param messages the list of messages to send - * @return a Mono that completes when all messages have been sent - */ - public Mono sendMessages(final List messages, - final Function, Mono> handler) { - if (fallbackToSse.get()) { - return Flux.fromIterable(messages).flatMap(this::sendMessage).then(); - } - - if (state.get() == TransportState.CLOSED) { - return Mono.empty(); - } - - return sentPost(messages, handler).onErrorResume(e -> { - LOGGER.error("Streamable transport sendMessages error", e); - return Mono.error(e); - }); - } - - private Mono sentPost(final Object msg, - final Function, Mono> handler) { - return serializeJson(msg).flatMap(json -> { - final HttpRequest request = requestBuilder.copy() + return serializeJson(message).flatMap(json -> { + final HttpRequest.Builder request = requestBuilder.copy() .POST(HttpRequest.BodyPublishers.ofString(json)) - .uri(uri) - .build(); - return Mono.fromFuture(httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream())) + .header(ACCEPT, DEFAULT_ACCEPT_VALUES) + .header(CONTENT_TYPE, APPLICATION_JSON) + .uri(uri); + if (mcpSessionId.get() != null) { + request.header(MCP_SESSION_ID, mcpSessionId.get()); + } + + return Mono.fromFuture(httpClient.sendAsync(request.build(), HttpResponse.BodyHandlers.ofInputStream())) .flatMap(response -> { + // server may assign a session ID at initialization time, if yes we + // have to use it for any subsequent requests + if (message instanceof McpSchema.JSONRPCRequest + && ((McpSchema.JSONRPCRequest) message).method().equals(McpSchema.METHOD_INITIALIZE)) { + response.headers() + .firstValue(MCP_SESSION_ID) + .map(String::trim) + .ifPresent(this.mcpSessionId::set); + } + // If the response is 202 Accepted, there's no body to process if (response.statusCode() == 202) { return Mono.empty(); } + // must like server terminate session and the client need to start a + // new session by sending a new `InitializeRequest` without a session + // ID attached. + if (mcpSessionId.get() != null && response.statusCode() == 404) { + mcpSessionId.set(null); + } + if (response.statusCode() == 405 || response.statusCode() == 404) { LOGGER.warn("Operation not allowed, falling back to SSE"); fallbackToSse.set(true); - if (msg instanceof McpSchema.JSONRPCMessage message) { - return sseClientTransport.sendMessage(message); - } - - if (msg instanceof List list) { - @SuppressWarnings("unchecked") - final List messages = (List) list; - return Flux.fromIterable(messages).flatMap(this::sendMessage).then(); - } + return fallbackToSse(message); } if (response.statusCode() >= 400) { @@ -274,18 +291,28 @@ private Mono sentPost(final Object msg, return handleStreamingResponse(response, handler); }); + }).onErrorResume(e -> { + LOGGER.error("Streamable transport sendMessages error", e); + return Mono.error(e); }); } - private Mono serializeJson(final Object input) { + private Mono fallbackToSse(final McpSchema.JSONRPCMessage msg) { + if (msg instanceof McpSchema.JSONRPCBatchRequest batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + if (msg instanceof McpSchema.JSONRPCBatchResponse batch) { + return Flux.fromIterable(batch.items()).flatMap(sseClientTransport::sendMessage).then(); + } + + return sseClientTransport.sendMessage(msg); + } + + private Mono serializeJson(final McpSchema.JSONRPCMessage msg) { try { - if (input instanceof McpSchema.JSONRPCMessage || input instanceof List) { - return Mono.just(objectMapper.writeValueAsString(input)); - } - else { - return Mono.error(new IllegalArgumentException("Unsupported message type for serialization")); - } + return Mono.just(objectMapper.writeValueAsString(msg)); } catch (IOException e) { LOGGER.error("Error serializing JSON-RPC message", e); @@ -295,27 +322,31 @@ private Mono serializeJson(final Object input) { private Mono handleStreamingResponse(final HttpResponse response, final Function, Mono> handler) { - final String contentType = response.headers().firstValue("Content-Type").orElse(""); - if (contentType.contains("application/json-seq")) { + final String contentType = response.headers().firstValue(CONTENT_TYPE).orElse(""); + if (contentType.contains(APPLICATION_JSON_SEQ)) { return handleJsonStream(response, handler); } - else if (contentType.contains("text/event-stream")) { + else if (contentType.contains(TEXT_EVENT_STREAM)) { return handleSseStream(response, handler); } - else if (contentType.contains("application/json")) { + else if (contentType.contains(APPLICATION_JSON)) { return handleSingleJson(response, handler); } - else { - return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType)); - } + return Mono.error(new UnsupportedOperationException("Unsupported Content-Type: " + contentType)); } private Mono handleSingleJson(final HttpResponse response, final Function, Mono> handler) { return Mono.fromCallable(() -> { - final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, - new String(response.body().readAllBytes(), StandardCharsets.UTF_8)); - return handler.apply(Mono.just(msg)); + try { + final McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, + new String(response.body().readAllBytes(), StandardCharsets.UTF_8)); + return handler.apply(Mono.just(msg)); + } + catch (IOException e) { + LOGGER.error("Error processing JSON response", e); + return Mono.error(e); + } }).flatMap(Function.identity()).then(); } @@ -328,7 +359,7 @@ private Mono handleJsonStream(final HttpResponse response, } catch (IOException e) { LOGGER.error("Error processing JSON line", e); - return Mono.empty(); + return Mono.error(e); } }).then(); } @@ -347,7 +378,7 @@ private Mono handleSseStream(final HttpResponse response, if (line.startsWith("event: ")) event = line.substring(7).trim(); else if (line.startsWith("data: ")) - data += line.substring(6).trim() + "\n"; + data += line.substring(6) + "\n"; else if (line.startsWith("id: ")) id = line.substring(4).trim(); } @@ -356,34 +387,39 @@ else if (line.startsWith("id: ")) data = data.substring(0, data.length() - 1); } - return new FlowSseClient.SseEvent(event, data, id); + return new FlowSseClient.SseEvent(id, event, data); }) .filter(sseEvent -> "message".equals(sseEvent.type())) - .doOnNext(sseEvent -> { - lastEventId.set(sseEvent.id()); + .concatMap(sseEvent -> { + String rawData = sseEvent.data().trim(); try { - String rawData = sseEvent.data().trim(); JsonNode node = objectMapper.readTree(rawData); - + List messages = new ArrayList<>(); if (node.isArray()) { for (JsonNode item : node) { - String rawMessage = objectMapper.writeValueAsString(item); - McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, - rawMessage); - handler.apply(Mono.just(msg)).subscribe(); + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, item.toString())); } } else if (node.isObject()) { - String rawMessage = objectMapper.writeValueAsString(node); - McpSchema.JSONRPCMessage msg = McpSchema.deserializeJsonRpcMessage(objectMapper, rawMessage); - handler.apply(Mono.just(msg)).subscribe(); + messages.add(McpSchema.deserializeJsonRpcMessage(objectMapper, node.toString())); } else { - LOGGER.warn("Unexpected JSON in SSE data: {}", rawData); + String warning = "Unexpected JSON in SSE data: " + rawData; + LOGGER.warn(warning); + return Mono.error(new IllegalArgumentException(warning)); } + + return Flux.fromIterable(messages) + .concatMap(msg -> handler.apply(Mono.just(msg))) + .then(Mono.fromRunnable(() -> { + if (!sseEvent.id().isEmpty()) { + lastEventId.set(sseEvent.id()); + } + })); } catch (IOException e) { - LOGGER.error("Error processing SSE event: {}", sseEvent.data(), e); + LOGGER.error("Error parsing SSE JSON: {}", rawData, e); + return Mono.error(e); } }) .then(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 8df8a1584..93bfb7489 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -10,6 +10,7 @@ import java.util.List; import java.util.Map; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -173,12 +174,37 @@ else if (map.containsKey("result") || map.containsKey("error")) { // --------------------------- // JSON-RPC Message Types // --------------------------- - public sealed interface JSONRPCMessage permits JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { + public sealed interface JSONRPCMessage + permits JSONRPCBatchRequest, JSONRPCBatchResponse, JSONRPCRequest, JSONRPCNotification, JSONRPCResponse { String jsonrpc(); } + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchRequest( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + + @JsonInclude(JsonInclude.Include.NON_ABSENT) + @JsonIgnoreProperties(ignoreUnknown = true) + public record JSONRPCBatchResponse( // @formatter:off + @JsonProperty("items") List items) implements JSONRPCMessage { + + @Override + @JsonIgnore + public String jsonrpc() { + return JSONRPC_VERSION; + } + } // @formatter:on + @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record JSONRPCRequest( // @formatter:off From 028ad6d8b2e6d4110d39ec629312452facc46343 Mon Sep 17 00:00:00 2001 From: "a.darafeyeu" Date: Wed, 16 Apr 2025 20:39:27 +0200 Subject: [PATCH 3/6] feat(client): makes possible to use StreamableHttpClientTransport in McpClient --- .../client/McpAsyncClient.java | 18 ++++++++- .../client/McpClient.java | 37 ++++++++++++++++--- .../client/McpSyncClient.java | 25 ++++++++----- .../StreamableHttpClientTransport.java | 31 ++-------------- .../spec/McpClientSession.java | 17 ++++++++- .../spec/McpClientSessionTests.java | 4 +- 6 files changed, 85 insertions(+), 47 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index e3a997ba3..39f33c95f 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -160,7 +160,7 @@ public class McpAsyncClient { * @param features the MCP Client supported features. */ McpAsyncClient(McpClientTransport transport, Duration requestTimeout, Duration initializationTimeout, - McpClientFeatures.Async features) { + McpClientFeatures.Async features, boolean connectOnInit) { Assert.notNull(transport, "Transport must not be null"); Assert.notNull(requestTimeout, "Request timeout must not be null"); @@ -235,7 +235,9 @@ public class McpAsyncClient { asyncLoggingNotificationHandler(loggingConsumersFinal)); this.mcpSession = new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); - + if (connectOnInit) { + this.mcpSession.openSSE(); + } } /** @@ -302,6 +304,18 @@ public Mono closeGracefully() { return this.mcpSession.closeGracefully(); } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.mcpSession.openSSE(); + } + // -------------------------- // Initialization // -------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index a1dc11685..dac2ee8a3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -157,11 +157,13 @@ class SyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Java SDK MCP Client", "1.0.0"); + private Implementation clientInfo = new Implementation("Java SDK MCP Sync Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -195,6 +197,17 @@ public SyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public SyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -368,8 +381,8 @@ public McpSyncClient build() { McpClientFeatures.Async asyncFeatures = McpClientFeatures.Async.fromSync(syncFeatures); - return new McpSyncClient( - new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, asyncFeatures)); + return new McpSyncClient(new McpAsyncClient(transport, this.requestTimeout, this.initializationTimeout, + asyncFeatures, this.connectOnInit)); } } @@ -396,11 +409,13 @@ class AsyncSpec { private Duration requestTimeout = Duration.ofSeconds(20); // Default timeout + private boolean connectOnInit = true; // Default true, for backward compatibility + private Duration initializationTimeout = Duration.ofSeconds(20); private ClientCapabilities capabilities; - private Implementation clientInfo = new Implementation("Spring AI MCP Client", "0.3.1"); + private Implementation clientInfo = new Implementation("Java SDK MCP Async Client", "0.10.0"); private final Map roots = new HashMap<>(); @@ -434,6 +449,17 @@ public AsyncSpec requestTimeout(Duration requestTimeout) { return this; } + /** + * Sets whether to connect to the server during the initialization phase (open an + * SSE stream). + * @param connectOnInit true to open an SSE stream during the initialization + * @return This builder instance for method chaining + */ + public AsyncSpec withConnectOnInit(final boolean connectOnInit) { + this.connectOnInit = connectOnInit; + return this; + } + /** * @param initializationTimeout The duration to wait for the initialization * lifecycle step to complete. @@ -606,7 +632,8 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.promptsChangeConsumers, - this.loggingConsumers, this.samplingHandler)); + this.loggingConsumers, this.samplingHandler), + this.connectOnInit); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java index a8fb979e1..e9676e8a0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpSyncClient.java @@ -137,6 +137,18 @@ public boolean closeGracefully() { return true; } + // --------------------------- + // open an SSE stream + // --------------------------- + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + this.delegate.openSSE(); + } + /** * The initialization phase MUST be the first interaction between client and server. * During this phase, the client and server: @@ -156,9 +168,7 @@ public boolean closeGracefully() { * The server MUST respond with its own capabilities and information: * {@link McpSchema.ServerCapabilities}.
* After successful initialization, the client MUST send an initialized notification - * to indicate it is ready to begin normal operations. - * - *
+ * to indicate it is ready to begin normal operations.
* * Initialization @@ -280,9 +290,8 @@ public McpSchema.ReadResourceResult readResource(McpSchema.ReadResourceRequest r /** * Resource templates allow servers to expose parameterized resources using URI - * templates. Arguments may be auto-completed through the completion API. - * - * Request a list of resource templates the server has. + * templates. Arguments may be auto-completed through the completion API. Request a + * list of resource templates the server has. * @param cursor the cursor * @return the list of resource templates result. */ @@ -301,9 +310,7 @@ public McpSchema.ListResourceTemplatesResult listResourceTemplates() { /** * Subscriptions. The protocol supports optional subscriptions to resource changes. * Clients can subscribe to specific resources and receive notifications when they - * change. - * - * Send a resources/subscribe request. + * change. Send a resources/subscribe request. * @param subscribeRequest the subscribe request contains the uri of the resource to * subscribe to. */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java index 556d85593..6e0c5d7ba 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/StreamableHttpClientTransport.java @@ -72,8 +72,6 @@ public class StreamableHttpClientTransport implements McpClientTransport { private final URI uri; - private final AtomicReference state = new AtomicReference<>(TransportState.DISCONNECTED); - private final AtomicReference lastEventId = new AtomicReference<>(); private final AtomicReference mcpSessionId = new AtomicReference<>(); @@ -99,15 +97,6 @@ public static Builder builder(final String uri) { return new Builder().withBaseUri(uri); } - /** - * The state of the Transport connection. - */ - public enum TransportState { - - DISCONNECTED, CONNECTING, CONNECTED, CLOSED - - } - /** * A builder for creating instances of WebSocketClientTransport. */ @@ -188,10 +177,6 @@ public Mono connect(final Function, Mono Mono.fromFuture(() -> { final HttpRequest.Builder request = requestBuilder.copy().GET().header(ACCEPT, TEXT_EVENT_STREAM).uri(uri); final String lastId = lastEventId.get(); @@ -219,13 +204,10 @@ public Mono connect(final Function, Mono err instanceof IllegalStateException)) - .doOnSuccess(v -> state.set(TransportState.CONNECTED)) - .doOnTerminate(() -> state.set(TransportState.CLOSED)) .onErrorResume(e -> { LOGGER.error("Streamable transport connection error", e); - state.set(TransportState.DISCONNECTED); return Mono.error(e); - })); + })).doOnTerminate(this::closeGracefully); } @Override @@ -239,10 +221,6 @@ public Mono sendMessage(final McpSchema.JSONRPCMessage message, return fallbackToSse(message); } - if (state.get() == TransportState.CLOSED) { - return Mono.empty(); - } - return serializeJson(message).flatMap(json -> { final HttpRequest.Builder request = requestBuilder.copy() .POST(HttpRequest.BodyPublishers.ofString(json)) @@ -427,7 +405,8 @@ else if (node.isObject()) { @Override public Mono closeGracefully() { - state.set(TransportState.CLOSED); + mcpSessionId.set(null); + lastEventId.set(null); if (fallbackToSse.get()) { return sseClientTransport.closeGracefully(); } @@ -439,8 +418,4 @@ public T unmarshalFrom(final Object data, final TypeReference typeRef) { return objectMapper.convertValue(data, typeRef); } - public TransportState getState() { - return state.get(); - } - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index f577b493a..4c10fba6a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -61,7 +61,7 @@ public class McpClientSession implements McpSession { /** Atomic counter for generating unique request IDs */ private final AtomicLong requestCounter = new AtomicLong(0); - private final Disposable connection; + private Disposable connection; /** * Functional interface for handling incoming JSON-RPC requests. Implementations @@ -116,6 +116,17 @@ public McpClientSession(Duration requestTimeout, McpClientTransport transport, this.transport = transport; this.requestHandlers.putAll(requestHandlers); this.notificationHandlers.putAll(notificationHandlers); + } + + /** + * The client may issue an HTTP GET to the MCP endpoint. This can be used to open an + * SSE stream, allowing the server to communicate to the client, without the client + * first sending data via HTTP POST. + */ + public void openSSE() { + if (this.connection != null && !this.connection.isDisposed()) { + return; // already connected and still active + } // TODO: consider mono.transformDeferredContextual where the Context contains // the @@ -288,7 +299,9 @@ public Mono closeGracefully() { */ @Override public void close() { - this.connection.dispose(); + if (this.connection != null) { + this.connection.dispose(); + } transport.close(); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java index f72be43e0..223e17eb5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java @@ -48,6 +48,7 @@ void setUp() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))); + session.openSSE(); } @AfterEach @@ -141,6 +142,7 @@ void testRequestHandling() { params -> Mono.just(params)); transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()); + session.openSSE(); // Simulate incoming request McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD, @@ -162,7 +164,7 @@ void testNotificationHandling() { transport = new MockMcpClientTransport(); session = new McpClientSession(TIMEOUT, transport, Map.of(), Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))); - + session.openSSE(); // Simulate incoming notification from the server Map notificationParams = Map.of("status", "ready"); From 9cc92aa6dddee114998e0adaed8ff95701e3483e Mon Sep 17 00:00:00 2001 From: Aliaksei_Darafeyeu Date: Thu, 15 May 2025 20:43:44 +0200 Subject: [PATCH 4/6] fix: adds StreamableHttpServerTransportProvide --- ...StreamableHttpServerTransportProvider.java | 276 ++++++++++++++++++ .../modelcontextprotocol/spec/McpSession.java | 18 ++ .../spec/StatelessMcpSession.java | 82 ++++++ 3 files changed, 376 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 000000000..20934f857 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,276 @@ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.spec.StatelessMcpSession; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +/** + * @author Aliaksei_Darafeyeu + */ +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + private static final String APPLICATION_JSON = "application/json"; + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private McpServerSession.Factory sessionFactory; + + private final ObjectMapper objectMapper; + + private final McpServerTransportProvider legacyTransportProvider; + + private final Set allowedOrigins; + + /** + * Map of active client sessions, keyed by session ID + */ + private final Map sessions = new ConcurrentHashMap<>(); + + public StreamableHttpServerTransportProvider(final ObjectMapper objectMapper, final McpServerTransportProvider legacyTransportProvider, final Set allowedOrigins) { + this.objectMapper = objectMapper; + this.legacyTransportProvider = legacyTransportProvider; + this.allowedOrigins = allowedOrigins; + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()).flatMap(McpSession::closeGracefully).then(); + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + // 1. Origin header check + String origin = req.getHeader("Origin"); + if (origin != null && !allowedOrigins.contains(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed"); + return; + } + + // 2. Accept header routing + final String accept = Optional.ofNullable(req.getHeader("Accept")).orElse(""); + final List acceptTypes = Arrays.stream(accept.split(",")) + .map(String::trim) + .toList(); + + // todo!!!! + if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doPost(req, resp); + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); + } + return; + } + + // 3. Enable async + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(0); + + // resp + resp.setStatus(HttpServletResponse.SC_OK); + resp.setCharacterEncoding("UTF-8"); + + final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper); + final McpSession session = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport); + if (!"stateless".equals(session.getId())) { + resp.setHeader(MCP_SESSION_ID, session.getId()); + } + final Flux messages = parseRequestBodyAsStream(req); + + if (accept.contains(TEXT_EVENT_STREAM)) { + // TODO: Handle streaming JSON-RPC over HTTP + resp.setContentType(TEXT_EVENT_STREAM); + resp.setHeader("Connection", "keep-alive"); + + messages.flatMap(session::handle) + .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) + .then(transport.closeGracefully()) + .subscribe(); + } else if (accept.contains(APPLICATION_JSON)) { + // TODO: Handle traditional JSON-RPC response + resp.setContentType(APPLICATION_JSON); + + messages.flatMap(session::handle) + .collectList() + .flatMap(responses -> { + try { + String json = new ObjectMapper().writeValueAsString( + responses.size() == 1 ? responses.get(0) : responses + ); + resp.getWriter().write(json); + return transport.closeGracefully(); + } catch (IOException e) { + return Mono.error(e); + } + }) + .doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage())) + .subscribe(); + + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Unsupported Accept header"); + } + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doGet(req, resp); + } else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); + } + } + + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + final String sessionId = req.getHeader("mcp-session-id"); + if (sessionId == null || !sessions.containsKey(sessionId)) { + resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Session not found"); + return; + } + + final McpSession session = sessions.remove(sessionId); + session.closeGracefully().subscribe(); + resp.setStatus(HttpServletResponse.SC_NO_CONTENT); + } + + // todo:!!! + private Flux parseRequestBodyAsStream(final HttpServletRequest req) { + return Mono.fromCallable(() -> { + try (final InputStream inputStream = req.getInputStream()) { + final JsonNode node = objectMapper.readTree(inputStream); + if (node.isArray()) { + final List messages = new ArrayList<>(); + for (final JsonNode item : node) { + messages.add(objectMapper.treeToValue(item, McpSchema.JSONRPCMessage.class)); + } + return messages; + } else if (node.isObject()) { + return List.of(objectMapper.treeToValue(node, McpSchema.JSONRPCMessage.class)); + } else { + throw new IllegalArgumentException("Invalid JSON-RPC request: not object or array"); + } + } + }).flatMapMany(Flux::fromIterable); + } + + private McpSession getOrCreateSession(final String sessionId, final McpServerTransport transport) { + if (sessionId != null && sessionFactory != null) { + // Reuse or track sessions if you support that; for now, we just create new ones + return sessions.get(sessionId); + } else if (sessionFactory != null) { + final String newSessionId = UUID.randomUUID().toString(); + return sessions.put(newSessionId, sessionFactory.create(transport)); + } else { + return new StatelessMcpSession(transport); + } + } + + private void sendError(final HttpServletResponse resp, final int code, final String msg) { + try { + resp.sendError(code, msg); + } catch (IOException ignored) { + logger.debug("Exception during send error"); + } + } + + public static class StreamableHttpServerTransport implements McpServerTransport { + private final ObjectMapper objectMapper; + private final OutputStream outputStream; + + public StreamableHttpServerTransport(final OutputStream outputStream, final ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.outputStream = outputStream; + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String json = objectMapper.writeValueAsString(message); + outputStream.write(json.getBytes(StandardCharsets.UTF_8)); + outputStream.write('\n'); + outputStream.flush(); + } catch (IOException e) { + throw new RuntimeException("Failed to send message", e); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + outputStream.flush(); + outputStream.close(); + } catch (IOException e) { + // ignore or log + } + }); + } + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 473a860c2..28fe44f9d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -25,6 +25,24 @@ */ public interface McpSession { + /** + * Retrieve the session id. + * @return session id + */ + String getId(); + + /** + * Called by the {@link McpServerTransportProvider} once the session is determined. + * The purpose of this method is to dispatch the message to an appropriate handler as + * specified by the MCP server implementation + * ({@link io.modelcontextprotocol.server.McpAsyncServer} or + * {@link io.modelcontextprotocol.server.McpSyncServer}) via + * {@link McpServerSession.Factory} that the server creates. + * @param message the incoming JSON-RPC message + * @return a Mono that completes when the message is processed + */ + Mono handle(McpSchema.JSONRPCMessage message); + /** * Sends a request to the model counterparty and expects a response of type T. * diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java new file mode 100644 index 000000000..2b911d11e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java @@ -0,0 +1,82 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import reactor.core.publisher.Mono; + +import java.util.UUID; + +/** + * @author Aliaksei_Darafeyeu + */ +public class StatelessMcpSession implements McpSession { + + private final McpTransport transport; + + public StatelessMcpSession(final McpTransport transport) { + this.transport = transport; + } + + @Override + public String getId() { + return "stateless"; + } + + @Override + public Mono handle(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCRequest request) { + // Stateless sessions do not support incoming requests + McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, + request.id(), + null, + new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Stateless session does not handle requests", + null + ) + ); + return transport.sendMessage(errorResponse); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // Stateless session ignores incoming notifications + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCResponse response) { + // No request/response correlation in stateless mode + return Mono.empty(); + } + else { + return Mono.empty(); + } + } + + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + // Stateless = no request/response correlation + String requestId = UUID.randomUUID().toString(); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest( + McpSchema.JSONRPC_VERSION, method, requestId, requestParams + ); + + return Mono.defer(() -> Mono.from(this.transport.sendMessage(request)).then(Mono.error(new IllegalStateException("Stateless session cannot receive responses"))) + ); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification notification = + new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); + return Mono.from(this.transport.sendMessage(notification)); + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } +} From b1cd84f30546e73b869c35344029c5e9f08e1245 Mon Sep 17 00:00:00 2001 From: Aliaksei_Darafeyeu Date: Tue, 3 Jun 2025 15:05:50 +0200 Subject: [PATCH 5/6] feat(server): adds clean foundations for StreamableHttpServer --- .../server/McpAsyncServerExchange.java | 6 +- ...StreamableHttpServerTransportProvider.java | 586 +++++++++++------- .../spec/McpClientSession.java | 63 +- .../spec/McpLastEventId.java | 10 + .../spec/McpServerSession.java | 6 +- .../spec/McpStatefulSession.java | 197 ++++++ .../spec/McpStatelessSession.java | 93 +++ .../spec/SessionWrapper.java | 9 + .../spec/StatelessMcpSession.java | 82 --- 9 files changed, 701 insertions(+), 351 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpLastEventId.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStatefulSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/SessionWrapper.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d0..edd4f7310 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -9,7 +9,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -22,7 +22,7 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final McpSession session; private final McpSchema.ClientCapabilities clientCapabilities; @@ -43,7 +43,7 @@ public class McpAsyncServerExchange { * features and functionality. * @param clientInfo The client implementation information. */ - public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java index 20934f857..d41618550 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -3,14 +3,18 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpLastEventId; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpSession; -import io.modelcontextprotocol.spec.StatelessMcpSession; +import io.modelcontextprotocol.spec.McpStatelessSession; +import io.modelcontextprotocol.spec.SessionWrapper; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -22,255 +26,365 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.io.PrintWriter; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; /** * @author Aliaksei_Darafeyeu */ +@WebServlet(asyncSupported = true) public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { - /** - * Logger for this class - */ - private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); - - private static final String MCP_SESSION_ID = "Mcp-Session-Id"; - private static final String APPLICATION_JSON = "application/json"; - private static final String TEXT_EVENT_STREAM = "text/event-stream"; - - private McpServerSession.Factory sessionFactory; - - private final ObjectMapper objectMapper; - - private final McpServerTransportProvider legacyTransportProvider; - - private final Set allowedOrigins; - - /** - * Map of active client sessions, keyed by session ID - */ - private final Map sessions = new ConcurrentHashMap<>(); - - public StreamableHttpServerTransportProvider(final ObjectMapper objectMapper, final McpServerTransportProvider legacyTransportProvider, final Set allowedOrigins) { - this.objectMapper = objectMapper; - this.legacyTransportProvider = legacyTransportProvider; - this.allowedOrigins = allowedOrigins; - } - - @Override - public void setSessionFactory(McpServerSession.Factory sessionFactory) { - this.sessionFactory = sessionFactory; - } - - @Override - public Mono notifyClients(String method, Object params) { - if (sessions.isEmpty()) { - logger.debug("No active sessions to broadcast message to"); - return Mono.empty(); - } - - logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()) - .flatMap(session -> session.sendNotification(method, params) - .doOnError(e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) - .onErrorComplete()) - .then(); - } - - @Override - public Mono closeGracefully() { - logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); - return Flux.fromIterable(sessions.values()).flatMap(McpSession::closeGracefully).then(); - } - - @Override - public void destroy() { - closeGracefully().block(); - super.destroy(); - } - - @Override - protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - // 1. Origin header check - String origin = req.getHeader("Origin"); - if (origin != null && !allowedOrigins.contains(origin)) { - resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed"); - return; - } - - // 2. Accept header routing - final String accept = Optional.ofNullable(req.getHeader("Accept")).orElse(""); - final List acceptTypes = Arrays.stream(accept.split(",")) - .map(String::trim) - .toList(); - - // todo!!!! - if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) { - if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { - legacy.doPost(req, resp); - } else { - resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); - } - return; - } - - // 3. Enable async - final AsyncContext asyncContext = req.startAsync(); - asyncContext.setTimeout(0); - - // resp - resp.setStatus(HttpServletResponse.SC_OK); - resp.setCharacterEncoding("UTF-8"); - - final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper); - final McpSession session = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport); - if (!"stateless".equals(session.getId())) { - resp.setHeader(MCP_SESSION_ID, session.getId()); - } - final Flux messages = parseRequestBodyAsStream(req); - - if (accept.contains(TEXT_EVENT_STREAM)) { - // TODO: Handle streaming JSON-RPC over HTTP - resp.setContentType(TEXT_EVENT_STREAM); - resp.setHeader("Connection", "keep-alive"); - - messages.flatMap(session::handle) - .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) - .then(transport.closeGracefully()) - .subscribe(); - } else if (accept.contains(APPLICATION_JSON)) { - // TODO: Handle traditional JSON-RPC response - resp.setContentType(APPLICATION_JSON); - - messages.flatMap(session::handle) - .collectList() - .flatMap(responses -> { - try { - String json = new ObjectMapper().writeValueAsString( - responses.size() == 1 ? responses.get(0) : responses - ); - resp.getWriter().write(json); - return transport.closeGracefully(); - } catch (IOException e) { - return Mono.error(e); - } - }) - .doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage())) - .subscribe(); - - } else { - resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Unsupported Accept header"); - } - } - - @Override - protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { - legacy.doGet(req, resp); - } else { - resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); - } - } - - protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { - final String sessionId = req.getHeader("mcp-session-id"); - if (sessionId == null || !sessions.containsKey(sessionId)) { - resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Session not found"); - return; - } - - final McpSession session = sessions.remove(sessionId); - session.closeGracefully().subscribe(); - resp.setStatus(HttpServletResponse.SC_NO_CONTENT); - } - - // todo:!!! - private Flux parseRequestBodyAsStream(final HttpServletRequest req) { - return Mono.fromCallable(() -> { - try (final InputStream inputStream = req.getInputStream()) { - final JsonNode node = objectMapper.readTree(inputStream); - if (node.isArray()) { - final List messages = new ArrayList<>(); - for (final JsonNode item : node) { - messages.add(objectMapper.treeToValue(item, McpSchema.JSONRPCMessage.class)); - } - return messages; - } else if (node.isObject()) { - return List.of(objectMapper.treeToValue(node, McpSchema.JSONRPCMessage.class)); - } else { - throw new IllegalArgumentException("Invalid JSON-RPC request: not object or array"); - } - } - }).flatMapMany(Flux::fromIterable); - } - - private McpSession getOrCreateSession(final String sessionId, final McpServerTransport transport) { - if (sessionId != null && sessionFactory != null) { - // Reuse or track sessions if you support that; for now, we just create new ones - return sessions.get(sessionId); - } else if (sessionFactory != null) { - final String newSessionId = UUID.randomUUID().toString(); - return sessions.put(newSessionId, sessionFactory.create(transport)); - } else { - return new StatelessMcpSession(transport); - } - } - - private void sendError(final HttpServletResponse resp, final int code, final String msg) { - try { - resp.sendError(code, msg); - } catch (IOException ignored) { - logger.debug("Exception during send error"); - } - } - - public static class StreamableHttpServerTransport implements McpServerTransport { - private final ObjectMapper objectMapper; - private final OutputStream outputStream; - - public StreamableHttpServerTransport(final OutputStream outputStream, final ObjectMapper objectMapper) { - this.objectMapper = objectMapper; - this.outputStream = outputStream; - } - - @Override - public Mono sendMessage(final McpSchema.JSONRPCMessage message) { - return Mono.fromRunnable(() -> { - try { - String json = objectMapper.writeValueAsString(message); - outputStream.write(json.getBytes(StandardCharsets.UTF_8)); - outputStream.write('\n'); - outputStream.flush(); - } catch (IOException e) { - throw new RuntimeException("Failed to send message", e); - } - }); - } - - @Override - public T unmarshalFrom(Object data, TypeReference typeRef) { - return objectMapper.convertValue(data, typeRef); - } - - @Override - public Mono closeGracefully() { - return Mono.fromRunnable(() -> { - try { - outputStream.flush(); - outputStream.close(); - } catch (IOException e) { - // ignore or log - } - }); - } - } + + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + private static final String MCP_SESSION_ID = "Mcp-Session-Id"; + + private static final String APPLICATION_JSON = "application/json"; + + private static final String TEXT_EVENT_STREAM = "text/event-stream"; + + private static final String LAST_EVENT_ID = "Last-Event-ID"; + + private McpServerSession.Factory sessionFactory; + + private final ObjectMapper objectMapper; + + private final McpServerTransportProvider legacyTransportProvider; + + private final Set allowedOrigins; + + /** + * Map of active client sessions, keyed by session ID + */ + private final Map sessions = new ConcurrentHashMap<>(); + + private final Duration sessionTimeout = Duration.ofMinutes(10); + + public StreamableHttpServerTransportProvider(final ObjectMapper objectMapper, + final McpServerTransportProvider legacyTransportProvider, final Set allowedOrigins) { + this.objectMapper = objectMapper; + this.legacyTransportProvider = legacyTransportProvider; + this.allowedOrigins = allowedOrigins; + + // clean-up sessions + Executors.newSingleThreadScheduledExecutor() + .scheduleAtFixedRate(this::cleanupExpiredSessions, 5, 30, TimeUnit.SECONDS); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(final String method, final Object params) { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + return legacy.notifyClients(method, params); + } + + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.session() + .sendNotification(method, params) + .doOnError(e -> logger.error("Failed to send message to session {}: {}", session.session().getId(), + e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + return Flux.fromIterable(sessions.values()).flatMap(session -> session.session().closeGracefully()).then(); + } + + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + @Override + protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + // 1. Origin header check + String origin = req.getHeader("Origin"); + if (origin != null && !allowedOrigins.contains(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed"); + return; + } + + // 2. Accept header routing + final String accept = Optional.ofNullable(req.getHeader("Accept")).orElse(""); + final List acceptTypes = Arrays.stream(accept.split(",")).map(String::trim).toList(); + + // todo!!!! + if (!acceptTypes.contains(APPLICATION_JSON) && !acceptTypes.contains(TEXT_EVENT_STREAM)) { + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doPost(req, resp); + } + else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Legacy transport not available"); + } + return; + } + + // resp + resp.setStatus(HttpServletResponse.SC_OK); + resp.setCharacterEncoding("UTF-8"); + + final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper); + final McpSession session = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport); + if (!"stateless".equals(session.getId())) { + resp.setHeader(MCP_SESSION_ID, session.getId()); + } + + final String lastEventId = req.getHeader(LAST_EVENT_ID); + if (session instanceof McpLastEventId resumeAwareSession) { + resumeAwareSession.resumeFrom(lastEventId); + } + + final List messages; + try { + messages = parseRequestBodyAsStream(req); + } + catch (Exception e) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid JSON input"); + return; + } + + boolean hasRequest = messages.stream().anyMatch(m -> m instanceof McpSchema.JSONRPCRequest); + if (!hasRequest) { + resp.setStatus(HttpServletResponse.SC_ACCEPTED); + return; + } + + if (accept.contains(TEXT_EVENT_STREAM)) { + // TODO: stream with SSE + resp.setContentType(TEXT_EVENT_STREAM); + resp.setHeader("Connection", "keep-alive"); + // Enable async + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(0); + + Flux.fromIterable(messages) + .flatMap(session::handle) + .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) + .then(transport.closeGracefully()) + .subscribe(); + } + else if (accept.contains(APPLICATION_JSON)) { + // TODO: Handle traditional JSON-RPC, + resp.setContentType(APPLICATION_JSON); + Flux.fromIterable(messages).flatMap(session::handle).collectList().flatMap(responses -> { + try { + // todo: collect result if it's a response, + // hm handle should not be void ... + String json = objectMapper.writeValueAsString(responses.size() == 1 ? responses.get(0) : responses); + resp.getWriter().write(json); + return transport.closeGracefully(); + } + catch (IOException e) { + return Mono.error(e); + } + }).doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage())).subscribe(); + + } + else { + resp.sendError(HttpServletResponse.SC_NOT_ACCEPTABLE, "Unsupported Accept header"); + } + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + if (!"text/event-stream".equalsIgnoreCase(req.getHeader("Accept"))) { + resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + return; + } + + // todo: legacy support + if (legacyTransportProvider instanceof HttpServletSseServerTransportProvider legacy) { + legacy.doGet(req, resp); + } + + final String sessionId = req.getHeader(MCP_SESSION_ID); + if (sessionId == null) { + final ServletOutputStream out = resp.getOutputStream(); + final McpServerTransport transport = new StreamableHttpServerTransport(out, new ObjectMapper()); + final McpSession newSession = getOrCreateSession(req.getHeader(MCP_SESSION_ID), transport); + } + + final SessionWrapper wrapper = sessionId != null ? sessions.get(sessionId) : null; + if (wrapper == null) { + resp.sendError(HttpServletResponse.SC_BAD_REQUEST); + return; + } + + if (wrapper.session() instanceof McpLastEventId resumable) { + String lastEventId = req.getHeader(LAST_EVENT_ID); + if (lastEventId != null && !lastEventId.isBlank()) { + resumable.resumeFrom(lastEventId); + } + } + + resp.setContentType(TEXT_EVENT_STREAM); + resp.setCharacterEncoding("UTF-8"); + + AsyncContext async = req.startAsync(); + async.setTimeout(0); + } + + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { + final String sessionId = req.getHeader(MCP_SESSION_ID); + if (sessionId == null || !sessions.containsKey(sessionId)) { + resp.sendError(HttpServletResponse.SC_NOT_FOUND, "Session not found"); + return; + } + + final McpSession session = sessions.remove(sessionId).session(); + session.closeGracefully().subscribe(); + resp.setStatus(HttpServletResponse.SC_NO_CONTENT); + } + + private List parseRequestBodyAsStream(final HttpServletRequest req) { + try (final InputStream inputStream = req.getInputStream()) { + final JsonNode node = objectMapper.readTree(inputStream); + if (node.isArray()) { + final List messages = new ArrayList<>(); + for (final JsonNode item : node) { + messages.add(objectMapper.treeToValue(item, McpSchema.JSONRPCMessage.class)); + } + return messages; + } + else if (node.isObject()) { + final McpSchema.JSONRPCMessage message = objectMapper.treeToValue(node, McpSchema.JSONRPCMessage.class); + if (message instanceof McpSchema.JSONRPCBatchRequest batch) { + return batch.items(); + } + return List.of(message); + } + else { + throw new IllegalArgumentException("Invalid JSON-RPC request: not object or array"); + + } + } + catch (Exception e) { + throw new IllegalArgumentException("Invalid JSON-RPC request: " + e.getMessage()); + } + + } + + private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { + writer.write("event: " + eventType + "\n"); + writer.write("data: " + data + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + + private McpSession getOrCreateSession(final String sessionId, final McpServerTransport transport) { + if (sessionId != null && sessionFactory != null) { + // Reuse or track sessions if you support that; for now, we just create new + // ones + return sessions.get(sessionId).session(); + } + else if (sessionFactory != null) { + final String newSessionId = UUID.randomUUID().toString(); + return sessions.put(newSessionId, new SessionWrapper(sessionFactory.create(transport), Instant.now())) + .session(); + } + else { + return new McpStatelessSession(transport); + } + } + + private void sendError(final HttpServletResponse resp, final int code, final String msg) { + try { + resp.sendError(code, msg); + } + catch (IOException ignored) { + logger.debug("Exception during send error"); + } + } + + private void cleanupExpiredSessions() { + final Instant now = Instant.now(); + final Iterator> it = sessions.entrySet().iterator(); + while (it.hasNext()) { + final Map.Entry entry = it.next(); + if (Duration.between(entry.getValue().lastAccessed(), now).compareTo(sessionTimeout) > 0) { + entry.getValue().session().closeGracefully().subscribe(); + it.remove(); + } + } + } + + public static class StreamableHttpServerTransport implements McpServerTransport { + + private final ObjectMapper objectMapper; + + private final OutputStream outputStream; + + public StreamableHttpServerTransport(final OutputStream outputStream, final ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.outputStream = outputStream; + } + + @Override + public Mono sendMessage(final McpSchema.JSONRPCMessage message) { + return Mono.fromRunnable(() -> { + try { + String json = objectMapper.writeValueAsString(message); + outputStream.write(json.getBytes(StandardCharsets.UTF_8)); + outputStream.write('\n'); + outputStream.flush(); + } + catch (IOException e) { + throw new RuntimeException("Failed to send message", e); + } + }); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + try { + outputStream.flush(); + outputStream.close(); + } + catch (IOException e) { + // ignore or log + } + }); + } + + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java index 4c10fba6a..9b09ccb95 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java @@ -133,38 +133,47 @@ public void openSSE() { // Observation associated with the individual message - it can be used to // create child Observation and emit it together with the message to the // consumer - this.connection = this.transport.connect(mono -> mono.doOnNext(this::handle)).subscribe(); + this.connection = this.transport.connect(mono -> mono.doOnNext(message -> handle(message).subscribe())) + .subscribe(); } - private void handle(McpSchema.JSONRPCMessage message) { - if (message instanceof McpSchema.JSONRPCResponse response) { - logger.debug("Received Response: {}", response); - var sink = pendingResponses.remove(response.id()); - if (sink == null) { - logger.warn("Unexpected response for unknown id {}", response.id()); + @Override + public String getId() { + return "mcpClientSession"; + } + + public Mono handle(McpSchema.JSONRPCMessage message) { + return Mono.defer(() -> { + if (message instanceof McpSchema.JSONRPCResponse response) { + logger.debug("Received Response: {}", response); + var sink = pendingResponses.remove(response.id()); + if (sink == null) { + logger.warn("Unexpected response for unknown id {}", response.id()); + } + else { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + logger.debug("Received request: {}", request); + return handleIncomingRequest(request).onErrorResume(error -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + error.getMessage(), null)); + return this.transport.sendMessage(errorResponse).then(Mono.empty()); + }).flatMap(this.transport::sendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + logger.debug("Received notification: {}", notification); + return handleIncomingNotification(notification) + .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())); } else { - sink.success(response); + logger.warn("Received unknown message type: {}", message); + return Mono.empty(); } - } - else if (message instanceof McpSchema.JSONRPCRequest request) { - logger.debug("Received request: {}", request); - handleIncomingRequest(request).onErrorResume(error -> { - var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)); - return this.transport.sendMessage(errorResponse).then(Mono.empty()); - }).flatMap(this.transport::sendMessage).subscribe(); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - logger.debug("Received notification: {}", notification); - handleIncomingNotification(notification) - .doOnError(error -> logger.error("Error handling notification: {}", error.getMessage())) - .subscribe(); - } - else { - logger.warn("Received unknown message type: {}", message); - } + }); } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLastEventId.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLastEventId.java new file mode 100644 index 000000000..5fbf13032 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLastEventId.java @@ -0,0 +1,10 @@ +package io.modelcontextprotocol.spec; + +/** + * @author Aliaksei_Darafeyeu + */ +public interface McpLastEventId { + + void resumeFrom(final String lastEventId); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..076056d2a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -48,11 +48,11 @@ public class McpServerSession implements McpSession { private final AtomicReference clientInfo = new AtomicReference<>(); - private static final int STATE_UNINITIALIZED = 0; + public static final int STATE_UNINITIALIZED = 0; - private static final int STATE_INITIALIZING = 1; + public static final int STATE_INITIALIZING = 1; - private static final int STATE_INITIALIZED = 2; + public static final int STATE_INITIALIZED = 2; private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatefulSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatefulSession.java new file mode 100644 index 000000000..748e7583b --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatefulSession.java @@ -0,0 +1,197 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Sinks; + +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +/** + * @author Aliaksei_Darafeyeu + */ +public class McpStatefulSession implements McpSession, McpLastEventId { + + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(McpStatefulSession.class); + + private static final int MAX_EVENT_HISTORY = 100; + + private final LinkedHashMap eventHistory = new LinkedHashMap<>(MAX_EVENT_HISTORY, + 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > MAX_EVENT_HISTORY; + } + }; + + private final Sinks.One exchangeSink = Sinks.one(); + + private final Map> pendingResponses = new ConcurrentHashMap<>(); + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private final AtomicInteger state = new AtomicInteger(McpServerSession.STATE_UNINITIALIZED); + + private final String id; + + private final McpServerTransport transport; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final McpServerSession.InitRequestHandler initRequestHandler; + + private final McpServerSession.InitNotificationHandler initNotificationHandler; + + public McpStatefulSession(final String id, final McpServerTransport transport, + final Map> requestHandlers, + final Map notificationHandlers, + final McpServerSession.InitRequestHandler initRequestHandler, + final McpServerSession.InitNotificationHandler initNotificationHandler) { + this.id = id; + this.transport = transport; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + this.initRequestHandler = initRequestHandler; + this.initNotificationHandler = initNotificationHandler; + } + + public void init(final McpSchema.ClientCapabilities clientCapabilities, final McpSchema.Implementation clientInfo) { + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + } + + @Override + public String getId() { + return id; + } + + @Override + public Mono handle(final McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCResponse response) { + MonoSink sink = pendingResponses.remove(response.id()); + if (sink != null) { + sink.success(response); + } + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCRequest request) { + if (McpSchema.METHOD_INITIALIZE.equals(request.method())) { + McpSchema.InitializeRequest initializeRequest = transport.unmarshalFrom(request.params(), + new TypeReference<>() { + }); + this.state.lazySet(McpServerSession.STATE_INITIALIZING); + this.init(initializeRequest.capabilities(), initializeRequest.clientInfo()); + + return this.initRequestHandler.handle(initializeRequest) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .flatMap(this::storeAndSendMessage); + } + + final McpServerSession.RequestHandler handler = requestHandlers.get(request.method()); + if (handler == null) { + McpSchema.JSONRPCResponse error = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + "Unknown method: " + request.method(), null)); + return transport.sendMessage(error); + } + return this.exchangeSink.asMono() + .flatMap(exchange -> handler.handle(exchange, request.params())) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .flatMap(this::storeAndSendMessage); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { + this.state.lazySet(McpServerSession.STATE_INITIALIZED); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + return this.initNotificationHandler.handle(); + } + + final McpServerSession.NotificationHandler handler = notificationHandlers.get(notification.method()); + if (handler != null) { + return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(exchange, notification.params())); + } + return Mono.empty(); + } + return Mono.empty(); + } + + @Override + public void resumeFrom(final String lastEventId) { + logger.info("session received Last-Event-ID: {}", lastEventId); + if (lastEventId != null && !lastEventId.isBlank()) { + return; + } + boolean resume = false; + for (Map.Entry entry : eventHistory.entrySet()) { + if (entry.getKey().equals(lastEventId)) { + resume = true; + } + if (resume) { + transport.sendMessage(entry.getValue()).subscribe(); + } + } + + // todo if resume false, replay all ... + } + + @Override + public Mono sendNotification(String method, Object params) { + logger.debug("sendNotification: {}, {}", method, params); + return transport.sendMessage(new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params)); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + logger.debug("sendRequest: {}, {}, {}", method, requestParams, typeRef); + + // add requestId creations + final String requestId = UUID.randomUUID().toString(); + return Mono.create(sink -> { + pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + transport.sendMessage(request).subscribe(); + }).timeout(Duration.ofSeconds(10)).handle((response, sink) -> { + if (response.error() != null) { + sink.error(new RuntimeException(response.error().message())); + } + else { + sink.next(transport.unmarshalFrom(response.result(), typeRef)); + } + }); + } + + private Mono storeAndSendMessage(McpSchema.JSONRPCMessage message) { + if (message instanceof McpSchema.JSONRPCRequest rq && rq.id() != null) { + eventHistory.put(rq.id().toString(), message); + } + return transport.sendMessage(message); + } + + @Override + public Mono closeGracefully() { + return transport.closeGracefully(); + } + + @Override + public void close() { + closeGracefully().subscribe(); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessSession.java new file mode 100644 index 000000000..53554e439 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessSession.java @@ -0,0 +1,93 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.util.UUID; + +/** + * @author Aliaksei_Darafeyeu + */ +public class McpStatelessSession implements McpSession, McpLastEventId { + + /** + * Logger for this class + */ + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSession.class); + + private final McpTransport transport; + + public McpStatelessSession(final McpTransport transport) { + this.transport = transport; + } + + @Override + public String getId() { + return "stateless"; + } + + @Override + public Mono handle(McpSchema.JSONRPCMessage message) { + logger.info("Handling message: {}", message); + + if (message instanceof McpSchema.JSONRPCRequest request) { + if (McpSchema.METHOD_PING.equals(request.method())) { + McpSchema.JSONRPCResponse pong = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), + "", null); + return transport.sendMessage(pong); + } + + // Stateless sessions do not support incoming requests + McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.METHOD_NOT_FOUND, "Stateless session does not handle requests", null)); + return transport.sendMessage(errorResponse); + } + else if (message instanceof McpSchema.JSONRPCNotification notification) { + // Stateless session ignores incoming notifications + return Mono.empty(); + } + else if (message instanceof McpSchema.JSONRPCResponse response) { + // No request/response correlation in stateless mode + return Mono.empty(); + } + else { + return Mono.empty(); + } + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + // Stateless = no request/response correlation + String requestId = UUID.randomUUID().toString(); + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, + requestParams); + + return Mono.defer(() -> Mono.from(this.transport.sendMessage(request)) + .then(Mono.error(new IllegalStateException("Stateless session cannot receive responses")))); + } + + @Override + public Mono sendNotification(String method, Object params) { + logger.debug("sendNotification: {}, {}", method, params); + return this.transport.sendMessage(new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params)); + } + + @Override + public Mono closeGracefully() { + return this.transport.closeGracefully(); + } + + @Override + public void close() { + this.closeGracefully().subscribe(); + } + + @Override + public void resumeFrom(String lastEventId) { + logger.info("session received Last-Event-ID: {}", lastEventId); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/SessionWrapper.java b/mcp/src/main/java/io/modelcontextprotocol/spec/SessionWrapper.java new file mode 100644 index 000000000..a2016e7b0 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/SessionWrapper.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +import java.time.Instant; + +/** + * @author Aliaksei_Darafeyeu + */ +public record SessionWrapper(McpServerSession session, Instant lastAccessed) { +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java deleted file mode 100644 index 2b911d11e..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/StatelessMcpSession.java +++ /dev/null @@ -1,82 +0,0 @@ -package io.modelcontextprotocol.spec; - -import com.fasterxml.jackson.core.type.TypeReference; -import reactor.core.publisher.Mono; - -import java.util.UUID; - -/** - * @author Aliaksei_Darafeyeu - */ -public class StatelessMcpSession implements McpSession { - - private final McpTransport transport; - - public StatelessMcpSession(final McpTransport transport) { - this.transport = transport; - } - - @Override - public String getId() { - return "stateless"; - } - - @Override - public Mono handle(McpSchema.JSONRPCMessage message) { - if (message instanceof McpSchema.JSONRPCRequest request) { - // Stateless sessions do not support incoming requests - McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse( - McpSchema.JSONRPC_VERSION, - request.id(), - null, - new McpSchema.JSONRPCResponse.JSONRPCError( - McpSchema.ErrorCodes.METHOD_NOT_FOUND, - "Stateless session does not handle requests", - null - ) - ); - return transport.sendMessage(errorResponse); - } - else if (message instanceof McpSchema.JSONRPCNotification notification) { - // Stateless session ignores incoming notifications - return Mono.empty(); - } - else if (message instanceof McpSchema.JSONRPCResponse response) { - // No request/response correlation in stateless mode - return Mono.empty(); - } - else { - return Mono.empty(); - } - } - - - @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - // Stateless = no request/response correlation - String requestId = UUID.randomUUID().toString(); - McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest( - McpSchema.JSONRPC_VERSION, method, requestId, requestParams - ); - - return Mono.defer(() -> Mono.from(this.transport.sendMessage(request)).then(Mono.error(new IllegalStateException("Stateless session cannot receive responses"))) - ); - } - - @Override - public Mono sendNotification(String method, Object params) { - McpSchema.JSONRPCNotification notification = - new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, method, params); - return Mono.from(this.transport.sendMessage(notification)); - } - - @Override - public Mono closeGracefully() { - return this.transport.closeGracefully(); - } - - @Override - public void close() { - this.closeGracefully().subscribe(); - } -} From 40036e329d1c9eb4f403d6ba9215d06ee673512c Mon Sep 17 00:00:00 2001 From: Aliaksei_Darafeyeu Date: Wed, 4 Jun 2025 12:57:25 +0200 Subject: [PATCH 6/6] feat(server): clean up --- ...StreamableHttpServerTransportProvider.java | 38 +++++++++---------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java index d41618550..067ae0c52 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -149,8 +149,6 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S return; } - // resp - resp.setStatus(HttpServletResponse.SC_OK); resp.setCharacterEncoding("UTF-8"); final McpServerTransport transport = new StreamableHttpServerTransport(resp.getOutputStream(), objectMapper); @@ -168,7 +166,7 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S try { messages = parseRequestBodyAsStream(req); } - catch (Exception e) { + catch (final Exception e) { resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Invalid JSON input"); return; } @@ -176,6 +174,9 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S boolean hasRequest = messages.stream().anyMatch(m -> m instanceof McpSchema.JSONRPCRequest); if (!hasRequest) { resp.setStatus(HttpServletResponse.SC_ACCEPTED); + if ("stateless".equals(session.getId())) { + transport.closeGracefully().subscribe(); + } return; } @@ -190,24 +191,17 @@ protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws S Flux.fromIterable(messages) .flatMap(session::handle) .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) - .then(transport.closeGracefully()) + .then(closeIfStateless(session, transport)) .subscribe(); } else if (accept.contains(APPLICATION_JSON)) { // TODO: Handle traditional JSON-RPC, resp.setContentType(APPLICATION_JSON); - Flux.fromIterable(messages).flatMap(session::handle).collectList().flatMap(responses -> { - try { - // todo: collect result if it's a response, - // hm handle should not be void ... - String json = objectMapper.writeValueAsString(responses.size() == 1 ? responses.get(0) : responses); - resp.getWriter().write(json); - return transport.closeGracefully(); - } - catch (IOException e) { - return Mono.error(e); - } - }).doOnError(e -> sendError(resp, 500, "JSON response failed: " + e.getMessage())).subscribe(); + Flux.fromIterable(messages) + .flatMap(session::handle) + .doOnError(e -> sendError(resp, 500, "Streaming failed: " + e.getMessage())) + .then(closeIfStateless(session, transport)) + .subscribe(); } else { @@ -266,7 +260,7 @@ protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws resp.setStatus(HttpServletResponse.SC_NO_CONTENT); } - private List parseRequestBodyAsStream(final HttpServletRequest req) { + private List parseRequestBodyAsStream(final HttpServletRequest req) throws IOException { try (final InputStream inputStream = req.getInputStream()) { final JsonNode node = objectMapper.readTree(inputStream); if (node.isArray()) { @@ -288,10 +282,6 @@ else if (node.isObject()) { } } - catch (Exception e) { - throw new IllegalArgumentException("Invalid JSON-RPC request: " + e.getMessage()); - } - } private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException { @@ -320,6 +310,12 @@ else if (sessionFactory != null) { } } + Mono closeIfStateless(final McpSession session, final McpServerTransport transport) { + return "stateless".equals(session.getId()) + ? transport.closeGracefully() + : Mono.empty(); + } + private void sendError(final HttpServletResponse resp, final int code, final String msg) { try { resp.sendError(code, msg);