From 1c674cdbd96d2e4690f54bbf8029a0a75bd6892e Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 2 Sep 2025 18:10:17 +0200 Subject: [PATCH 1/2] test: Add additional MCP transport context integration tests - Add integration tests for transport context propagation between MCP clients and servers - Test both sync and async server implementations across all transport types (stateless, streamable, SSE) - Cover Spring WebFlux and WebMVC environments with dedicated test suites - Validate context flow through HTTP headers for authentication, correlation IDs, and metadata - Rename existing McpTransportContextIntegrationTests to SyncServerMcpTransportContextIntegrationTests for clarity Signed-off-by: Christian Tzolov --- ...erMcpTransportContextIntegrationTests.java | 381 +++++++++++++++++ ...erMcpTransportContextIntegrationTests.java | 267 ++++++++++++ mcp-spring/mcp-spring-webmvc/pom.xml | 2 +- .../McpTransportContextIntegrationTests.java | 320 ++++++++++++++ ...erMcpTransportContextIntegrationTests.java | 393 ++++++++++++++++++ ...rMcpTransportContextIntegrationTests.java} | 2 +- 6 files changed, 1363 insertions(+), 2 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java rename mcp/src/test/java/io/modelcontextprotocol/common/{McpTransportContextIntegrationTests.java => SyncServerMcpTransportContextIntegrationTests.java} (99%) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..0fb59b22e --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,381 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers using Spring WebFlux infrastructure. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication for asynchronous server implementations. It tests various combinations of + * client types (sync/async) and server transport mechanisms (stateless, streamable, SSE) + * to ensure proper context handling across different configurations. + * + *

Context Propagation Flow

+ *
    + *
  1. Client sets a value in its transport context (either via thread-local for sync or + * Reactor context for async)
  2. + *
  3. Client-side context provider extracts the value and adds it as an HTTP header to + * the request
  4. + *
  5. Server-side context extractor reads the header from the incoming request
  6. + *
  7. Server handler receives the extracted context and returns the value as the tool + * call result
  8. + *
  9. Test verifies the round-trip context propagation was successful
  10. + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see WebFluxStatelessServerTransport + * @see WebFluxStreamableServerTransportProvider + * @see WebFluxSseServerTransportProvider + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final ThreadLocal SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + // Sync client context provider + private final Supplier syncClientContextProvider = () -> { + var headerValue = SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + // Async client context provider + ExchangeFilterFunction asyncClientContextProvider = (request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }); + + // Tools + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + // Server context extractor + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + // Server transports + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + // Async clients + private final McpAsyncClient asyncStreamableClient = McpClient + .async(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .build(); + + // Sync clients + private final McpSyncClient syncStreamableClient = McpClient + .sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .transportContextProvider(syncClientContextProvider) + .build(); + + private final McpSyncClient syncSseClient = McpClient + .sync(WebFluxSseClientTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) + .build()) + .transportContextProvider(syncClientContextProvider) + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + stopHttpServer(); + } + + @Test + void syncClientStatelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); + assertThat(initResult).isNotNull(); + + SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = syncStreamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void asyncClientStatelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void syncClientStreamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); + assertThat(initResult).isNotNull(); + + SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = syncStreamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void syncClientSseServer() { + + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = syncSseClient.initialize(); + assertThat(initResult).isNotNull(); + + SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = syncSseClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..fefae35cb --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,267 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP client and + * server using synchronous operations in a Spring WebFlux environment. + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different WebFlux-based MCP transport implementations + * + *

+ * The test scenario follows these steps: + *

    + *
  1. The client stores a value in a thread-local variable
  2. + *
  3. The client's transport context provider reads this value and includes it in the MCP + * context
  4. + *
  5. A WebClient filter extracts the context value and adds it as an HTTP header + * (x-test)
  6. + *
  7. The server's {@link McpTransportContextExtractor} reads the header from the + * request
  8. + *
  9. The server returns the header value as the tool call result, validating the + * round-trip
  10. + *
+ * + *

+ * This test demonstrates how custom context can be propagated through HTTP headers in a + * reactive WebFlux environment, enabling features like authentication tokens, correlation + * IDs, or other metadata to flow between MCP client and server. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @since 1.0.0 + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see WebFluxStatelessServerTransport + * @see WebFluxStreamableServerTransportProvider + * @see WebFluxSseServerTransportProvider + */ +@Timeout(15) +public class SyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final BiFunction statelessHandler = ( + transportContext, request) -> { + return new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + }; + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + private final McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + var headerValue = r.headers().firstHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final WebFluxStatelessServerTransport statelessServerTransport = WebFluxStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxStreamableServerTransportProvider streamableServerTransport = WebFluxStreamableServerTransportProvider + .builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final WebFluxSseServerTransportProvider sseServerTransport = WebFluxSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + + private final McpSyncClient streamableClient = McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient.sync(WebFluxSseClientTransport.builder(WebClient.builder() + .baseUrl("http://localhost:" + PORT) + .filter((request, next) -> Mono.deferContextual(ctx -> { + var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + // // do stuff with the context + var headerValue = context.get("client-side-header-value"); + if (headerValue == null) { + return next.exchange(request); + } + var reqWithHeader = ClientRequest.from(request).header(HEADER_NAME, headerValue.toString()).build(); + return next.exchange(reqWithHeader); + }))).build()).transportContextProvider(clientContextProvider).build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private DisposableServer httpServer; + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + stopHttpServer(); + } + + @Test + void statelessServer() { + + startHttpServer(statelessServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + + startHttpServer(streamableServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startHttpServer(sseServerTransport.getRouterFunction()); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startHttpServer(RouterFunction routerFunction) { + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(routerFunction); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + } + + private void stopHttpServer() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index ea262d3a1..170309211 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -41,7 +41,7 @@ test - + io.modelcontextprotocol.sdk mcp-spring-webflux 0.12.0-SNAPSHOT diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java new file mode 100644 index 000000000..6a5ebfe2d --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java @@ -0,0 +1,320 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpClient.SyncSpec; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil; +import io.modelcontextprotocol.server.TomcatTestUtil.TomcatServer; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Import; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * servers using Spring WebMVC transport implementations. + * + *

+ * This test class validates the end-to-end flow of transport context propagation across + * different MCP transport mechanisms in a Spring WebMVC environment. It demonstrates how + * contextual information can be passed from client to server through HTTP headers and + * properly extracted and utilized on the server side. + * + *

Transport Types Tested

+ *
    + *
  • Stateless: Tests context propagation with + * {@link WebMvcStatelessServerTransport} where each request is independent
  • + *
  • Streamable HTTP: Tests context propagation with + * {@link WebMvcStreamableServerTransportProvider} supporting stateful server + * sessions
  • + *
  • Server-Sent Events (SSE): Tests context propagation with + * {@link WebMvcSseServerTransportProvider} for long-lived connections
  • + *
+ * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see McpSyncHttpClientRequestCustomizer + */ +@Timeout(15) +public class McpTransportContextIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private TomcatServer tomcatServer; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final BiFunction statelessHandler = ( + transportContext, + request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); + + private final BiFunction statefulHandler = ( + exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); + + @Configuration + static class TestCommonConfig { + + @Bean + public McpTransportContextExtractor serverContextExtractor() { + return (ServerRequest r) -> { + String headerValue = r.servletRequest().getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + }; + + } + + @Configuration + @EnableWebMvc + @Import(TestCommonConfig.class) + static class TestStatelessConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport( + McpTransportContextExtractor serverContextExtractor) { + + return WebMvcStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + @Configuration + @EnableWebMvc + @Import(TestCommonConfig.class) + static class TestStreamableHttpConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport( + McpTransportContextExtractor serverContextExtractor) { + + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + @Configuration + @EnableWebMvc + @Import(TestCommonConfig.class) + static class TestSseConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport( + McpTransportContextExtractor serverContextExtractor) { + + return WebMvcSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private final McpSyncClient streamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + stopTomcat(); + } + + @Test + void statelessServer() { + startTomcat(TestStatelessConfig.class); + + var statelessServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); + + var mcpServer = McpServer.sync(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void streamableServer() { + + startTomcat(TestStreamableHttpConfig.class); + + var streamableServerTransport = tomcatServer.appContext() + .getBean(WebMvcStreamableServerTransportProvider.class); + + var mcpServer = McpServer.sync(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = streamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = streamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void sseServer() { + startTomcat(TestSseConfig.class); + + var sseServerTransport = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + + var mcpServer = McpServer.sync(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Class componentClass) { + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, componentClass); + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcatServer != null && tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java new file mode 100644 index 000000000..dc7873297 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -0,0 +1,393 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.common; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.server.transport.HttpServletSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import jakarta.servlet.Servlet; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link McpTransportContext} propagation between MCP clients and + * async servers. + * + *

+ * This test class validates the end-to-end flow of transport context propagation in MCP + * communication, demonstrating how contextual information can be passed from client to + * server through HTTP headers and accessed within server-side handlers. + * + *

Test Scenarios

+ *

+ * The tests cover multiple transport configurations with async servers: + *

    + *
  • Stateless server with async/sync streamable HTTP clients
  • + *
  • Streamable server with async/sync streamable HTTP clients
  • + *
  • SSE (Server-Sent Events) server with async/sync SSE clients
  • + *
+ * + *

Context Propagation Flow

+ *
    + *
  1. Client-side: Context data is stored (either in ThreadLocal for sync clients or + * Reactor Context for async clients) and injected into HTTP headers via + * {@link McpSyncHttpClientRequestCustomizer}
  2. + *
  3. Transport: The context travels as HTTP headers (specifically "x-test" header in + * these tests)
  4. + *
  5. Server-side: A {@link McpTransportContextExtractor} extracts the header value and + * makes it available to request handlers through {@link McpTransportContext}
  6. + *
  7. Verification: The server echoes back the received context value as the tool call + * result
  8. + *
+ * + *

Key Components

+ *
    + *
  • {@link McpTransportContext} - Container for contextual data
  • + *
  • {@link McpSyncHttpClientRequestCustomizer} - Customizes HTTP requests to include + * context headers
  • + *
  • {@link McpTransportContextExtractor} - Extracts context from incoming HTTP + * requests
  • + *
  • ThreadLocal (sync) / Reactor Context (async) - Storage mechanisms for context + * data
  • + *
+ * + *

+ * All tests use an embedded Tomcat server running on a dynamically allocated port to + * ensure isolation and prevent port conflicts during parallel test execution. + * + * @author Daniel Garnier-Moiroux + * @author Christian Tzolov + * @see McpTransportContext + * @see McpTransportContextExtractor + * @see SyncServerMcpTransportContextIntegrationTests + */ +@Timeout(15) +public class AsyncServerMcpTransportContextIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private Tomcat tomcat; + + private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); + + private static final String HEADER_NAME = "x-test"; + + private final Supplier clientContextProvider = () -> { + var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); + return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + context) -> { + var headerValue = context.get("client-side-header-value"); + if (headerValue != null) { + builder.header(HEADER_NAME, headerValue.toString()); + } + }; + + private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { + var headerValue = r.getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; + + private final HttpServletStatelessServerTransport statelessServerTransport = HttpServletStatelessServerTransport + .builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletStreamableServerTransportProvider streamableServerTransport = HttpServletStreamableServerTransportProvider + .builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + + private final HttpServletSseServerTransportProvider sseServerTransport = HttpServletSseServerTransportProvider + .builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .messageEndpoint("/message") + .build(); + + private final McpAsyncClient asyncStreamableClient = McpClient + .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .build(); + + private final McpSyncClient syncStreamableClient = McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpAsyncClient asyncSseClient = McpClient + .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .build(); + + private final McpSyncClient sseClient = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .httpRequestCustomizer(clientRequestCustomizer) + .build()) + .transportContextProvider(clientContextProvider) + .build(); + + private final McpSchema.Tool tool = McpSchema.Tool.builder() + .name("test-tool") + .description("return the value of the x-test header from call tool request") + .build(); + + private final BiFunction> asyncStatelessHandler = ( + transportContext, request) -> { + return Mono + .just(new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null)); + }; + + private final BiFunction> asyncStatefulHandler = ( + exchange, request) -> { + return asyncStatelessHandler.apply(exchange.transportContext(), request); + }; + + @AfterEach + public void after() { + CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (statelessServerTransport != null) { + statelessServerTransport.closeGracefully().block(); + } + if (streamableServerTransport != null) { + streamableServerTransport.closeGracefully().block(); + } + if (sseServerTransport != null) { + sseServerTransport.closeGracefully().block(); + } + stopTomcat(); + } + + @Test + void asyncClinetStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void syncClientStatelessServer() { + startTomcat(statelessServerTransport); + + var mcpServer = McpServer.async(statelessServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) + .build(); + + McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = syncStreamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void asyncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncStreamableClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncStreamableClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void syncClientStreamableServer() { + startTomcat(streamableServerTransport); + + var mcpServer = McpServer.async(streamableServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = syncStreamableClient + .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + @Test + void asyncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + StepVerifier.create(asyncSseClient.initialize()).assertNext(initResult -> { + assertThat(initResult).isNotNull(); + }).verifyComplete(); + + // Test tool call with context + StepVerifier + .create(asyncSseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, + McpTransportContext.create(Map.of("client-side-header-value", "some important value"))))) + .assertNext(response -> { + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + }) + .verifyComplete(); + + mcpServer.close(); + } + + @Test + void syncClientSseServer() { + startTomcat(sseServerTransport); + + var mcpServer = McpServer.async(sseServerTransport) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) + .build(); + + McpSchema.InitializeResult initResult = sseClient.initialize(); + assertThat(initResult).isNotNull(); + + CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); + McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response.content()).hasSize(1) + .first() + .extracting(McpSchema.TextContent.class::cast) + .extracting(McpSchema.TextContent::text) + .isEqualTo("some important value"); + + mcpServer.close(); + } + + private void startTomcat(Servlet transport) { + tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + } + + private void stopTomcat() { + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java similarity index 99% rename from mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java rename to mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java index 8d75b8479..e590ca3ad 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -48,7 +48,7 @@ * @author Daniel Garnier-Moiroux */ @Timeout(15) -public class McpTransportContextIntegrationTests { +public class SyncServerMcpTransportContextIntegrationTests { private static final int PORT = TomcatTestUtil.findAvailablePort(); From 077171a277fd7bd6294c70fb4106d78fcc38f798 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Wed, 3 Sep 2025 18:00:03 +0200 Subject: [PATCH 2/2] test: Remove sync client tests from async server test suites and improve resource cleanup - Remove synchronous client tests from AsyncServerMcpTransportContextIntegrationTests - Clean up unused sync client imports and ThreadLocal context providers - Add proper resource cleanup for async clients in teardown methods - Update documentation to reflect async-only client/server focus - Refactor WebMVC tests to use Spring beans for MCP server configuration - Simplify test architecture by separating sync and async concerns Signed-off-by: Christian Tzolov --- ...erMcpTransportContextIntegrationTests.java | 132 +---------- ...erMcpTransportContextIntegrationTests.java | 6 + .../McpTransportContextIntegrationTests.java | 220 ++++++++---------- ...erMcpTransportContextIntegrationTests.java | 141 ++--------- ...erMcpTransportContextIntegrationTests.java | 6 + 5 files changed, 143 insertions(+), 362 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java index 0fb59b22e..f3e2d3626 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -6,12 +6,10 @@ import java.util.Map; import java.util.function.BiFunction; -import java.util.function.Supplier; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; import io.modelcontextprotocol.server.McpAsyncServerExchange; @@ -49,14 +47,13 @@ * *

* This test class validates the end-to-end flow of transport context propagation in MCP - * communication for asynchronous server implementations. It tests various combinations of - * client types (sync/async) and server transport mechanisms (stateless, streamable, SSE) - * to ensure proper context handling across different configurations. + * communication for asynchronous client and server implementations. It tests various + * combinations of client types and server transport mechanisms (stateless, streamable, + * SSE) to ensure proper context handling across different configurations. * *

Context Propagation Flow

*
    - *
  1. Client sets a value in its transport context (either via thread-local for sync or - * Reactor context for async)
  2. + *
  3. Client sets a value in its transport context via thread-local Reactor context
  4. *
  5. Client-side context provider extracts the value and adds it as an HTTP header to * the request
  6. *
  7. Server-side context extractor reads the header from the incoming request
  8. @@ -67,33 +64,19 @@ * * @author Daniel Garnier-Moiroux * @author Christian Tzolov - * @see McpTransportContext - * @see McpTransportContextExtractor - * @see WebFluxStatelessServerTransport - * @see WebFluxStreamableServerTransportProvider - * @see WebFluxSseServerTransportProvider */ @Timeout(15) public class AsyncServerMcpTransportContextIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); - private static final ThreadLocal SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); - private static final String HEADER_NAME = "x-test"; - // Sync client context provider - private final Supplier syncClientContextProvider = () -> { - var headerValue = SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); - return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) - : McpTransportContext.EMPTY; - }; - // Async client context provider ExchangeFilterFunction asyncClientContextProvider = (request, next) -> Mono.deferContextual(ctx -> { - var context = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + var transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); // // do stuff with the context - var headerValue = context.get("client-side-header-value"); + var headerValue = transportContext.get("client-side-header-value"); if (headerValue == null) { return next.exchange(request); } @@ -156,26 +139,10 @@ public class AsyncServerMcpTransportContextIntegrationTests { .build()) .build(); - // Sync clients - private final McpSyncClient syncStreamableClient = McpClient - .sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) - .build()) - .transportContextProvider(syncClientContextProvider) - .build(); - - private final McpSyncClient syncSseClient = McpClient - .sync(WebFluxSseClientTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT).filter(asyncClientContextProvider)) - .build()) - .transportContextProvider(syncClientContextProvider) - .build(); - private DisposableServer httpServer; @AfterEach public void after() { - SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); if (statelessServerTransport != null) { statelessServerTransport.closeGracefully().block(); } @@ -185,36 +152,15 @@ public void after() { if (sseServerTransport != null) { sseServerTransport.closeGracefully().block(); } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } stopHttpServer(); } - @Test - void syncClientStatelessServer() { - - startHttpServer(statelessServerTransport.getRouterFunction()); - - var mcpServer = McpServer.async(statelessServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) - .build(); - - McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); - assertThat(initResult).isNotNull(); - - SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = syncStreamableClient - .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - @Test void asyncClientStatelessServer() { @@ -247,33 +193,6 @@ void asyncClientStatelessServer() { mcpServer.close(); } - @Test - void syncClientStreamableServer() { - - startHttpServer(streamableServerTransport.getRouterFunction()); - - var mcpServer = McpServer.async(streamableServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) - .build(); - - McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); - assertThat(initResult).isNotNull(); - - SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = syncStreamableClient - .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - @Test void asyncClientStreamableServer() { @@ -306,33 +225,6 @@ void asyncClientStreamableServer() { mcpServer.close(); } - @Test - void syncClientSseServer() { - - startHttpServer(sseServerTransport.getRouterFunction()); - - var mcpServer = McpServer.async(sseServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) - .build(); - - McpSchema.InitializeResult initResult = syncSseClient.initialize(); - assertThat(initResult).isNotNull(); - - SYNC_CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = syncSseClient - .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - @Test void asyncClientSseServer() { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java index fefae35cb..865192489 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -169,6 +169,12 @@ public void after() { if (sseServerTransport != null) { sseServerTransport.closeGracefully().block(); } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } stopHttpServer(); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java index 6a5ebfe2d..1f5f1cc0c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/common/McpTransportContextIntegrationTests.java @@ -10,7 +10,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.McpClient.SyncSpec; import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; @@ -18,6 +17,8 @@ import io.modelcontextprotocol.server.McpServer; import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.server.TestUtil; @@ -35,7 +36,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.Import; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerRequest; @@ -66,9 +66,6 @@ * * @author Daniel Garnier-Moiroux * @author Christian Tzolov - * @see McpTransportContext - * @see McpTransportContextExtractor - * @see McpSyncHttpClientRequestCustomizer */ @Timeout(15) public class McpTransportContextIntegrationTests { @@ -95,94 +92,18 @@ public class McpTransportContextIntegrationTests { } }; - private final BiFunction statelessHandler = ( + private static final BiFunction statelessHandler = ( transportContext, request) -> new McpSchema.CallToolResult(transportContext.get("server-side-header-value").toString(), null); - private final BiFunction statefulHandler = ( + private static final BiFunction statefulHandler = ( exchange, request) -> statelessHandler.apply(exchange.transportContext(), request); - @Configuration - static class TestCommonConfig { - - @Bean - public McpTransportContextExtractor serverContextExtractor() { - return (ServerRequest r) -> { - String headerValue = r.servletRequest().getHeader(HEADER_NAME); - return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) - : McpTransportContext.EMPTY; - }; - }; - - } - - @Configuration - @EnableWebMvc - @Import(TestCommonConfig.class) - static class TestStatelessConfig { - - @Bean - public WebMvcStatelessServerTransport webMvcStatelessServerTransport( - McpTransportContextExtractor serverContextExtractor) { - - return WebMvcStatelessServerTransport.builder() - .objectMapper(new ObjectMapper()) - .contextExtractor(serverContextExtractor) - .build(); - } - - @Bean - public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - @Configuration - @EnableWebMvc - @Import(TestCommonConfig.class) - static class TestStreamableHttpConfig { - - @Bean - public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport( - McpTransportContextExtractor serverContextExtractor) { - - return WebMvcStreamableServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .contextExtractor(serverContextExtractor) - .build(); - } - - @Bean - public RouterFunction routerFunction( - WebMvcStreamableServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } - - @Configuration - @EnableWebMvc - @Import(TestCommonConfig.class) - static class TestSseConfig { - - @Bean - public WebMvcSseServerTransportProvider webMvcSseServerTransport( - McpTransportContextExtractor serverContextExtractor) { - - return WebMvcSseServerTransportProvider.builder() - .objectMapper(new ObjectMapper()) - .contextExtractor(serverContextExtractor) - .messageEndpoint("/mcp/message") - .build(); - } - - @Bean - public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { - return transportProvider.getRouterFunction(); - } - - } + private static McpTransportContextExtractor serverContextExtractor = (ServerRequest r) -> { + String headerValue = r.servletRequest().getHeader(HEADER_NAME); + return headerValue != null ? McpTransportContext.create(Map.of("server-side-header-value", headerValue)) + : McpTransportContext.EMPTY; + }; private final McpSyncClient streamableClient = McpClient .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) @@ -198,7 +119,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro .transportContextProvider(clientContextProvider) .build(); - private final McpSchema.Tool tool = McpSchema.Tool.builder() + private static final McpSchema.Tool tool = McpSchema.Tool.builder() .name("test-tool") .description("return the value of the x-test header from call tool request") .build(); @@ -206,6 +127,12 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro @AfterEach public void after() { CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } stopTomcat(); } @@ -213,13 +140,6 @@ public void after() { void statelessServer() { startTomcat(TestStatelessConfig.class); - var statelessServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); - - var mcpServer = McpServer.sync(statelessServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) - .build(); - McpSchema.InitializeResult initResult = streamableClient.initialize(); assertThat(initResult).isNotNull(); @@ -233,8 +153,6 @@ void statelessServer() { .extracting(McpSchema.TextContent.class::cast) .extracting(McpSchema.TextContent::text) .isEqualTo("some important value"); - - mcpServer.close(); } @Test @@ -242,14 +160,6 @@ void streamableServer() { startTomcat(TestStreamableHttpConfig.class); - var streamableServerTransport = tomcatServer.appContext() - .getBean(WebMvcStreamableServerTransportProvider.class); - - var mcpServer = McpServer.sync(streamableServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) - .build(); - McpSchema.InitializeResult initResult = streamableClient.initialize(); assertThat(initResult).isNotNull(); @@ -263,21 +173,12 @@ void streamableServer() { .extracting(McpSchema.TextContent.class::cast) .extracting(McpSchema.TextContent::text) .isEqualTo("some important value"); - - mcpServer.close(); } @Test void sseServer() { startTomcat(TestSseConfig.class); - var sseServerTransport = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); - - var mcpServer = McpServer.sync(sseServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) - .build(); - McpSchema.InitializeResult initResult = sseClient.initialize(); assertThat(initResult).isNotNull(); @@ -290,8 +191,6 @@ void sseServer() { .extracting(McpSchema.TextContent.class::cast) .extracting(McpSchema.TextContent::text) .isEqualTo("some important value"); - - mcpServer.close(); } private void startTomcat(Class componentClass) { @@ -317,4 +216,91 @@ private void stopTomcat() { } } + @Configuration + @EnableWebMvc + static class TestStatelessConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpStatelessSyncServer mcpStatelessServer(WebMvcStatelessServerTransport transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpStatelessServerFeatures.SyncToolSpecification(tool, statelessHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestStreamableHttpConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransport() { + + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpStreamableServer(WebMvcStreamableServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + } + + } + + @Configuration + @EnableWebMvc + static class TestSseConfig { + + @Bean + public WebMvcSseServerTransportProvider webMvcSseServerTransport() { + + return WebMvcSseServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .contextExtractor(serverContextExtractor) + .messageEndpoint("/mcp/message") + .build(); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + @Bean + public McpSyncServer mcpSseServer(WebMvcSseServerTransportProvider transportProvider) { + return McpServer.sync(transportProvider) + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(new McpServerFeatures.SyncToolSpecification(tool, null, statefulHandler)) + .build(); + + } + + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java index dc7873297..fb19c62f7 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/common/AsyncServerMcpTransportContextIntegrationTests.java @@ -6,14 +6,13 @@ import java.util.Map; import java.util.function.BiFunction; -import java.util.function.Supplier; import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.client.McpAsyncClient; import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.McpSyncClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.customizer.McpAsyncHttpClientRequestCustomizer; import io.modelcontextprotocol.client.transport.customizer.McpSyncHttpClientRequestCustomizer; import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpServer; @@ -51,16 +50,15 @@ *

    * The tests cover multiple transport configurations with async servers: *

      - *
    • Stateless server with async/sync streamable HTTP clients
    • - *
    • Streamable server with async/sync streamable HTTP clients
    • - *
    • SSE (Server-Sent Events) server with async/sync SSE clients
    • + *
    • Stateless server with async streamable HTTP clients
    • + *
    • Streamable server with async streamable HTTP clients
    • + *
    • SSE (Server-Sent Events) server with async SSE clients
    • *
    * *

    Context Propagation Flow

    *
      - *
    1. Client-side: Context data is stored (either in ThreadLocal for sync clients or - * Reactor Context for async clients) and injected into HTTP headers via - * {@link McpSyncHttpClientRequestCustomizer}
    2. + *
    3. Client-side: Context data is stored in the Reactor Context and injected into HTTP + * headers via {@link McpSyncHttpClientRequestCustomizer}
    4. *
    5. Transport: The context travels as HTTP headers (specifically "x-test" header in * these tests)
    6. *
    7. Server-side: A {@link McpTransportContextExtractor} extracts the header value and @@ -69,26 +67,12 @@ * result
    8. *
    * - *

    Key Components

    - *
      - *
    • {@link McpTransportContext} - Container for contextual data
    • - *
    • {@link McpSyncHttpClientRequestCustomizer} - Customizes HTTP requests to include - * context headers
    • - *
    • {@link McpTransportContextExtractor} - Extracts context from incoming HTTP - * requests
    • - *
    • ThreadLocal (sync) / Reactor Context (async) - Storage mechanisms for context - * data
    • - *
    - * *

    * All tests use an embedded Tomcat server running on a dynamically allocated port to * ensure isolation and prevent port conflicts during parallel test execution. * * @author Daniel Garnier-Moiroux * @author Christian Tzolov - * @see McpTransportContext - * @see McpTransportContextExtractor - * @see SyncServerMcpTransportContextIntegrationTests */ @Timeout(15) public class AsyncServerMcpTransportContextIntegrationTests { @@ -97,22 +81,15 @@ public class AsyncServerMcpTransportContextIntegrationTests { private Tomcat tomcat; - private static final ThreadLocal CLIENT_SIDE_HEADER_VALUE_HOLDER = new ThreadLocal<>(); - private static final String HEADER_NAME = "x-test"; - private final Supplier clientContextProvider = () -> { - var headerValue = CLIENT_SIDE_HEADER_VALUE_HOLDER.get(); - return headerValue != null ? McpTransportContext.create(Map.of("client-side-header-value", headerValue)) - : McpTransportContext.EMPTY; - }; - - private final McpSyncHttpClientRequestCustomizer clientRequestCustomizer = (builder, method, endpoint, body, + private final McpAsyncHttpClientRequestCustomizer asyncClientRequestCustomizer = (builder, method, endpoint, body, context) -> { var headerValue = context.get("client-side-header-value"); if (headerValue != null) { builder.header(HEADER_NAME, headerValue.toString()); } + return Mono.just(builder); }; private final McpTransportContextExtractor serverContextExtractor = (HttpServletRequest r) -> { @@ -142,30 +119,16 @@ public class AsyncServerMcpTransportContextIntegrationTests { private final McpAsyncClient asyncStreamableClient = McpClient .async(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .httpRequestCustomizer(clientRequestCustomizer) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) .build()) .build(); - private final McpSyncClient syncStreamableClient = McpClient - .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .httpRequestCustomizer(clientRequestCustomizer) - .build()) - .transportContextProvider(clientContextProvider) - .build(); - private final McpAsyncClient asyncSseClient = McpClient .async(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .httpRequestCustomizer(clientRequestCustomizer) + .asyncHttpRequestCustomizer(asyncClientRequestCustomizer) .build()) .build(); - private final McpSyncClient sseClient = McpClient - .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .httpRequestCustomizer(clientRequestCustomizer) - .build()) - .transportContextProvider(clientContextProvider) - .build(); - private final McpSchema.Tool tool = McpSchema.Tool.builder() .name("test-tool") .description("return the value of the x-test header from call tool request") @@ -184,7 +147,6 @@ public class AsyncServerMcpTransportContextIntegrationTests { @AfterEach public void after() { - CLIENT_SIDE_HEADER_VALUE_HOLDER.remove(); if (statelessServerTransport != null) { statelessServerTransport.closeGracefully().block(); } @@ -194,6 +156,12 @@ public void after() { if (sseServerTransport != null) { sseServerTransport.closeGracefully().block(); } + if (asyncStreamableClient != null) { + asyncStreamableClient.closeGracefully().block(); + } + if (asyncSseClient != null) { + asyncSseClient.closeGracefully().block(); + } stopTomcat(); } @@ -228,32 +196,6 @@ void asyncClinetStatelessServer() { mcpServer.close(); } - @Test - void syncClientStatelessServer() { - startTomcat(statelessServerTransport); - - var mcpServer = McpServer.async(statelessServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpStatelessServerFeatures.AsyncToolSpecification(tool, asyncStatelessHandler)) - .build(); - - McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); - assertThat(initResult).isNotNull(); - - CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = syncStreamableClient - .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - @Test void asyncClientStreamableServer() { startTomcat(streamableServerTransport); @@ -285,32 +227,6 @@ void asyncClientStreamableServer() { mcpServer.close(); } - @Test - void syncClientStreamableServer() { - startTomcat(streamableServerTransport); - - var mcpServer = McpServer.async(streamableServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) - .build(); - - McpSchema.InitializeResult initResult = syncStreamableClient.initialize(); - assertThat(initResult).isNotNull(); - - CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = syncStreamableClient - .callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - @Test void asyncClientSseServer() { startTomcat(sseServerTransport); @@ -342,31 +258,6 @@ void asyncClientSseServer() { mcpServer.close(); } - @Test - void syncClientSseServer() { - startTomcat(sseServerTransport); - - var mcpServer = McpServer.async(sseServerTransport) - .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.AsyncToolSpecification(tool, null, asyncStatefulHandler)) - .build(); - - McpSchema.InitializeResult initResult = sseClient.initialize(); - assertThat(initResult).isNotNull(); - - CLIENT_SIDE_HEADER_VALUE_HOLDER.set("some important value"); - McpSchema.CallToolResult response = sseClient.callTool(new McpSchema.CallToolRequest("test-tool", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response.content()).hasSize(1) - .first() - .extracting(McpSchema.TextContent.class::cast) - .extracting(McpSchema.TextContent::text) - .isEqualTo("some important value"); - - mcpServer.close(); - } - private void startTomcat(Servlet transport) { tomcat = TomcatTestUtil.createTomcatServer("", PORT, transport); try { diff --git a/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java index e590ca3ad..42747f717 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/common/SyncServerMcpTransportContextIntegrationTests.java @@ -135,6 +135,12 @@ public void after() { if (sseServerTransport != null) { sseServerTransport.closeGracefully().block(); } + if (streamableClient != null) { + streamableClient.closeGracefully(); + } + if (sseClient != null) { + sseClient.closeGracefully(); + } stopTomcat(); }