Skip to content

Commit 4c05e1a

Browse files
committed
feat: Add Pagination for requesting list of prompts
Adds the Pagination feature to the `prompts/list` feature as described in the specification. To make this possible mainly two changes are made: 1. The logic for cursor handling is added. 2. Handling for invalid parameters (MCP error code `-32602 (Invalid params)`) is added to the `McpServerSession`. For now the cursor is the base64 encoded start index of the next page. The page size is set to 10. When parameters are found to be invalid the newly introduced `McpParamsValidationError` is returned to handle it properly in the `McpServerSession`.
1 parent 84adde1 commit 4c05e1a

File tree

4 files changed

+183
-9
lines changed

4 files changed

+183
-9
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import java.util.concurrent.atomic.AtomicReference;
1313
import java.util.function.Function;
1414
import java.util.stream.Collectors;
15+
import java.util.stream.Stream;
1516

1617
import com.fasterxml.jackson.databind.ObjectMapper;
1718
import io.modelcontextprotocol.client.McpClient;
@@ -36,6 +37,8 @@
3637
import org.junit.jupiter.api.AfterEach;
3738
import org.junit.jupiter.api.BeforeEach;
3839
import org.junit.jupiter.params.ParameterizedTest;
40+
import org.junit.jupiter.params.provider.Arguments;
41+
import org.junit.jupiter.params.provider.MethodSource;
3942
import org.junit.jupiter.params.provider.ValueSource;
4043
import reactor.netty.DisposableServer;
4144
import reactor.netty.http.server.HttpServer;
@@ -46,9 +49,11 @@
4649
import org.springframework.web.reactive.function.client.WebClient;
4750
import org.springframework.web.reactive.function.server.RouterFunctions;
4851

52+
import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS;
4953
import static org.assertj.core.api.Assertions.assertThat;
5054
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
5155
import static org.assertj.core.api.Assertions.assertWith;
56+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5257
import static org.awaitility.Awaitility.await;
5358
import static org.mockito.Mockito.mock;
5459

@@ -759,4 +764,99 @@ void testLoggingNotification(String clientType) {
759764
mcpServer.close();
760765
}
761766

767+
// ---------------------------------------
768+
// Prompt List Tests
769+
// ---------------------------------------
770+
771+
static Stream<Arguments> providePaginationTestParams() {
772+
return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21),
773+
Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21));
774+
}
775+
776+
@ParameterizedTest(name = "{0} ({1}) : {displayName} ")
777+
@MethodSource("providePaginationTestParams")
778+
void testListPromptSuccess(String clientType, int availablePrompts) {
779+
780+
var clientBuilder = clientBuilders.get(clientType);
781+
782+
// Setup list of prompts
783+
List<McpServerFeatures.SyncPromptSpecification> prompts = new ArrayList<>();
784+
785+
for (int i = 0; i < availablePrompts; i++) {
786+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description",
787+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
788+
789+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
790+
791+
prompts.add(promptSpec);
792+
}
793+
794+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
795+
.capabilities(ServerCapabilities.builder().prompts(true).build())
796+
.prompts(prompts)
797+
.build();
798+
799+
try (var mcpClient = clientBuilder.build()) {
800+
801+
InitializeResult initResult = mcpClient.initialize();
802+
assertThat(initResult).isNotNull();
803+
804+
// Iterate through list
805+
var returnedPromptsSum = 0;
806+
807+
var hasEntries = true;
808+
String nextCursor = null;
809+
810+
while (hasEntries) {
811+
var res = mcpClient.listPrompts(nextCursor);
812+
returnedPromptsSum += res.prompts().size();
813+
814+
nextCursor = res.nextCursor();
815+
816+
if (nextCursor == null) {
817+
hasEntries = false;
818+
}
819+
}
820+
821+
assertThat(returnedPromptsSum).isEqualTo(availablePrompts);
822+
823+
}
824+
825+
mcpServer.close();
826+
}
827+
828+
@ParameterizedTest(name = "{0} : {displayName} ")
829+
@ValueSource(strings = { "httpclient", "webflux" })
830+
void testListPromptInvalidCursor(String clientType) {
831+
832+
var clientBuilder = clientBuilders.get(clientType);
833+
834+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description",
835+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
836+
837+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
838+
839+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
840+
.capabilities(ServerCapabilities.builder().prompts(true).build())
841+
.prompts(promptSpec)
842+
.build();
843+
844+
try (var mcpClient = clientBuilder.build()) {
845+
846+
InitializeResult initResult = mcpClient.initialize();
847+
assertThat(initResult).isNotNull();
848+
849+
assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class)
850+
.hasMessage("Invalid cursor")
851+
.satisfies(exception -> {
852+
var error = (McpError) exception;
853+
assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS);
854+
assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor");
855+
});
856+
857+
}
858+
859+
mcpServer.close();
860+
}
861+
762862
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Map;
1111
import java.util.Optional;
1212
import java.util.UUID;
13+
import java.util.Base64;
1314
import java.util.concurrent.ConcurrentHashMap;
1415
import java.util.concurrent.CopyOnWriteArrayList;
1516
import java.util.function.BiFunction;
@@ -18,6 +19,7 @@
1819
import com.fasterxml.jackson.databind.ObjectMapper;
1920
import io.modelcontextprotocol.spec.McpClientSession;
2021
import io.modelcontextprotocol.spec.McpError;
22+
import io.modelcontextprotocol.spec.McpParamsValidationError;
2123
import io.modelcontextprotocol.spec.McpSchema;
2224
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2325
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
@@ -265,6 +267,8 @@ private static class AsyncServerImpl extends McpAsyncServer {
265267

266268
private final ConcurrentHashMap<String, McpServerFeatures.AsyncPromptSpecification> prompts = new ConcurrentHashMap<>();
267269

270+
private static final int PAGE_SIZE = 10;
271+
268272
// FIXME: this field is deprecated and should be remvoed together with the
269273
// broadcasting loggingNotification.
270274
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
@@ -638,20 +642,67 @@ public Mono<Void> notifyPromptsListChanged() {
638642

639643
private McpServerSession.RequestHandler<McpSchema.ListPromptsResult> promptsListRequestHandler() {
640644
return (exchange, params) -> {
641-
// TODO: Implement pagination
642-
// McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
643-
// new TypeReference<McpSchema.PaginatedRequest>() {
644-
// });
645+
McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
646+
new TypeReference<McpSchema.PaginatedRequest>() {
647+
});
648+
649+
if (!isCursorValid(request.cursor(), this.prompts.size())) {
650+
return Mono.error(new McpParamsValidationError("Invalid cursor"));
651+
}
652+
653+
int requestedStartIndex = 0;
654+
655+
if (request.cursor() != null) {
656+
requestedStartIndex = decodeCursor(request.cursor());
657+
}
658+
659+
int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, this.prompts.size());
645660

646661
var promptList = this.prompts.values()
647662
.stream()
663+
.skip(requestedStartIndex)
664+
.limit(endIndex - requestedStartIndex)
648665
.map(McpServerFeatures.AsyncPromptSpecification::prompt)
649666
.toList();
650667

651-
return Mono.just(new McpSchema.ListPromptsResult(promptList, null));
668+
String nextCursor = null;
669+
670+
if (endIndex < this.prompts.size()) {
671+
nextCursor = encodeCursor(endIndex);
672+
}
673+
674+
return Mono.just(new McpSchema.ListPromptsResult(promptList, nextCursor));
652675
};
653676
}
654677

678+
private boolean isCursorValid(String cursor, int maxPageSize) {
679+
if (cursor == null) {
680+
return true;
681+
}
682+
683+
try {
684+
var decoded = decodeCursor(cursor);
685+
686+
if (decoded < 0 || decoded > maxPageSize) {
687+
return false;
688+
}
689+
690+
return true;
691+
}
692+
catch (NumberFormatException e) {
693+
return false;
694+
}
695+
}
696+
697+
private String encodeCursor(int index) {
698+
return Base64.getEncoder().encodeToString(String.valueOf(index).getBytes());
699+
}
700+
701+
private int decodeCursor(String cursor) {
702+
String decoded = new String(Base64.getDecoder().decode(cursor));
703+
return Integer.parseInt(decoded);
704+
}
705+
655706
private McpServerSession.RequestHandler<McpSchema.GetPromptResult> promptsGetRequestHandler() {
656707
return (exchange, params) -> {
657708
McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.modelcontextprotocol.spec;
2+
3+
public class McpParamsValidationError extends McpError {
4+
5+
public McpParamsValidationError(McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError) {
6+
super(jsonRpcError.message());
7+
}
8+
9+
public McpParamsValidationError(Object error) {
10+
super(error.toString());
11+
}
12+
13+
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,20 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
225225
}
226226
return resultMono
227227
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
228-
.onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
229-
null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
230-
error.getMessage(), null)))); // TODO: add error message
231-
// through the data field
228+
.onErrorResume(error -> {
229+
230+
var errorCode = McpSchema.ErrorCodes.INTERNAL_ERROR;
231+
232+
if (error instanceof McpParamsValidationError) {
233+
errorCode = McpSchema.ErrorCodes.INVALID_PARAMS;
234+
}
235+
236+
// TODO: add error message through the data field
237+
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
238+
new McpSchema.JSONRPCResponse.JSONRPCError(errorCode, error.getMessage(), null));
239+
240+
return Mono.just(errorResponse);
241+
});
232242
});
233243
}
234244

0 commit comments

Comments
 (0)