Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,21 @@ static String identify(ToolInfo toolInfo) {
static String identify(String namespace, String toolName) {
return StringUtils.format("{0}:{1}", namespace, toolName);
}

/**
* Parses the tool identifier.
*
* @param identifier The identifier to be parsed.
* @return An array containing the namespace and tool name.
*/
static String[] parseIdentifier(String identifier) {
if (identifier == null || identifier.isEmpty()) {
throw new IllegalArgumentException("Identifier cannot be null or empty.");
}
String[] parts = identifier.split(":", 2);
if (parts.length != 2) {
throw new IllegalArgumentException("Invalid identifier format. Expected 'namespace:name'.");
}
return parts;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import static modelengine.fitframework.inspection.Validation.notNull;

import modelengine.fel.tool.ToolInfoEntity;
import modelengine.fel.core.tool.ToolInfo;
import modelengine.fel.tool.Tool;
import modelengine.fel.tool.ToolFactory;
import modelengine.fel.tool.ToolFactoryRepository;
import modelengine.fel.tool.ToolInfoEntity;
import modelengine.fel.tool.service.ToolExecuteService;
import modelengine.fel.tool.service.ToolRepository;
import modelengine.fitframework.annotation.Component;
Expand Down Expand Up @@ -70,13 +71,15 @@ public String execute(String group, String toolName, Map<String, Object> jsonObj
@Override
@Fitable(id = "standard")
public String execute(String uniqueName, String jsonArgs) {
return this.execute("Common", uniqueName, jsonArgs);
String[] strings = ToolInfo.parseIdentifier(uniqueName);
return this.execute(strings[0], strings[1], jsonArgs);
}

@Override
@Fitable(id = "standard")
public String execute(String uniqueName, Map<String, Object> jsonObject) {
return this.execute("Common", uniqueName, jsonObject);
String[] strings = ToolInfo.parseIdentifier(uniqueName);
return this.execute(strings[0], strings[1], jsonObject);
}

private Tool getTool(String group, String toolName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import static modelengine.fitframework.util.ObjectUtils.cast;

import modelengine.fel.tool.mcp.client.McpClient;
import modelengine.fel.tool.mcp.entity.ClientSchema;
import modelengine.fel.tool.mcp.entity.JsonRpc;
import modelengine.fel.tool.mcp.entity.Method;
import modelengine.fel.tool.mcp.entity.Server;
import modelengine.fel.tool.mcp.entity.ServerSchema;
import modelengine.fel.tool.mcp.entity.Tool;
import modelengine.fit.http.client.HttpClassicClient;
import modelengine.fit.http.client.HttpClassicClientRequest;
Expand All @@ -26,8 +27,11 @@
import modelengine.fitframework.schedule.ThreadPoolExecutor;
import modelengine.fitframework.schedule.ThreadPoolScheduler;
import modelengine.fitframework.serialization.ObjectSerializer;
import modelengine.fitframework.util.CollectionUtils;
import modelengine.fitframework.util.LockUtils;
import modelengine.fitframework.util.MapBuilder;
import modelengine.fitframework.util.ObjectUtils;
import modelengine.fitframework.util.StringUtils;
import modelengine.fitframework.util.ThreadUtils;
import modelengine.fitframework.util.UuidUtils;

Expand All @@ -54,37 +58,43 @@ public class DefaultMcpClient implements McpClient {

private final ObjectSerializer jsonSerializer;
private final HttpClassicClient client;
private final String connectionString;
private final String baseUri;
private final String sseEndpoint;
private final String name;
private final AtomicLong id = new AtomicLong(0);

private volatile String messageUrl;
private volatile String messageEndpoint;
private volatile String sessionId;
private volatile Server server;
private volatile ServerSchema serverSchema;
private volatile boolean initialized = false;
private final List<Tool> tools = new ArrayList<>();
private final Object initializedLock = LockUtils.newSynchronizedLock();
private final Object toolsLock = LockUtils.newSynchronizedLock();
private final Map<Long, Consumer<JsonRpc.Response<Long>>> responseConsumers = new ConcurrentHashMap<>();
private final Map<Long, Boolean> pendingRequests = new ConcurrentHashMap<>();
private final Map<Long, Object> pendingResults = new ConcurrentHashMap<>();

/**
* Constructs a new instance of the DefaultMcpClient.
*
* @param jsonSerializer The serializer used for JSON serialization and deserialization.
* @param client The HTTP client used for communication with the MCP server.
* @param connectionString The connection string used to establish the initial connection.
* @param baseUri The base URI of the MCP server.
* @param sseEndpoint The endpoint for the Server-Sent Events (SSE) connection.
*/
public DefaultMcpClient(ObjectSerializer jsonSerializer, HttpClassicClient client, String connectionString) {
public DefaultMcpClient(ObjectSerializer jsonSerializer, HttpClassicClient client, String baseUri,
String sseEndpoint) {
this.jsonSerializer = jsonSerializer;
this.client = client;
this.connectionString = connectionString;
this.baseUri = baseUri;
this.sseEndpoint = sseEndpoint;
this.name = UuidUtils.randomUuidString();
}

@Override
public void initialize() {
HttpClassicClientRequest request = this.client.createRequest(HttpRequestMethod.GET, connectionString);
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.GET, this.baseUri + this.sseEndpoint);
Choir<TextEvent> messages = this.client.exchangeStream(request, TextEvent.class);
ThreadPoolExecutor threadPool = ThreadPoolExecutor.custom()
.threadPoolName("mcp-client-" + this.name)
Expand Down Expand Up @@ -125,7 +135,13 @@ public void initialize() {
}

private void consumeTextEvent(TextEvent textEvent) {
log.info("Receive message from MCP server. [message={}]", textEvent.data());
log.info("Receive message from MCP server. [id={}, event={}, message={}]",
textEvent.id(),
textEvent.event(),
textEvent.data());
if (StringUtils.isBlank(textEvent.event()) || StringUtils.isBlank((String) textEvent.data())) {
return;
}
if (Objects.equals(textEvent.event(), "endpoint")) {
this.initializeMcpServer(textEvent);
return;
Expand Down Expand Up @@ -157,7 +173,8 @@ private void pingServer() {
log.info("MCP client is not initialized and {} method will be delayed.", Method.PING.code());
return;
}
HttpClassicClientRequest request = this.client.createRequest(HttpRequestMethod.POST, this.messageUrl);
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.PING.code());
request.entity(Entity.createObject(request, rpcRequest));
Expand All @@ -183,12 +200,17 @@ private void pingServer() {
}

private void initializeMcpServer(TextEvent textEvent) {
this.messageUrl = textEvent.data().toString();
this.sessionId = textEvent.id();
HttpClassicClientRequest request = this.client.createRequest(HttpRequestMethod.POST, this.messageUrl);
this.messageEndpoint = textEvent.data().toString();
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
this.sessionId =
request.queries().first("session_id").orElseThrow(() -> new IllegalStateException("no session_id"));
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::initializedMcpServer);
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.INITIALIZE.code());
ClientSchema schema = new ClientSchema("2024-11-05",
new ClientSchema.Capabilities(),
new ClientSchema.Info("FIT MCP Client", "3.5.0-SNAPSHOT"));
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId, Method.INITIALIZE.code(), schema);
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.INITIALIZE.code(),
Expand Down Expand Up @@ -223,9 +245,9 @@ private void initializedMcpServer(JsonRpc.Response<Long> response) {
this.initialized = true;
this.initializedLock.notifyAll();
}
this.server = ObjectUtils.toCustomObject(response.result(), Server.class);
log.info("MCP server has initialized. [server={}]", this.server);
HttpClassicClientRequest request = this.client.createRequest(HttpRequestMethod.POST, this.messageUrl);
this.recordServerSchema(response);
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
JsonRpc.Notification notification = JsonRpc.createNotification(Method.NOTIFICATION_INITIALIZED.code());
request.entity(Entity.createObject(request, notification));
log.info("Send {} method to MCP server. [sessionId={}, notification={}]",
Expand All @@ -249,12 +271,19 @@ private void initializedMcpServer(JsonRpc.Response<Long> response) {
}
}

private void recordServerSchema(JsonRpc.Response<Long> response) {
Map<String, Object> mapResult = cast(response.result());
this.serverSchema = ServerSchema.create(mapResult);
log.info("MCP server has initialized. [server={}]", this.serverSchema);
}

@Override
public List<Tool> getTools() {
if (this.isNotInitialized()) {
throw new IllegalStateException("MCP client is not initialized. Please wait a moment.");
}
HttpClassicClientRequest request = this.client.createRequest(HttpRequestMethod.POST, this.messageUrl);
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::getTools0);
this.pendingRequests.put(currentId, true);
Expand Down Expand Up @@ -292,6 +321,7 @@ private void getTools0(JsonRpc.Response<Long> response) {
log.error("Failed to get tools list from MCP server. [sessionId={}, response={}]",
this.sessionId,
response);
this.pendingRequests.put(response.id(), false);
return;
}
Map<String, Object> result = cast(response.result());
Expand All @@ -301,16 +331,73 @@ private void getTools0(JsonRpc.Response<Long> response) {
this.tools.addAll(rawTools.stream()
.map(rawTool -> ObjectUtils.<Tool>toCustomObject(rawTool, Tool.class))
.toList());
this.pendingRequests.put(response.id(), false);
}
this.pendingRequests.put(response.id(), false);
}

@Override
public Object callTool(String name, Map<String, Object> arguments) {
if (this.isNotInitialized()) {
throw new IllegalStateException("MCP client is not initialized. Please wait a moment.");
}
return null;
HttpClassicClientRequest request =
this.client.createRequest(HttpRequestMethod.POST, this.baseUri + this.messageEndpoint);
long currentId = this.getNextId();
this.responseConsumers.put(currentId, this::callTools0);
this.pendingRequests.put(currentId, true);
JsonRpc.Request<Long> rpcRequest = JsonRpc.createRequest(currentId,
Method.TOOLS_CALL.code(),
MapBuilder.<String, Object>get().put("name", name).put("arguments", arguments).build());
request.entity(Entity.createObject(request, rpcRequest));
log.info("Send {} method to MCP server. [sessionId={}, request={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
rpcRequest);
try (HttpClassicClientResponse<Object> exchange = request.exchange(Object.class)) {
if (exchange.statusCode() >= 200 && exchange.statusCode() < 300) {
log.info("Send {} method to MCP server successfully. [sessionId={}, statusCode={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
exchange.statusCode());
} else {
log.error("Failed to {} MCP server. [sessionId={}, statusCode={}]",
Method.TOOLS_CALL.code(),
this.sessionId,
exchange.statusCode());
}
} catch (IOException e) {
throw new IllegalStateException(e);
}
while (this.pendingRequests.get(currentId)) {
ThreadUtils.sleep(100);
}
return this.pendingResults.get(currentId);
}

private void callTools0(JsonRpc.Response<Long> response) {
if (response.error() != null) {
log.error("Failed to call tool from MCP server. [sessionId={}, response={}]", this.sessionId, response);
this.pendingRequests.put(response.id(), false);
return;
}
Map<String, Object> result = cast(response.result());
boolean isError = cast(result.get("isError"));
if (isError) {
log.error("Failed to call tool from MCP server. [sessionId={}, result={}]", this.sessionId, result);
this.pendingRequests.put(response.id(), false);
return;
}
List<Map<String, Object>> rawContents = cast(result.get("content"));
if (CollectionUtils.isEmpty(rawContents)) {
log.error("Failed to call tool from MCP server: no result returned. [sessionId={}, result={}]",
this.sessionId,
result);
this.pendingRequests.put(response.id(), false);
return;
}
Map<String, Object> rawContent = rawContents.get(0);
this.pendingResults.put(response.id(), rawContent.get("text"));
this.pendingRequests.put(response.id(), false);
}

private long getNextId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public DefaultMcpClientFactory(HttpClassicClientFactory clientFactory,
}

@Override
public McpClient create(String connectionString) {
return new DefaultMcpClient(this.jsonSerializer, this.client, connectionString);
public McpClient create(String baseUri, String sseEndpoint) {
return new DefaultMcpClient(this.jsonSerializer, this.client, baseUri, sseEndpoint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

package modelengine.fel.tool.mcp.server;

import modelengine.fel.tool.mcp.entity.Server;
import modelengine.fel.tool.mcp.entity.ServerSchema;
import modelengine.fel.tool.mcp.entity.Tool;

import java.util.List;
Expand All @@ -20,21 +20,21 @@
*/
public interface McpServer {
/**
* Gets MCP Server Info.
* Gets MCP server schema.
*
* @return The MCP Server Info as a {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}.
* @return The MCP server schema as a {@link ServerSchema}.
*/
Server getInfo();
ServerSchema getSchema();

/**
* Gets MCP Server Tools.
* Gets MCP server tools.
*
* @return The MCP Server Tools as a {@link List}{@code <}{@link Tool}{@code >}.
* @return The MCP server tools as a {@link List}{@code <}{@link Tool}{@code >}.
*/
List<Tool> getTools();

/**
* Calls MCP Server Tool.
* Calls MCP server tool.
*
* @param name The tool name as a {@link String}.
* @param arguments The tool arguments as a {@link Map}{@code <}{@link String}{@code , }{@link Object}{@code >}.
Expand All @@ -43,18 +43,18 @@ public interface McpServer {
Object callTool(String name, Map<String, Object> arguments);

/**
* Registers MCP Server Tools Changed Observer.
* Registers MCP server tools changed observer.
*
* @param observer The MCP Server Tools Changed Observer as a {@link ToolsChangedObserver}.
* @param observer The MCP server tools changed observer as a {@link ToolsChangedObserver}.
*/
void registerToolsChangedObserver(ToolsChangedObserver observer);

/**
* Represents the MCP Server Tools Changed Observer.
* Represents the MCP server tools changed observer.
*/
interface ToolsChangedObserver {
/**
* Called when MCP Server Tools changed.
* Called when MCP server tools changed.
*/
void onToolsChanged();
}
Expand Down
Loading