Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import modelengine.fel.tool.mcp.client.McpClient;
import modelengine.fel.tool.mcp.entity.ClientSchema;
import modelengine.fel.tool.mcp.entity.Event;
import modelengine.fel.tool.mcp.entity.JsonRpc;
import modelengine.fel.tool.mcp.entity.Method;
import modelengine.fel.tool.mcp.entity.ServerSchema;
Expand Down Expand Up @@ -43,6 +44,7 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

/**
Expand Down Expand Up @@ -129,8 +131,8 @@ public void initialize() {
.runnable(this::pingServer)
.policy(ExecutePolicy.fixedDelay(DELAY_MILLIS))
.build(), DELAY_MILLIS);
while (!this.waitInitialized()) {
ThreadUtils.sleep(100);
if (!this.waitInitialized()) {
throw new IllegalStateException("Failed to initialize.");
}
}

Expand All @@ -142,7 +144,7 @@ private void consumeTextEvent(TextEvent textEvent) {
if (StringUtils.isBlank(textEvent.event()) || StringUtils.isBlank((String) textEvent.data())) {
return;
}
if (Objects.equals(textEvent.event(), "endpoint")) {
if (Objects.equals(textEvent.event(), Event.ENDPOINT.code())) {
this.initializeMcpServer(textEvent);
return;
}
Expand Down Expand Up @@ -173,64 +175,20 @@ private void pingServer() {
log.info("MCP client is not initialized and {} method will be delayed.", Method.PING.code());
return;
}
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.PING.code());
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.PING.code(),
this.sessionId,
rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
Method.PING.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
Method.PING.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
this.post2McpServer(Method.PING, null, null);
}

private void initializeMcpServer(TextEvent textEvent) {
this.messageEndpoint = textEvent.data().toString();
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
this.sessionId =
request.queries().first("session_id").orElseThrow(() -> new IllegalStateException("no session_id"));
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::initializedMcpServer);
ClientSchema schema = new ClientSchema("2024-11-05",
new ClientSchema.Capabilities(),
new ClientSchema.Info("FIT MCP Client", "3.5.0-SNAPSHOT"));
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.INITIALIZE.code(), schema);
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.INITIALIZE.code(),
this.sessionId,
rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
Method.INITIALIZE.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
Method.INITIALIZE.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
this.post2McpServer(Method.INITIALIZE, schema, (request, currentId) -> {
this.sessionId = request.queries()
.first("session_id")
.orElseThrow(() -> new IllegalStateException("The session_id cannot be empty."));
this.responseConsumers.put(currentId, this::initializedMcpServer);
});
}

private void initializedMcpServer(JsonRpc.Response<Long> response) {
Expand Down Expand Up @@ -282,33 +240,11 @@ public List<Tool> getTools() {
if (this.isNotInitialized()) {
throw new IllegalStateException("MCP client is not initialized. Please wait a moment.");
}
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::getTools0);
this.pendingRequests.put(currentId, true);
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.TOOLS_LIST.code());
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.TOOLS_LIST.code(),
this.sessionId,
rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
Method.TOOLS_LIST.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
Method.TOOLS_LIST.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
while (this.pendingRequests.get(currentId)) {
long requestId = this.post2McpServer(Method.TOOLS_LIST, null, (request, currentId) -> {
this.responseConsumers.put(currentId, this::getTools0);
this.pendingRequests.put(currentId, true);
});
while (this.pendingRequests.get(requestId)) {
ThreadUtils.sleep(100);
}
synchronized (this.toolsLock) {
Expand Down Expand Up @@ -340,38 +276,16 @@ public Object callTool(String name, Map<String, Object> arguments) {
if (this.isNotInitialized()) {
throw new IllegalStateException("MCP client is not initialized. Please wait a moment.");
}
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::callTools0);
this.pendingRequests.put(currentId, true);
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId,
Method.TOOLS_CALL.code(),
MapBuilder.<String, Object>get().put("name", name).put("arguments", arguments).build());
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
while (this.pendingRequests.get(currentId)) {
long requestId = this.post2McpServer(Method.TOOLS_CALL,
MapBuilder.<String, Object>get().put("name", name).put("arguments", arguments).build(),
(request, currentId) -> {
this.responseConsumers.put(currentId, this::callTools0);
this.pendingRequests.put(currentId, true);
});
while (this.pendingRequests.get(requestId)) {
ThreadUtils.sleep(100);
}
return this.pendingResults.get(currentId);
return this.pendingResults.get(requestId);
}

private void callTools0(JsonRpc.Response<Long> response) {
Expand Down Expand Up @@ -400,6 +314,37 @@ private void callTools0(JsonRpc.Response<Long> response) {
this.pendingRequests.put(response.id(), false);
}

private long post2McpServer(Method method, Object requestParams,
BiConsumer<HttpClassicClientRequest, Long> requestConsumer) {
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
if (requestConsumer != null) {
requestConsumer.accept(request, currentId);
}
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, method.code(), requestParams);
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]", method.code(), this.sessionId, rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
method.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
method.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(StringUtils.format("Failed to {0} MCP server. [sessionId={1}]",
method.code(),
this.sessionId), e);
}
return currentId;
}

private long getNextId() {
long tmpId = this.id.getAndIncrement();
if (tmpId < 0) {
Expand All @@ -422,10 +367,10 @@ private boolean waitInitialized() {
return true;
}
try {
this.initializedLock.wait();
this.initializedLock.wait(60_000L);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new IllegalStateException(e);
throw new IllegalStateException("Failed to initialize.", e);
}
}
return this.initialized;
Expand Down