Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions MCPForUnity/Editor/Services/ToolDiscoveryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,6 @@ private void EnsurePreferenceInitialized(ToolMetadata metadata)
{
bool defaultValue = metadata.AutoRegister || metadata.IsBuiltIn;
EditorPrefs.SetBool(key, defaultValue);
return;
}

if (metadata.IsBuiltIn && !metadata.AutoRegister)
{
bool currentValue = EditorPrefs.GetBool(key, metadata.AutoRegister);
if (currentValue == metadata.AutoRegister)
{
EditorPrefs.SetBool(key, true);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ public interface IMcpTransportClient
Task<bool> StartAsync();
Task StopAsync();
Task<bool> VerifyAsync();
Task ReregisterToolsAsync();
}
}
24 changes: 14 additions & 10 deletions MCPForUnity/Editor/Services/Transport/TransportManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ private IMcpTransportClient GetOrCreateClient(TransportMode mode)
};
}

private IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}

public async Task<bool> StartAsync(TransportMode mode)
{
IMcpTransportClient client = GetOrCreateClient(mode);
Expand Down Expand Up @@ -128,6 +118,20 @@ public TransportState GetState(TransportMode mode)

public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected;

/// <summary>
/// Gets the active transport client for the specified mode.
/// Returns null if the client hasn't been created yet.
/// </summary>
public IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}

private void UpdateState(TransportMode mode, TransportState state)
{
switch (mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,12 @@ public Task<bool> VerifyAsync()
return Task.FromResult(running);
}

public Task ReregisterToolsAsync()
{
// Stdio transport doesn't support dynamic tool reregistration
// Tools are registered at server startup
return Task.CompletedTask;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,29 @@ private async Task SendRegisterToolsAsync(CancellationToken token)
McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false);
}

public async Task ReregisterToolsAsync()
{
if (!IsConnected || _lifecycleCts == null)
{
McpLog.Warn("[WebSocket] Cannot reregister tools: not connected");
return;
}

try
{
await SendRegisterToolsAsync(_lifecycleCts.Token).ConfigureAwait(false);
McpLog.Info("[WebSocket] Tool reregistration completed", false);
}
catch (System.OperationCanceledException)
{
McpLog.Warn("[WebSocket] Tool reregistration cancelled");
}
catch (System.Exception ex)
{
McpLog.Error($"[WebSocket] Tool reregistration failed: {ex.Message}");
}
}

private async Task HandleExecuteAsync(JObject payload, CancellationToken token)
{
string commandId = payload.Value<string>("id");
Expand Down
56 changes: 53 additions & 3 deletions MCPForUnity/Editor/Windows/Components/Tools/McpToolsSection.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services;
using MCPForUnity.Editor.Services.Transport;
using MCPForUnity.Editor.Tools;
using UnityEditor;
using UnityEditor.UIElements;
using UnityEngine.UIElements;

namespace MCPForUnity.Editor.Windows.Components.Tools
Expand Down Expand Up @@ -223,23 +226,63 @@ private VisualElement CreateToolRow(ToolMetadata tool)
return row;
}

private void HandleToggleChange(ToolMetadata tool, bool enabled, bool updateSummary = true)
private void HandleToggleChange(
ToolMetadata tool,
bool enabled,
bool updateSummary = true,
bool reregisterTools = true)
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);

if (updateSummary)
{
UpdateSummary();
}

if (reregisterTools)
{
// Trigger tool reregistration with connected MCP server
ReregisterToolsAsync();
}
}

private void ReregisterToolsAsync()
{
// Fire and forget - don't block UI thread
var transportManager = MCPServiceLocator.TransportManager;
var client = transportManager.GetClient(TransportMode.Http);
if (client == null || !client.IsConnected)
{
return;
}

_ = Task.Run(async () =>
{
try
{
await client.ReregisterToolsAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
McpLog.Warn($"Failed to reregister tools: {ex}");
}
});
}

private void SetAllToolsState(bool enabled)
{
bool hasChanges = false;

foreach (var tool in allTools)
{
if (!toolToggleMap.TryGetValue(tool.Name, out var toggle))
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
bool currentEnabled = MCPServiceLocator.ToolDiscovery.IsToolEnabled(tool.Name);
if (currentEnabled != enabled)
{
MCPServiceLocator.ToolDiscovery.SetToolEnabled(tool.Name, enabled);
hasChanges = true;
}
continue;
}

Expand All @@ -249,10 +292,17 @@ private void SetAllToolsState(bool enabled)
}

toggle.SetValueWithoutNotify(enabled);
HandleToggleChange(tool, enabled, updateSummary: false);
HandleToggleChange(tool, enabled, updateSummary: false, reregisterTools: false);
hasChanges = true;
}

UpdateSummary();

if (hasChanges)
{
// Trigger a single reregistration after bulk change
ReregisterToolsAsync();
}
}

private void UpdateSummary()
Expand Down
56 changes: 49 additions & 7 deletions Server/src/services/custom_tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from starlette.requests import Request
from starlette.responses import JSONResponse

from core.config import config
from models.models import MCPResponse, ToolDefinitionModel, ToolParameterModel
from core.logging_decorator import log_execution
from core.telemetry_decorator import telemetry_tool
Expand All @@ -28,6 +29,23 @@
_MAX_POLL_SECONDS = 600


def get_user_id_from_context(ctx: Context) -> str | None:
"""Read user_id from request-scoped context in remote-hosted mode."""
if not config.http_remote_hosted:
return None

get_state = getattr(ctx, "get_state", None)
if not callable(get_state):
return None

try:
user_id = get_state("user_id")
except Exception:
return None

return user_id if isinstance(user_id, str) and user_id else None


class RegisterToolsPayload(BaseModel):
project_id: str
project_hash: str | None = None
Expand Down Expand Up @@ -84,30 +102,40 @@ async def register_tools(request: Request) -> JSONResponse:
return JSONResponse(response.model_dump())

# --- Public API for MCP tools ---------------------------------------
async def list_registered_tools(self, project_id: str) -> list[ToolDefinitionModel]:
async def list_registered_tools(
self,
project_id: str,
user_id: str | None = None,
) -> list[ToolDefinitionModel]:
legacy = list(self._project_tools.get(project_id, {}).values())
hub_tools = await PluginHub.get_tools_for_project(project_id)
hub_tools = await PluginHub.get_tools_for_project(project_id, user_id=user_id)
return legacy + hub_tools

async def get_tool_definition(self, project_id: str, tool_name: str) -> ToolDefinitionModel | None:
async def get_tool_definition(
self,
project_id: str,
tool_name: str,
user_id: str | None = None,
) -> ToolDefinitionModel | None:
tool = self._project_tools.get(project_id, {}).get(tool_name)
if tool:
return tool
return await PluginHub.get_tool_definition(project_id, tool_name)
return await PluginHub.get_tool_definition(project_id, tool_name, user_id=user_id)

async def execute_tool(
self,
project_id: str,
tool_name: str,
unity_instance: str | None,
params: dict[str, object] | None = None,
user_id: str | None = None,
) -> MCPResponse:
params = params or {}
logger.info(
f"Executing tool '{tool_name}' for project '{project_id}' (instance={unity_instance}) with params: {params}"
)

definition = await self.get_tool_definition(project_id, tool_name)
definition = await self.get_tool_definition(project_id, tool_name, user_id=user_id)
if definition is None:
return MCPResponse(
success=False,
Expand All @@ -119,6 +147,7 @@ async def execute_tool(
unity_instance,
tool_name,
params,
user_id=user_id,
)

if not definition.requires_polling:
Expand All @@ -132,6 +161,7 @@ async def execute_tool(
params,
response,
definition.poll_action or "status",
user_id=user_id,
)
logger.info(f"Tool '{tool_name}' polled response: {result}")
return result
Expand All @@ -156,6 +186,7 @@ async def _poll_until_complete(
initial_params: dict[str, object],
initial_response,
poll_action: str,
user_id: str | None = None,
) -> MCPResponse:
poll_params = dict(initial_params)
poll_params["action"] = poll_action or "status"
Expand All @@ -180,7 +211,11 @@ async def _poll_until_complete(

try:
response = await send_with_unity_instance(
async_send_command_with_retry, unity_instance, tool_name, poll_params
async_send_command_with_retry,
unity_instance,
tool_name,
poll_params,
user_id=user_id,
)
except Exception as exc: # pragma: no cover - network/domain reload variability
logger.debug(f"Polling {tool_name} failed, will retry: {exc}")
Expand Down Expand Up @@ -347,8 +382,15 @@ async def _handler(ctx: Context, **kwargs) -> MCPResponse:
)

params = {k: v for k, v in kwargs.items() if v is not None}
user_id = get_user_id_from_context(ctx)
service = CustomToolService.get_instance()
return await service.execute_tool(project_id, definition.name, unity_instance, params)
return await service.execute_tool(
project_id,
definition.name,
unity_instance,
params,
user_id=user_id,
)

_handler.__name__ = f"custom_tool_{definition.name}"
_handler.__doc__ = definition.description or ""
Expand Down
25 changes: 24 additions & 1 deletion Server/src/services/registry/tool_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def mcp_for_unity_tool(
name: str | None = None,
description: str | None = None,
unity_target: str | None = "self",
**kwargs
) -> Callable:
"""
Expand All @@ -20,6 +21,10 @@ def mcp_for_unity_tool(
Args:
name: Tool name (defaults to function name)
description: Tool description
unity_target: Visibility target used by middleware filtering.
- "self" (default): tool follows its own enabled state.
- None: server-only tool, always visible in tool listing.
- "<tool_name>": alias tool that follows another Unity tool state.
**kwargs: Additional arguments passed to @mcp.tool()

Example:
Expand All @@ -29,11 +34,29 @@ async def my_custom_tool(ctx: Context, ...):
"""
def decorator(func: Callable) -> Callable:
tool_name = name if name is not None else func.__name__
# Safety guard: unity_target is internal metadata and must never leak into mcp.tool kwargs.
tool_kwargs = dict(kwargs) # Create a copy to avoid side effects
if "unity_target" in tool_kwargs:
del tool_kwargs["unity_target"]

if unity_target is None:
normalized_unity_target: str | None = None
elif isinstance(unity_target, str) and unity_target.strip():
normalized_unity_target = (
tool_name if unity_target == "self" else unity_target.strip()
)
else:
raise ValueError(
f"Invalid unity_target for tool '{tool_name}': {unity_target!r}. "
"Expected None or a non-empty string."
)

_tool_registry.append({
'func': func,
'name': tool_name,
'description': description,
'kwargs': kwargs
'unity_target': normalized_unity_target,
'kwargs': tool_kwargs,
})

return func
Expand Down
Loading