Skip to content

Commit 48612e8

Browse files
committed
chore: implement integration test for sampling
1 parent 07e7b8f commit 48612e8

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@
55
package io.modelcontextprotocol.client;
66

77
import java.time.Duration;
8+
import java.util.Map;
89
import java.util.concurrent.CountDownLatch;
910
import java.util.concurrent.TimeUnit;
11+
import java.util.concurrent.atomic.AtomicInteger;
1012
import java.util.concurrent.atomic.AtomicReference;
1113

1214
import io.modelcontextprotocol.client.transport.ServerParameters;
1315
import io.modelcontextprotocol.client.transport.StdioClientTransport;
1416
import io.modelcontextprotocol.spec.McpClientTransport;
17+
import io.modelcontextprotocol.spec.McpSchema;
1518
import org.junit.jupiter.api.Test;
1619
import org.junit.jupiter.api.Timeout;
1720
import reactor.core.publisher.Sinks;
1821
import reactor.test.StepVerifier;
1922

2023
import static org.assertj.core.api.Assertions.assertThat;
24+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
2125

2226
/**
2327
* Tests for the {@link McpSyncClient} with {@link StdioClientTransport}.
@@ -67,6 +71,50 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException {
6771
StepVerifier.create(transport.closeGracefully()).expectComplete().verify(Duration.ofSeconds(5));
6872
}
6973

74+
@Test
75+
void testSampling() {
76+
McpClientTransport transport = createMcpTransport();
77+
78+
final String message = "Hello, world!";
79+
final String response = "Goodbye, world!";
80+
final int maxTokens = 100;
81+
82+
AtomicReference<String> receivedPrompt = new AtomicReference<>();
83+
AtomicReference<String> receivedMessage = new AtomicReference<>();
84+
AtomicInteger receivedMaxTokens = new AtomicInteger();
85+
86+
withClient(transport, spec -> spec.capabilities(McpSchema.ClientCapabilities.builder().sampling().build())
87+
.sampling(request -> {
88+
McpSchema.TextContent messageText = assertInstanceOf(McpSchema.TextContent.class,
89+
request.messages().get(0).content());
90+
receivedPrompt.set(request.systemPrompt());
91+
receivedMessage.set(messageText.text());
92+
receivedMaxTokens.set(request.maxTokens());
93+
94+
return new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, new McpSchema.TextContent(response),
95+
"modelId", McpSchema.CreateMessageResult.StopReason.END_TURN);
96+
}), client -> {
97+
client.initialize();
98+
99+
McpSchema.CallToolResult result = client.callTool(
100+
new McpSchema.CallToolRequest("sampleLLM", Map.of("prompt", message, "maxTokens", maxTokens)));
101+
102+
// Verify tool response to ensure our sampling response was passed through
103+
assertThat(result.content()).hasAtLeastOneElementOfType(McpSchema.TextContent.class);
104+
assertThat(result.content()).allSatisfy(content -> {
105+
if (!(content instanceof McpSchema.TextContent text))
106+
return;
107+
108+
assertThat(text.text()).endsWith(response); // Prefixed
109+
});
110+
111+
// Verify sampling request parameters received in our callback
112+
assertThat(receivedPrompt.get()).isNotEmpty();
113+
assertThat(receivedMessage.get()).endsWith(message); // Prefixed
114+
assertThat(receivedMaxTokens.get()).isEqualTo(maxTokens);
115+
});
116+
}
117+
70118
protected Duration getInitializationTimeout() {
71119
return Duration.ofSeconds(6);
72120
}

0 commit comments

Comments
 (0)