diff --git a/MCPForUnity/Editor/Services/ToolDiscoveryService.cs b/MCPForUnity/Editor/Services/ToolDiscoveryService.cs index b5b86c0a2..ac7cc1455 100644 --- a/MCPForUnity/Editor/Services/ToolDiscoveryService.cs +++ b/MCPForUnity/Editor/Services/ToolDiscoveryService.cs @@ -226,16 +226,6 @@ private void EnsurePreferenceInitialized(ToolMetadata metadata) { bool defaultValue = metadata.AutoRegister || metadata.IsBuiltIn; EditorPrefs.SetBool(key, defaultValue); - return; - } - - if (metadata.IsBuiltIn && !metadata.AutoRegister) - { - bool currentValue = EditorPrefs.GetBool(key, metadata.AutoRegister); - if (currentValue == metadata.AutoRegister) - { - EditorPrefs.SetBool(key, true); - } } } diff --git a/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs b/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs index 3d8584fd9..44f1d3b8c 100644 --- a/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/IMcpTransportClient.cs @@ -14,5 +14,6 @@ public interface IMcpTransportClient Task StartAsync(); Task StopAsync(); Task VerifyAsync(); + Task ReregisterToolsAsync(); } } diff --git a/MCPForUnity/Editor/Services/Transport/TransportManager.cs b/MCPForUnity/Editor/Services/Transport/TransportManager.cs index 1204e7014..b544273b4 100644 --- a/MCPForUnity/Editor/Services/Transport/TransportManager.cs +++ b/MCPForUnity/Editor/Services/Transport/TransportManager.cs @@ -42,16 +42,6 @@ private IMcpTransportClient GetOrCreateClient(TransportMode mode) }; } - private IMcpTransportClient GetClient(TransportMode mode) - { - return mode switch - { - TransportMode.Http => _httpClient, - TransportMode.Stdio => _stdioClient, - _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), - }; - } - public async Task StartAsync(TransportMode mode) { IMcpTransportClient client = GetOrCreateClient(mode); @@ -128,6 +118,20 @@ public TransportState GetState(TransportMode mode) public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected; + /// + /// Gets the active transport client for the specified mode. + /// Returns null if the client hasn't been created yet. + /// + public IMcpTransportClient GetClient(TransportMode mode) + { + return mode switch + { + TransportMode.Http => _httpClient, + TransportMode.Stdio => _stdioClient, + _ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"), + }; + } + private void UpdateState(TransportMode mode, TransportState state) { switch (mode) diff --git a/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs b/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs index ea3ed1a22..9dadc4336 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/StdioTransportClient.cs @@ -46,5 +46,12 @@ public Task VerifyAsync() return Task.FromResult(running); } + public Task ReregisterToolsAsync() + { + // Stdio transport doesn't support dynamic tool reregistration + // Tools are registered at server startup + return Task.CompletedTask; + } + } } diff --git a/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs b/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs index b94c0836f..6d2f76838 100644 --- a/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs +++ b/MCPForUnity/Editor/Services/Transport/Transports/WebSocketTransportClient.cs @@ -506,6 +506,29 @@ private async Task SendRegisterToolsAsync(CancellationToken token) McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false); } + public async Task ReregisterToolsAsync() + { + if (!IsConnected || _lifecycleCts == null) + { + McpLog.Warn("[WebSocket] Cannot reregister tools: not connected"); + return; + } + + try + { + await SendRegisterToolsAsync(_lifecycleCts.Token).ConfigureAwait(false); + McpLog.Info("[WebSocket] Tool reregistration completed", false); + } + catch (System.OperationCanceledException) + { + McpLog.Warn("[WebSocket] Tool reregistration cancelled"); + } + catch (System.Exception ex) + { + McpLog.Error($"[WebSocket] Tool reregistration failed: {ex.Message}"); + } + } + private async Task HandleExecuteAsync(JObject payload, CancellationToken token) { string commandId = payload.Value("id"); diff --git a/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs b/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs index dc5a3eabf..687fbf2eb 100644 --- a/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs +++ b/MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using MCPForUnity.Editor.Constants; using MCPForUnity.Editor.Helpers; using MCPForUnity.Editor.Services; +using MCPForUnity.Editor.Services.Transport; using MCPForUnity.Editor.Tools; using UnityEditor; using UnityEngine.UIElements; @@ -223,7 +225,11 @@ private VisualElement CreateToolRow(ToolMetadata tool) return row; } - private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSummary = true) + private void HandleToggleChange( + ToolMetadata tool, + bool enabled, + bool updateSummary = true, + bool reregisterTools = true) { MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled); @@ -231,15 +237,49 @@ private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSumm { UpdateSummary(); } + + if (reregisterTools) + { + // Trigger tool reregistration with connected MCP server + ReregisterToolsAsync(); + } + } + + private void ReregisterToolsAsync() + { + // Fire and forget - don't block UI + ThreadPool.QueueUserWorkItem(_ => + { + try + { + var transportManager = MCPServiceLocator.TransportManager; + var client = transportManager.GetClient(TransportMode.Http); + if (client != null && client.IsConnected) + { + client.ReregisterToolsAsync().Wait(); + } + } + catch (Exception ex) + { + McpLog.Warn($"Failed to reregister tools: {ex.Message}"); + } + }); } private void SetAllToolsState(bool enabled) { + bool hasChanges = false; + foreach (var tool in allTools) { if (!toolToggleMap.TryGetValue(tool.Name, out var toggle)) { - MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled); + bool currentEnabled = MCPServiceLocator.ToolDiscovery.IsToolEnabled(tool.Name); + if (currentEnabled != enabled) + { + MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled); + hasChanges = true; + } continue; } @@ -249,10 +289,17 @@ private void SetAllToolsState(bool enabled) } toggle.SetValueWithoutNotify(enabled); - HandleToggleChange(tool, enabled, updateSummary: false); + HandleToggleChange(tool, enabled, updateSummary: false, reregisterTools: false); + hasChanges = true; } UpdateSummary(); + + if (hasChanges) + { + // Trigger a single reregistration after bulk change + ReregisterToolsAsync(); + } } private void UpdateSummary() diff --git a/Server/src/transport/unity_instance_middleware.py b/Server/src/transport/unity_instance_middleware.py index 7adbe31ab..d172a304d 100644 --- a/Server/src/transport/unity_instance_middleware.py +++ b/Server/src/transport/unity_instance_middleware.py @@ -55,6 +55,43 @@ def __init__(self): super().__init__() self._active_by_key: dict[str, str] = {} self._lock = RLock() + self._unity_managed_tool_names = { + "batch_execute", + "execute_menu_item", + "find_gameobjects", + "get_test_job", + "manage_asset", + "manage_components", + "manage_editor", + "manage_gameobject", + "manage_material", + "manage_prefabs", + "manage_scene", + "manage_script", + "manage_scriptable_object", + "manage_shader", + "manage_texture", + "manage_vfx", + "read_console", + "refresh_unity", + "run_tests", + } + self._tool_alias_to_unity_target = { + # Server-side script helpers route to Unity's manage_script command. + "apply_text_edits": "manage_script", + "create_script": "manage_script", + "delete_script": "manage_script", + "find_in_file": "manage_script", + "get_sha": "manage_script", + "script_apply_edits": "manage_script", + "validate_script": "manage_script", + } + self._server_only_tool_names = { + "debug_request_context", + "execute_custom_tool", + "manage_script_capabilities", + "set_active_instance", + } def get_session_key(self, ctx) -> str: """ @@ -260,3 +297,108 @@ async def on_read_resource(self, context: MiddlewareContext, call_next): """Inject active Unity instance into resource context if available.""" await self._inject_unity_instance(context) return await call_next(context) + + async def on_list_tools(self, context: MiddlewareContext, call_next): + """Filter MCP tool listing to the Unity-enabled set when session data is available.""" + await self._inject_unity_instance(context) + tools = await call_next(context) + + if not self._should_filter_tool_listing(): + return tools + + enabled_tool_names = await self._resolve_enabled_tool_names_for_context(context) + if not enabled_tool_names: + return tools + + filtered = [] + for tool in tools: + tool_name = getattr(tool, "name", None) + if self._is_tool_visible(tool_name, enabled_tool_names): + filtered.append(tool) + + return filtered + + def _should_filter_tool_listing(self) -> bool: + transport = (config.transport_mode or "stdio").lower() + return transport == "http" and PluginHub.is_configured() + + async def _resolve_enabled_tool_names_for_context( + self, + context: MiddlewareContext, + ) -> set[str] | None: + ctx = context.fastmcp_context + user_id = ctx.get_state("user_id") if config.http_remote_hosted else None + active_instance = ctx.get_state("unity_instance") + + project_hashes = self._resolve_candidate_project_hashes(active_instance) + if not project_hashes: + try: + sessions_data = await PluginHub.get_sessions(user_id=user_id) + sessions = sessions_data.sessions if sessions_data else {} + except Exception: + return None + + if not sessions: + return None + + if len(sessions) == 1: + only_session = next(iter(sessions.values())) + only_hash = getattr(only_session, "hash", None) + if only_hash: + project_hashes = [only_hash] + else: + # Multiple sessions without explicit selection: use a union so we don't + # hide tools that are valid in at least one visible Unity instance. + project_hashes = [ + session.hash + for session in sessions.values() + if getattr(session, "hash", None) + ] + + if not project_hashes: + return None + + enabled_tool_names: set[str] = set() + for project_hash in project_hashes: + try: + registered_tools = await PluginHub.get_tools_for_project(project_hash) + except Exception: + continue + + for tool in registered_tools: + tool_name = getattr(tool, "name", None) + if isinstance(tool_name, str) and tool_name: + enabled_tool_names.add(tool_name) + + return enabled_tool_names or None + + @staticmethod + def _resolve_candidate_project_hashes(active_instance: str | None) -> list[str]: + if not active_instance: + return [] + + if "@" in active_instance: + _, _, suffix = active_instance.rpartition("@") + return [suffix] if suffix else [] + + return [active_instance] + + def _is_tool_visible(self, tool_name: str | None, enabled_tool_names: set[str]) -> bool: + if not isinstance(tool_name, str) or not tool_name: + return True + + if tool_name in self._server_only_tool_names: + return True + + if tool_name in enabled_tool_names: + return True + + unity_target = self._tool_alias_to_unity_target.get(tool_name) + if unity_target: + return unity_target in enabled_tool_names + + # Keep unknown tools visible for forward compatibility. + if tool_name not in self._unity_managed_tool_names: + return True + + return False diff --git a/Server/tests/test_transport_characterization.py b/Server/tests/test_transport_characterization.py index a0b617934..f1b698bcb 100644 --- a/Server/tests/test_transport_characterization.py +++ b/Server/tests/test_transport_characterization.py @@ -18,6 +18,7 @@ from unittest.mock import AsyncMock, Mock, MagicMock, patch, call from datetime import datetime, timezone import uuid +from types import SimpleNamespace from transport.unity_instance_middleware import UnityInstanceMiddleware, get_unity_instance_middleware, set_unity_instance_middleware from transport.plugin_registry import PluginRegistry, PluginSession @@ -31,6 +32,7 @@ SessionDetails, ) from models.models import ToolDefinitionModel +from core.config import config # ============================================================================ @@ -266,6 +268,78 @@ async def mock_call_next(_ctx): if len(c[0]) > 0 and c[0][0] == "unity_instance"] assert len(calls) == 0 + @pytest.mark.asyncio + async def test_list_tools_filters_disabled_unity_tools_and_aliases(self, mock_context, monkeypatch): + """ + Current behavior: in HTTP mode with a connected Unity session, on_list_tools() + uses PluginHub-registered tool names to hide disabled Unity tools while keeping + server-only tools visible. Aliases like create_script follow manage_script state. + """ + middleware = UnityInstanceMiddleware() + middleware_ctx = Mock() + middleware_ctx.fastmcp_context = mock_context + + mock_context.set_state("unity_instance", "Project@abc123") + monkeypatch.setattr(config, "transport_mode", "http") + + available_tools = [ + SimpleNamespace(name="manage_scene"), + SimpleNamespace(name="manage_script"), + SimpleNamespace(name="set_active_instance"), + SimpleNamespace(name="manage_asset"), + SimpleNamespace(name="create_script"), + ] + + async def call_next(_ctx): + return available_tools + + with patch.object(middleware, "_inject_unity_instance", new=AsyncMock()): + with patch("transport.unity_instance_middleware.PluginHub.is_configured", return_value=True): + with patch("transport.unity_instance_middleware.PluginHub.get_tools_for_project", new_callable=AsyncMock) as mock_get_tools: + mock_get_tools.return_value = [ + SimpleNamespace(name="manage_scene"), + SimpleNamespace(name="manage_script"), + ] + + filtered = await middleware.on_list_tools(middleware_ctx, call_next) + + names = [tool.name for tool in filtered] + assert "manage_scene" in names + assert "create_script" in names + assert "set_active_instance" in names + assert "manage_asset" not in names + + @pytest.mark.asyncio + async def test_list_tools_skips_filter_when_no_enabled_set_available(self, mock_context, monkeypatch): + """ + Current behavior: if Unity has not yet registered tool definitions for a + session, on_list_tools() leaves the FastMCP list unchanged. + """ + middleware = UnityInstanceMiddleware() + middleware_ctx = Mock() + middleware_ctx.fastmcp_context = mock_context + + mock_context.set_state("unity_instance", "Project@abc123") + monkeypatch.setattr(config, "transport_mode", "http") + + original_tools = [ + SimpleNamespace(name="manage_scene"), + SimpleNamespace(name="manage_asset"), + SimpleNamespace(name="set_active_instance"), + ] + + async def call_next(_ctx): + return original_tools + + with patch.object(middleware, "_inject_unity_instance", new=AsyncMock()): + with patch("transport.unity_instance_middleware.PluginHub.is_configured", return_value=True): + with patch("transport.unity_instance_middleware.PluginHub.get_tools_for_project", new_callable=AsyncMock) as mock_get_tools: + mock_get_tools.return_value = [] + + filtered = await middleware.on_list_tools(middleware_ctx, call_next) + + assert [tool.name for tool in filtered] == [tool.name for tool in original_tools] + # ============================================================================ # AUTO-SELECT INSTANCE TESTS diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs new file mode 100644 index 000000000..7e267c225 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs @@ -0,0 +1,155 @@ +using System.Linq; +using NUnit.Framework; +using MCPForUnity.Editor.Constants; +using MCPForUnity.Editor.Services; +using UnityEditor; + +namespace MCPForUnity.Editor.Tests.EditMode.Services +{ + [TestFixture] + public class ToolDiscoveryServiceTests + { + private const string TestToolName = "test_tool_for_testing"; + + [SetUp] + public void SetUp() + { + // Clean up any test preferences + string testKey = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(testKey)) + { + EditorPrefs.DeleteKey(testKey); + } + } + + [TearDown] + public void TearDown() + { + // Clean up test preferences after each test + string testKey = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(testKey)) + { + EditorPrefs.DeleteKey(testKey); + } + } + + [Test] + public void SetToolEnabled_WritesToEditorPrefs() + { + // Arrange + var service = new ToolDiscoveryService(); + + // Act + service.SetToolEnabled(TestToolName, false); + + // Assert + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + Assert.IsTrue(EditorPrefs.HasKey(key), "Preference key should exist after SetToolEnabled"); + Assert.IsFalse(EditorPrefs.GetBool(key, true), "Preference should be set to false"); + } + + [Test] + public void IsToolEnabled_ReturnsFalse_WhenToolDoesNotExist() + { + // Arrange - Ensure no preference exists + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + if (EditorPrefs.HasKey(key)) + { + EditorPrefs.DeleteKey(key); + } + + var service = new ToolDiscoveryService(); + + // Act - For a non-existent tool, IsToolEnabled should return false + // (since metadata.AutoRegister defaults to false for non-existent tools) + bool result = service.IsToolEnabled(TestToolName); + + // Assert - Non-existent tools return false (no metadata found) + Assert.IsFalse(result, "Non-existent tool should return false"); + } + + [Test] + public void IsToolEnabled_ReturnsStoredValue_WhenPreferenceExists() + { + // Arrange + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + EditorPrefs.SetBool(key, false); // Store false value + var service = new ToolDiscoveryService(); + + // Act + bool result = service.IsToolEnabled(TestToolName); + + // Assert + Assert.IsFalse(result, "Should return the stored preference value (false)"); + } + + [Test] + public void IsToolEnabled_ReturnsTrue_WhenPreferenceSetToTrue() + { + // Arrange + string key = EditorPrefKeys.ToolEnabledPrefix + TestToolName; + EditorPrefs.SetBool(key, true); + var service = new ToolDiscoveryService(); + + // Act + bool result = service.IsToolEnabled(TestToolName); + + // Assert + Assert.IsTrue(result, "Should return the stored preference value (true)"); + } + + [Test] + public void ToolToggle_PersistsAcrossServiceInstances() + { + // Arrange + var service1 = new ToolDiscoveryService(); + service1.SetToolEnabled(TestToolName, false); + + // Act - Create a new service instance + var service2 = new ToolDiscoveryService(); + bool result = service2.IsToolEnabled(TestToolName); + + // Assert - The disabled state should persist + Assert.IsFalse(result, "Tool state should persist across service instances"); + } + + [Test] + public void DiscoverAllTools_DoesNotOverrideStoredFalse_ForBuiltInAutoRegisterFalseTool() + { + // Arrange + var service = new ToolDiscoveryService(); + var builtInTool = service.DiscoverAllTools() + .FirstOrDefault(tool => tool.IsBuiltIn && !tool.AutoRegister); + + Assert.IsNotNull(builtInTool, "Expected at least one built-in tool with AutoRegister=false."); + + string key = EditorPrefKeys.ToolEnabledPrefix + builtInTool.Name; + bool hadOriginalKey = EditorPrefs.HasKey(key); + bool originalValue = hadOriginalKey && EditorPrefs.GetBool(key, true); + + try + { + EditorPrefs.SetBool(key, false); + service.InvalidateCache(); + + // Act + service.DiscoverAllTools(); + bool enabled = service.IsToolEnabled(builtInTool.Name); + + // Assert + Assert.IsFalse(enabled, $"Built-in tool '{builtInTool.Name}' should remain disabled when preference is false."); + } + finally + { + if (hadOriginalKey) + { + EditorPrefs.SetBool(key, originalValue); + } + else + { + EditorPrefs.DeleteKey(key); + } + } + } + } +} diff --git a/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta new file mode 100644 index 000000000..cea36ebb1 --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/EditMode/Services/ToolDiscoveryServiceTests.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 7a8b9c0d1e2f3a4b5c6d7e8f9a0b1c2d +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta b/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta new file mode 100644 index 000000000..b6e169d0d --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: ea478f98f983e4c4daafd9a07cf03e3b +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: diff --git a/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta new file mode 100644 index 000000000..a4a8007ae --- /dev/null +++ b/TestProjects/UnityMCPTests/Assets/Tests/Editor/State.meta @@ -0,0 +1,8 @@ +fileFormatVersion: 2 +guid: 778e8246ab09497419985acaee595f6c +folderAsset: yes +DefaultImporter: + externalObjects: {} + userData: + assetBundleName: + assetBundleVariant: