diff --git a/framework/fel/java/plugins/tool-mcp-server/pom.xml b/framework/fel/java/plugins/tool-mcp-server/pom.xml index 597ceeaf..123964db 100644 --- a/framework/fel/java/plugins/tool-mcp-server/pom.xml +++ b/framework/fel/java/plugins/tool-mcp-server/pom.xml @@ -32,9 +32,20 @@ + + org.junit.jupiter + junit-jupiter + test + + + org.mockito + mockito-core + test + org.assertj assertj-core + test diff --git a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpController.java b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpController.java index dd588f05..cbf3b539 100644 --- a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpController.java +++ b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpController.java @@ -48,7 +48,7 @@ * @since 2025-05-13 */ @Component -public class McpController { +public class McpController implements McpServer.ToolsChangedObserver { private static final Logger log = Logger.get(McpController.class); private static final String MESSAGE_PATH = "/mcp/message"; private static final String EVENT_ENDPOINT = "endpoint"; @@ -56,6 +56,7 @@ public class McpController { private static final String METHOD_INITIALIZE = "initialize"; private static final String METHOD_TOOLS_LIST = "tools/list"; private static final String METHOD_TOOLS_CALL = "tools/call"; + private static final String METHOD_NOTIFICATION_TOOLS_CHANGED = "notifications/tools/list_changed"; private static final String RESPONSE_OK = StringUtils.EMPTY; private final Map> emitters = new ConcurrentHashMap<>(); @@ -79,6 +80,7 @@ public McpController(@Value("${base-url}") String baseUrl, @Fit(alias = "json") this.baseUrl = notBlank(baseUrl, "The base URL for MCP server cannot be blank."); this.serializer = notNull(serializer, "The json serializer cannot be null."); notNull(mcpServer, "The MCP server cannot be null."); + mcpServer.registerToolsChangedObserver(this); this.methodHandlers.put(METHOD_INITIALIZE, new InitializeHandler(mcpServer)); this.methodHandlers.put(METHOD_TOOLS_LIST, new ToolListHandler(mcpServer)); @@ -170,4 +172,16 @@ public Object receiveMcpMessage(@RequestQuery(name = "sessionId") String session log.info("Send MCP message. [response={}]", serialized); return RESPONSE_OK; } + + @Override + public void onToolsChanged() { + JsonRpcEntity notification = new JsonRpcEntity(); + notification.setMethod(METHOD_NOTIFICATION_TOOLS_CHANGED); + String serialized = this.serializer.serialize(notification); + this.emitters.forEach((sessionId, emitter) -> { + TextEvent textEvent = TextEvent.custom().id(sessionId).event(EVENT_MESSAGE).data(serialized).build(); + emitter.emit(textEvent); + log.info("Send MCP notification: tools changed. [sessionId={}]", sessionId); + }); + } } diff --git a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpServer.java b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpServer.java index 8d6fb3d1..121a4686 100644 --- a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpServer.java +++ b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/McpServer.java @@ -40,4 +40,21 @@ public interface McpServer { * @return The tool result as a {@link Object}. */ Object callTool(String name, Map arguments); + + /** + * Registers MCP Server Tools Changed Observer. + * + * @param observer The MCP Server Tools Changed Observer as a {@link ToolsChangedObserver}. + */ + void registerToolsChangedObserver(ToolsChangedObserver observer); + + /** + * Represents the MCP Server Tools Changed Observer. + */ + interface ToolsChangedObserver { + /** + * Called when MCP Server Tools changed. + */ + void onToolsChanged(); + } } diff --git a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/handler/InitializeHandler.java b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/handler/InitializeHandler.java index 8d86f170..0ce01532 100644 --- a/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/handler/InitializeHandler.java +++ b/framework/fel/java/plugins/tool-mcp-server/src/main/java/modelengine/fel/tool/mcp/server/handler/InitializeHandler.java @@ -26,7 +26,7 @@ public class InitializeHandler extends AbstractMessageHandler tools = new ConcurrentHashMap<>(); + private final List toolsChangedObservers = new ArrayList<>(); /** * Constructs a new instance of the DefaultMcpServer class. * * @param toolExecuteService The service used to execute tools when handling tool call requests. - * @throws IllegalStateException If {@code toolExecuteService} is null. + * @throws IllegalArgumentException If {@code toolExecuteService} is null. */ public DefaultMcpServer(ToolExecuteService toolExecuteService) { this.toolExecuteService = notNull(toolExecuteService, "The tool execute service cannot be null."); @@ -72,6 +74,13 @@ public Object callTool(String name, Map arguments) { return result; } + @Override + public void registerToolsChangedObserver(ToolsChangedObserver observer) { + if (observer != null) { + this.toolsChangedObservers.add(observer); + } + } + @Override public void onToolAdded(String name, String description, Map schema) { if (StringUtils.isBlank(name)) { @@ -92,6 +101,7 @@ public void onToolAdded(String name, String description, Map sch tool.setInputSchema(schema); this.tools.put(name, tool); log.info("Tool added to MCP server. [toolName={}, description={}, schema={}]", name, description, schema); + this.toolsChangedObservers.forEach(ToolsChangedObserver::onToolsChanged); } @Override @@ -102,5 +112,6 @@ public void onToolRemoved(String name) { } this.tools.remove(name); log.info("Tool removed from MCP server. [toolName={}]", name); + this.toolsChangedObservers.forEach(ToolsChangedObserver::onToolsChanged); } } diff --git a/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/McpControllerTest.java b/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/McpControllerTest.java new file mode 100644 index 00000000..e1baa97d --- /dev/null +++ b/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/McpControllerTest.java @@ -0,0 +1,72 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.tool.mcp.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.catchThrowableOfType; +import static org.mockito.Mockito.mock; + +import modelengine.fitframework.serialization.ObjectSerializer; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +/** + * Unit test for {@link McpController}. + * + * @author 季聿阶 + * @since 2025-05-20 + */ +@DisplayName("Unit tests for McpController") +public class McpControllerTest { + private ObjectSerializer objectSerializer; + private McpServer mcpServer; + private String baseUrl; + + @BeforeEach + void setup() { + this.objectSerializer = mock(ObjectSerializer.class); + this.mcpServer = mock(McpServer.class); + this.baseUrl = "http://localhost:8080"; + } + + @Nested + @DisplayName("Constructor Tests") + class GivenConstructor { + @Test + @DisplayName("Should throw exception when base URL is null or blank") + void shouldThrowExceptionWhenBaseUrlIsNullOrEmpty() { + // Null + var exception1 = catchThrowableOfType(IllegalArgumentException.class, + () -> new McpController(null, objectSerializer, mcpServer)); + assertThat(exception1).hasMessage("The base URL for MCP server cannot be blank."); + + // Blank + var exception2 = catchThrowableOfType(IllegalArgumentException.class, + () -> new McpController("", objectSerializer, mcpServer)); + assertThat(exception2).hasMessage("The base URL for MCP server cannot be blank."); + } + + @Test + @DisplayName("Should throw exception when serializer is null") + void shouldThrowExceptionWhenSerializerIsNull() { + var exception = catchThrowableOfType(IllegalArgumentException.class, + () -> new McpController(baseUrl, null, mcpServer)); + assertThat(exception).hasMessage("The json serializer cannot be null."); + } + + @Test + @DisplayName("Should throw exception when mcpServer is null") + void shouldThrowExceptionWhenMcpServerIsNull() { + var exception = catchThrowableOfType(IllegalArgumentException.class, + () -> new McpController(baseUrl, objectSerializer, null)); + assertThat(exception).hasMessage("The MCP server cannot be null."); + } + } +} diff --git a/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/support/DefaultMcpServerTest.java b/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/support/DefaultMcpServerTest.java new file mode 100644 index 00000000..ec85628b --- /dev/null +++ b/framework/fel/java/plugins/tool-mcp-server/src/test/java/modelengine/fel/tool/mcp/server/support/DefaultMcpServerTest.java @@ -0,0 +1,182 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. + * This file is a part of the ModelEngine Project. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +package modelengine.fel.tool.mcp.server.support; + +import static modelengine.fitframework.util.ObjectUtils.cast; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.catchThrowableOfType; +import static org.mockito.Mockito.anyMap; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import modelengine.fel.tool.mcp.server.McpServer; +import modelengine.fel.tool.mcp.server.entity.ToolEntity; +import modelengine.fel.tool.service.ToolExecuteService; +import modelengine.fitframework.util.MapBuilder; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +/** + * Unit test for {@link DefaultMcpServer}. + * + * @author 季聿阶 + * @since 2025-05-20 + */ +@DisplayName("Unit tests for DefaultMcpServer") +public class DefaultMcpServerTest { + private ToolExecuteService toolExecuteService; + + @BeforeEach + void setup() { + this.toolExecuteService = mock(ToolExecuteService.class); + } + + @Nested + @DisplayName("Constructor Tests") + class GivenConstructor { + @Test + @DisplayName("Should throw IllegalArgumentException when toolExecuteService is null") + void throwIllegalArgumentExceptionWhenToolExecuteServiceIsNull() { + IllegalArgumentException exception = + catchThrowableOfType(IllegalArgumentException.class, () -> new DefaultMcpServer(null)); + assertThat(exception).isNotNull().hasMessage("The tool execute service cannot be null."); + } + } + + @Nested + @DisplayName("getInfo Method Tests") + class GivenGetInfo { + @Test + @DisplayName("Should return expected server information") + void returnExpectedServerInfo() { + McpServer server = new DefaultMcpServer(toolExecuteService); + Map info = server.getInfo(); + + assertThat(info).containsKey("protocolVersion").containsValue("2025-03-26"); + + Map capabilities = cast(info.get("capabilities")); + assertThat(capabilities).containsKey("logging").containsKey("tools"); + + Map toolsCapability = cast(capabilities.get("tools")); + assertThat(toolsCapability).containsEntry("listChanged", true); + + Map serverInfo = cast(info.get("serverInfo")); + assertThat(serverInfo).containsEntry("name", "FIT Store MCP Server") + .containsEntry("version", "3.5.0-SNAPSHOT"); + } + } + + @Nested + @DisplayName("registerToolsChangedObserver and Notification Tests") + class GivenRegisterAndNotify { + @Test + @DisplayName("Should notify observers when tools are added or removed") + void notifyObserversOnToolAddOrRemove() { + DefaultMcpServer server = new DefaultMcpServer(toolExecuteService); + McpServer.ToolsChangedObserver observer = mock(McpServer.ToolsChangedObserver.class); + server.registerToolsChangedObserver(observer); + + server.onToolAdded("tool1", + "description1", + MapBuilder.get().put("schema", "value1").build()); + verify(observer, times(1)).onToolsChanged(); + + server.onToolRemoved("tool1"); + verify(observer, times(2)).onToolsChanged(); + } + } + + @Nested + @DisplayName("onToolAdded Method Tests") + class GivenOnToolAdded { + @Test + @DisplayName("Should add tool successfully with valid parameters") + void addToolSuccessfully() { + DefaultMcpServer server = new DefaultMcpServer(toolExecuteService); + String name = "tool1"; + String description = "description1"; + Map schema = MapBuilder.get().put("input", "value").build(); + + server.onToolAdded(name, description, schema); + + List tools = server.getTools(); + assertThat(tools).hasSize(1); + + ToolEntity tool = tools.get(0); + assertThat(tool.getName()).isEqualTo(name); + assertThat(tool.getDescription()).isEqualTo(description); + assertThat(tool.getInputSchema()).isEqualTo(schema); + } + + @Test + @DisplayName("Should ignore invalid parameters and not add any tool") + void ignoreInvalidParameters() { + DefaultMcpServer server = new DefaultMcpServer(toolExecuteService); + + server.onToolAdded("", "description", MapBuilder.get().put("input", "value").build()); + assertThat(server.getTools()).isEmpty(); + + server.onToolAdded("tool1", "", MapBuilder.get().put("input", "value").build()); + assertThat(server.getTools()).isEmpty(); + + server.onToolAdded("tool1", "description", null); + assertThat(server.getTools()).isEmpty(); + } + } + + @Nested + @DisplayName("onToolRemoved Method Tests") + class GivenOnToolRemoved { + @Test + @DisplayName("Should remove an added tool correctly") + void removeToolSuccessfully() { + DefaultMcpServer server = new DefaultMcpServer(toolExecuteService); + server.onToolAdded("tool1", "desc", MapBuilder.get().put("input", "value").build()); + + server.onToolRemoved("tool1"); + + assertThat(server.getTools()).isEmpty(); + } + + @Test + @DisplayName("Should ignore removal if name is blank") + void ignoreBlankName() { + DefaultMcpServer server = new DefaultMcpServer(toolExecuteService); + server.onToolAdded("tool1", "desc", MapBuilder.get().put("input", "value").build()); + + server.onToolRemoved(""); + + assertThat(server.getTools()).hasSize(1); + } + } + + @Nested + @DisplayName("callTool Method Tests") + class GivenCallTool { + @Test + @DisplayName("Should call the tool and return correct result") + void callToolSuccessfully() { + when(toolExecuteService.execute(anyString(), anyMap())).thenReturn("result"); + McpServer server = new DefaultMcpServer(toolExecuteService); + + Object result = server.callTool("tool1", Map.of("arg1", "value1")); + + assertThat(result).isEqualTo("result"); + verify(toolExecuteService, times(1)).execute(eq("tool1"), anyMap()); + } + } +} \ No newline at end of file