diff --git a/framework/fel/java/plugins/tool-mcp-client/src/main/java/modelengine/fel/tool/mcp/client/support/DefaultMcpClient.java b/framework/fel/java/plugins/tool-mcp-client/src/main/java/modelengine/fel/tool/mcp/client/support/DefaultMcpClient.java index f22ef103..77b32091 100644 --- a/framework/fel/java/plugins/tool-mcp-client/src/main/java/modelengine/fel/tool/mcp/client/support/DefaultMcpClient.java +++ b/framework/fel/java/plugins/tool-mcp-client/src/main/java/modelengine/fel/tool/mcp/client/support/DefaultMcpClient.java @@ -10,6 +10,7 @@ import modelengine.fel.tool.mcp.client.McpClient; import modelengine.fel.tool.mcp.entity.ClientSchema; +import modelengine.fel.tool.mcp.entity.Event; import modelengine.fel.tool.mcp.entity.JsonRpc; import modelengine.fel.tool.mcp.entity.Method; import modelengine.fel.tool.mcp.entity.ServerSchema; @@ -43,6 +44,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiConsumer; import java.util.function.Consumer; /** @@ -129,8 +131,8 @@ public void initialize() { .runnable(this::pingServer) .policy(ExecutePolicy.fixedDelay(DELAY_MILLIS)) .build(), DELAY_MILLIS); - while (!this.waitInitialized()) { - ThreadUtils.sleep(100); + if (!this.waitInitialized()) { + throw new IllegalStateException("Failed to initialize."); } } @@ -142,7 +144,7 @@ private void consumeTextEvent(TextEvent textEvent) { if (StringUtils.isBlank(textEvent.event()) || StringUtils.isBlank((String) textEvent.data())) { return; } - if (Objects.equals(textEvent.event(), "endpoint")) { + if (Objects.equals(textEvent.event(), Event.ENDPOINT.code())) { this.initializeMcpServer(textEvent); return; } @@ -173,64 +175,20 @@ private void pingServer() { log.info("MCP client is not initialized and {} method will be delayed.", Method.PING.code()); return; } - HttpClassicClientRequest request = - this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint); - long currentId = this.getNextId(); - JsonRpc.Request rpcRequest = JsonRpc.createRequest(currentId, Method.PING.code()); - request.entity(Entity.createObject(request, rpcRequest)); - log.info("Send {} method to MCP server. [sessionId={}, request={}]", - Method.PING.code(), - this.sessionId, - rpcRequest); - try (HttpClassicClientResponse exchange = request.exchange(Object.class)) { - if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) { - log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]", - Method.PING.code(), - this.sessionId, - exchange.statusCode()); - } else { - log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]", - Method.PING.code(), - this.sessionId, - exchange.statusCode()); - } - } catch (IOException e) { - throw new IllegalStateException(e); - } + this.post2McpServer(Method.PING, null, null); } private void initializeMcpServer(TextEvent textEvent) { this.messageEndpoint = textEvent.data().toString(); - HttpClassicClientRequest request = - this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint); - this.sessionId = - request.queries().first("session_id").orElseThrow(() -> new IllegalStateException("no session_id")); - long currentId = this.getNextId(); - this.responseConsumers.put(currentId, this::initializedMcpServer); ClientSchema schema = new ClientSchema("2024-11-05", new ClientSchema.Capabilities(), new ClientSchema.Info("FIT MCP Client", "3.5.0-SNAPSHOT")); - JsonRpc.Request rpcRequest = JsonRpc.createRequest(currentId, Method.INITIALIZE.code(), schema); - request.entity(Entity.createObject(request, rpcRequest)); - log.info("Send {} method to MCP server. [sessionId={}, request={}]", - Method.INITIALIZE.code(), - this.sessionId, - rpcRequest); - try (HttpClassicClientResponse exchange = request.exchange(Object.class)) { - if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) { - log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]", - Method.INITIALIZE.code(), - this.sessionId, - exchange.statusCode()); - } else { - log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]", - Method.INITIALIZE.code(), - this.sessionId, - exchange.statusCode()); - } - } catch (IOException e) { - throw new IllegalStateException(e); - } + this.post2McpServer(Method.INITIALIZE, schema, (request, currentId) -> { + this.sessionId = request.queries() + .first("session_id") + .orElseThrow(() -> new IllegalStateException("The session_id cannot be empty.")); + this.responseConsumers.put(currentId, this::initializedMcpServer); + }); } private void initializedMcpServer(JsonRpc.Response response) { @@ -282,33 +240,11 @@ public List getTools() { if (this.isNotInitialized()) { throw new IllegalStateException("MCP client is not initialized. Please wait a moment."); } - HttpClassicClientRequest request = - this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint); - long currentId = this.getNextId(); - this.responseConsumers.put(currentId, this::getTools0); - this.pendingRequests.put(currentId, true); - JsonRpc.Request rpcRequest = JsonRpc.createRequest(currentId, Method.TOOLS_LIST.code()); - request.entity(Entity.createObject(request, rpcRequest)); - log.info("Send {} method to MCP server. [sessionId={}, request={}]", - Method.TOOLS_LIST.code(), - this.sessionId, - rpcRequest); - try (HttpClassicClientResponse exchange = request.exchange(Object.class)) { - if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) { - log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]", - Method.TOOLS_LIST.code(), - this.sessionId, - exchange.statusCode()); - } else { - log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]", - Method.TOOLS_LIST.code(), - this.sessionId, - exchange.statusCode()); - } - } catch (IOException e) { - throw new IllegalStateException(e); - } - while (this.pendingRequests.get(currentId)) { + long requestId = this.post2McpServer(Method.TOOLS_LIST, null, (request, currentId) -> { + this.responseConsumers.put(currentId, this::getTools0); + this.pendingRequests.put(currentId, true); + }); + while (this.pendingRequests.get(requestId)) { ThreadUtils.sleep(100); } synchronized (this.toolsLock) { @@ -340,38 +276,16 @@ public Object callTool(String name, Map arguments) { if (this.isNotInitialized()) { throw new IllegalStateException("MCP client is not initialized. Please wait a moment."); } - HttpClassicClientRequest request = - this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint); - long currentId = this.getNextId(); - this.responseConsumers.put(currentId, this::callTools0); - this.pendingRequests.put(currentId, true); - JsonRpc.Request rpcRequest = JsonRpc.createRequest(currentId, - Method.TOOLS_CALL.code(), - MapBuilder.get().put("name", name).put("arguments", arguments).build()); - request.entity(Entity.createObject(request, rpcRequest)); - log.info("Send {} method to MCP server. [sessionId={}, request={}]", - Method.TOOLS_CALL.code(), - this.sessionId, - rpcRequest); - try (HttpClassicClientResponse exchange = request.exchange(Object.class)) { - if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) { - log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]", - Method.TOOLS_CALL.code(), - this.sessionId, - exchange.statusCode()); - } else { - log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]", - Method.TOOLS_CALL.code(), - this.sessionId, - exchange.statusCode()); - } - } catch (IOException e) { - throw new IllegalStateException(e); - } - while (this.pendingRequests.get(currentId)) { + long requestId = this.post2McpServer(Method.TOOLS_CALL, + MapBuilder.get().put("name", name).put("arguments", arguments).build(), + (request, currentId) -> { + this.responseConsumers.put(currentId, this::callTools0); + this.pendingRequests.put(currentId, true); + }); + while (this.pendingRequests.get(requestId)) { ThreadUtils.sleep(100); } - return this.pendingResults.get(currentId); + return this.pendingResults.get(requestId); } private void callTools0(JsonRpc.Response response) { @@ -400,6 +314,37 @@ private void callTools0(JsonRpc.Response response) { this.pendingRequests.put(response.id(), false); } + private long post2McpServer(Method method, Object requestParams, + BiConsumer requestConsumer) { + HttpClassicClientRequest request = + this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint); + long currentId = this.getNextId(); + if (requestConsumer != null) { + requestConsumer.accept(request, currentId); + } + JsonRpc.Request rpcRequest = JsonRpc.createRequest(currentId, method.code(), requestParams); + request.entity(Entity.createObject(request, rpcRequest)); + log.info("Send {} method to MCP server. [sessionId={}, request={}]", method.code(), this.sessionId, rpcRequest); + try (HttpClassicClientResponse exchange = request.exchange(Object.class)) { + if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) { + log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]", + method.code(), + this.sessionId, + exchange.statusCode()); + } else { + log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]", + method.code(), + this.sessionId, + exchange.statusCode()); + } + } catch (IOException e) { + throw new IllegalStateException(StringUtils.format("Failed to {0} MCP server. [sessionId={1}]", + method.code(), + this.sessionId), e); + } + return currentId; + } + private long getNextId() { long tmpId = this.id.getAndIncrement(); if (tmpId < 0) { @@ -422,10 +367,10 @@ private boolean waitInitialized() { return true; } try { - this.initializedLock.wait(); + this.initializedLock.wait(60_000L); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - throw new IllegalStateException(e); + throw new IllegalStateException("Failed to initialize.", e); } } return this.initialized;