diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b4ea460..b84c8dd 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -724,28 +724,7 @@ func TestFallthrough(t *testing.T) { } } -// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools -func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) { - t.Helper() - - // Setup Coder MCP integration - srv, acc := createMockMCPSrv(t) - mcpSrv := httptest.NewServer(srv) - t.Cleanup(mcpSrv.Close) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) - require.NoError(t, err) - // Initialize MCP client, fetch tools, and inject into bridge - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) - t.Cleanup(cancel) - require.NoError(t, proxy.Init(ctx)) - tools := proxy.ListTools() - require.NotEmpty(t, tools) - - return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc -} type ( configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) @@ -766,7 +745,7 @@ func TestAnthropicInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -776,7 +755,7 @@ func TestAnthropicInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mcpMock.GetToolCalls(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -853,7 +832,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq) // Ensure expected tool was invoked with expected input. require.Len(t, recorderClient.toolUsages, 1) @@ -863,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) { actual, err := json.Marshal(recorderClient.toolUsages[0].Args) require.NoError(t, err) require.EqualValues(t, expected, actual) - invocations := mcpCalls.getCallsByTool(mockToolName) + invocations := mcpMock.GetToolCalls(mockToolName) require.Len(t, invocations, 1) actual, err = json.Marshal(invocations[0]) require.NoError(t, err) @@ -942,8 +921,7 @@ func TestOpenAIInjectedTools(t *testing.T) { } // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -// Kinda fugly right now, we can refactor this later. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *mockMCPServer, *http.Response) { t.Helper() arc := txtar.Parse(fixture) @@ -988,11 +966,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu recorderClient := &mockRecorderClient{} - // Setup MCP mcpProxiers. - mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer) + // Setup MCP server with integrated call tracking. + mcpMock := newMockMCPServer(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) require.NoError(t, err) @@ -1019,7 +997,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu return mockSrv.callCount.Load() == 2 }, time.Second*10, time.Millisecond*50) - return recorderClient, acc, mcpProxiers, resp + return recorderClient, mcpMock, resp } func TestErrorHandling(t *testing.T) { @@ -1277,10 +1255,10 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(cancel) // Setup MCP tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) + mcpMock := newMockMCPServer(t, testTracer) // Configure the bridge with injected tools. - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) arc := txtar.Parse(tc.fixture) @@ -1689,58 +1667,93 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) { const mockToolName = "coder_list_workspaces" -// callAccumulator tracks all tool invocations by name and each instance's arguments. -type callAccumulator struct { +// mockMCPServer combines the MCP server proxy with call tracking. +// This addresses the tech debt of having callAccumulator as a separate return value. +type mockMCPServer struct { + Proxies map[string]mcp.ServerProxier + calls map[string][]any callsMu sync.Mutex } -func newCallAccumulator() *callAccumulator { - return &callAccumulator{ - calls: make(map[string][]any), - } +func (m *mockMCPServer) addCall(tool string, args any) { + m.callsMu.Lock() + defer m.callsMu.Unlock() + m.calls[tool] = append(m.calls[tool], args) } -func (a *callAccumulator) addCall(tool string, args any) { - a.callsMu.Lock() - defer a.callsMu.Unlock() - - a.calls[tool] = append(a.calls[tool], args) -} - -func (a *callAccumulator) getCallsByTool(name string) []any { - a.callsMu.Lock() - defer a.callsMu.Unlock() +// GetToolCalls returns all recorded invocations for a specific tool. +func (m *mockMCPServer) GetToolCalls(name string) []any { + m.callsMu.Lock() + defer m.callsMu.Unlock() // Protect against concurrent access of the slice. - result := make([]any, len(a.calls[name])) - copy(result, a.calls[name]) + result := make([]any, len(m.calls[name])) + copy(result, m.calls[name]) return result } -func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) { +func newMockMCPServer(t *testing.T, tracer trace.Tracer) *mockMCPServer { t.Helper() + mock := &mockMCPServer{ + calls: make(map[string][]any), + } + s := server.NewMCPServer( "Mock coder MCP server", "1.0.0", server.WithToolCapabilities(true), ) - // Accumulate tool calls & their arguments. - acc := newCallAccumulator() + for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { + tool := mcplib.NewTool(name, + mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), + ) + s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + mock.addCall(request.Params.Name, request.Params.Arguments) + return mcplib.NewToolResultText("mock"), nil + }) + } + + mcpSrv := httptest.NewServer(server.NewStreamableHTTPServer(s)) + t.Cleanup(mcpSrv.Close) + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + require.NoError(t, proxy.Init(ctx)) + tools := proxy.ListTools() + require.NotEmpty(t, tools) + + mock.Proxies = map[string]mcp.ServerProxier{proxy.Name(): proxy} + return mock +} + +// createMockMCPSrvHandler creates just the HTTP handler for the mock MCP server. +// Use this when you need custom server configuration (e.g., custom server name). +func createMockMCPSrvHandler(t *testing.T) http.Handler { + t.Helper() + + s := server.NewMCPServer( + "Mock coder MCP server", + "1.0.0", + server.WithToolCapabilities(true), + ) for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} { tool := mcplib.NewTool(name, mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)), ) s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { - acc.addCall(request.Params.Name, request.Params.Arguments) return mcplib.NewToolResultText("mock"), nil }) } - return server.NewStreamableHTTPServer(s), acc + return server.NewStreamableHTTPServer(s) } func openaiCfg(url, key string) aibridge.OpenAIConfig { diff --git a/metrics_integration_test.go b/metrics_integration_test.go index f326dec..387a6f6 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -237,8 +237,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil) // Setup mocked MCP server & tools. - mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) - mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) + mcpMock := newMockMCPServer(t, testTracer) + mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer) require.NoError(t, mcpMgr.Init(ctx)) bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer) diff --git a/trace_integration_test.go b/trace_integration_test.go index ee6574d..e7eadc5 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc) defer resp.Body.Close() @@ -358,7 +358,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) { model = "beddel" } - for _, proxy := range proxies { + for _, proxy := range mcpMock.Proxies { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] @@ -607,14 +607,14 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { } // Build the requirements & make the assertions which are common to all providers. - recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) + recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc) defer resp.Body.Close() require.Len(t, recorderClient.interceptions, 1) intcID := recorderClient.interceptions[0].ID - for _, proxy := range proxies { + for _, proxy := range mcpMock.Proxies { require.NotEmpty(t, proxy.ListTools()) tool := proxy.ListTools()[0] @@ -687,7 +687,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) { defer func() { _ = tp.Shutdown(t.Context()) }() serverName := "serverName" - srv, _ := createMockMCPSrv(t) + srv := createMockMCPSrvHandler(t) mcpSrv := httptest.NewServer(srv) t.Cleanup(mcpSrv.Close)