|
5 | 5 | package io.modelcontextprotocol.client; |
6 | 6 |
|
7 | 7 | import java.time.Duration; |
| 8 | +import java.util.Map; |
8 | 9 | import java.util.concurrent.CountDownLatch; |
9 | 10 | import java.util.concurrent.TimeUnit; |
| 11 | +import java.util.concurrent.atomic.AtomicInteger; |
10 | 12 | import java.util.concurrent.atomic.AtomicReference; |
11 | 13 |
|
12 | 14 | import io.modelcontextprotocol.client.transport.ServerParameters; |
13 | 15 | import io.modelcontextprotocol.client.transport.StdioClientTransport; |
14 | 16 | import io.modelcontextprotocol.spec.McpClientTransport; |
| 17 | +import io.modelcontextprotocol.spec.McpSchema; |
15 | 18 | import org.junit.jupiter.api.Test; |
16 | 19 | import org.junit.jupiter.api.Timeout; |
17 | 20 | import reactor.core.publisher.Sinks; |
18 | 21 | import reactor.test.StepVerifier; |
19 | 22 |
|
20 | 23 | import static org.assertj.core.api.Assertions.assertThat; |
| 24 | +import static org.junit.jupiter.api.Assertions.assertInstanceOf; |
21 | 25 |
|
22 | 26 | /** |
23 | 27 | * Tests for the {@link McpSyncClient} with {@link StdioClientTransport}. |
@@ -67,6 +71,50 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException { |
67 | 71 | StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5)); |
68 | 72 | } |
69 | 73 |
|
| 74 | + @Test |
| 75 | + void testSampling() { |
| 76 | + McpClientTransport transport = createMcpTransport(); |
| 77 | + |
| 78 | + final String message = "Hello, world!"; |
| 79 | + final String response = "Goodbye, world!"; |
| 80 | + final int maxTokens = 100; |
| 81 | + |
| 82 | + AtomicReference<String> receivedPrompt = new AtomicReference<>(); |
| 83 | + AtomicReference<String> receivedMessage = new AtomicReference<>(); |
| 84 | + AtomicInteger receivedMaxTokens = new AtomicInteger(); |
| 85 | + |
| 86 | + withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build()) |
| 87 | + .sampling(request -> { |
| 88 | + McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class, |
| 89 | + request.messages().get(0).content()); |
| 90 | + receivedPrompt.set(request.systemPrompt()); |
| 91 | + receivedMessage.set(messageText.text()); |
| 92 | + receivedMaxTokens.set(request.maxTokens()); |
| 93 | + |
| 94 | + return new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, new McpSchema.TextContent(response), |
| 95 | + "modelId", McpSchema.CreateMessageResult.StopReason.END_TURN); |
| 96 | + }), client -> { |
| 97 | + client.initialize(); |
| 98 | + |
| 99 | + McpSchema.CallToolResult result = client.callTool( |
| 100 | + new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens))); |
| 101 | + |
| 102 | + // Verify tool response to ensure our sampling response was passed through |
| 103 | + assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class); |
| 104 | + assertThat(result.content()).allSatisfy(content -> { |
| 105 | + if (!(content instanceof McpSchema.TextContent text)) |
| 106 | + return; |
| 107 | + |
| 108 | + assertThat(text.text()).endsWith(response); // Prefixed |
| 109 | + }); |
| 110 | + |
| 111 | + // Verify sampling request parameters received in our callback |
| 112 | + assertThat(receivedPrompt.get()).isNotEmpty(); |
| 113 | + assertThat(receivedMessage.get()).endsWith(message); // Prefixed |
| 114 | + assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens); |
| 115 | + }); |
| 116 | + } |
| 117 | + |
70 | 118 | protected Duration getInitializationTimeout() { |
71 | 119 | return Duration.ofSeconds(6); |
72 | 120 | } |
|
0 commit comments