Skip to content

Commit 1728be9

Browse files
committed
refactor: integrate callAccumulator into mockMCPServer struct
This addresses the tech debt identified in issue #73 by: 1. Creating mockMCPServer struct that combines: - MCP server proxies (Proxies map) - Call tracking (previously separate callAccumulator) 2. Replacing setupMCPServerProxiesForTest with newMockMCPServer: - Returns single struct instead of multiple values - Integrates all MCP server setup into one place 3. Updating setupInjectedToolTest: - Returns 3 values instead of 4 (recorder, mcpMock, resp) - Callers use mcpMock.GetToolCalls() instead of separate accumulator 4. Adding createMockMCPSrvHandler for cases needing custom setup (e.g., custom server name in trace tests) The changes reduce cognitive overhead by keeping related data together and simplifying function signatures. Resolves #73
1 parent b202549 commit 1728be9

File tree

3 files changed

+76
-63
lines changed

3 files changed

+76
-63
lines changed

bridge_integration_test.go

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -724,28 +724,7 @@ func TestFallthrough(t *testing.T) {
724724
}
725725
}
726726

727-
// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
728-
func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) {
729-
t.Helper()
730-
731-
// Setup Coder MCP integration
732-
srv, acc := createMockMCPSrv(t)
733-
mcpSrv := httptest.NewServer(srv)
734-
t.Cleanup(mcpSrv.Close)
735-
736-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
737-
proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer)
738-
require.NoError(t, err)
739727

740-
// Initialize MCP client, fetch tools, and inject into bridge
741-
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
742-
t.Cleanup(cancel)
743-
require.NoError(t, proxy.Init(ctx))
744-
tools := proxy.ListTools()
745-
require.NotEmpty(t, tools)
746-
747-
return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc
748-
}
749728

750729
type (
751730
configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
@@ -766,7 +745,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
766745
}
767746

768747
// Build the requirements & make the assertions which are common to all providers.
769-
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
748+
recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
770749

771750
// Ensure expected tool was invoked with expected input.
772751
require.Len(t, recorderClient.toolUsages, 1)
@@ -776,7 +755,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
776755
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
777756
require.NoError(t, err)
778757
require.EqualValues(t, expected, actual)
779-
invocations := mcpCalls.getCallsByTool(mockToolName)
758+
invocations := mcpMock.GetToolCalls(mockToolName)
780759
require.Len(t, invocations, 1)
781760
actual, err = json.Marshal(invocations[0])
782761
require.NoError(t, err)
@@ -853,7 +832,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
853832
}
854833

855834
// Build the requirements & make the assertions which are common to all providers.
856-
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
835+
recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
857836

858837
// Ensure expected tool was invoked with expected input.
859838
require.Len(t, recorderClient.toolUsages, 1)
@@ -863,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
863842
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
864843
require.NoError(t, err)
865844
require.EqualValues(t, expected, actual)
866-
invocations := mcpCalls.getCallsByTool(mockToolName)
845+
invocations := mcpMock.GetToolCalls(mockToolName)
867846
require.Len(t, invocations, 1)
868847
actual, err = json.Marshal(invocations[0])
869848
require.NoError(t, err)
@@ -942,8 +921,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
942921
}
943922

944923
// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
945-
// Kinda fugly right now, we can refactor this later.
946-
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) {
924+
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *mockMCPServer, *http.Response) {
947925
t.Helper()
948926

949927
arc := txtar.Parse(fixture)
@@ -988,11 +966,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
988966

989967
recorderClient := &mockRecorderClient{}
990968

991-
// Setup MCP mcpProxiers.
992-
mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer)
969+
// Setup MCP server with integrated call tracking.
970+
mcpMock := newMockMCPServer(t, testTracer)
993971

994972
// Configure the bridge with injected tools.
995-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
973+
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
996974
require.NoError(t, mcpMgr.Init(ctx))
997975
b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr)
998976
require.NoError(t, err)
@@ -1019,7 +997,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
1019997
return mockSrv.callCount.Load() == 2
1020998
}, time.Second*10, time.Millisecond*50)
1021999

1022-
return recorderClient, acc, mcpProxiers, resp
1000+
return recorderClient, mcpMock, resp
10231001
}
10241002

10251003
func TestErrorHandling(t *testing.T) {
@@ -1277,10 +1255,10 @@ func TestStableRequestEncoding(t *testing.T) {
12771255
t.Cleanup(cancel)
12781256

12791257
// Setup MCP tools.
1280-
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
1258+
mcpMock := newMockMCPServer(t, testTracer)
12811259

12821260
// Configure the bridge with injected tools.
1283-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
1261+
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
12841262
require.NoError(t, mcpMgr.Init(ctx))
12851263

12861264
arc := txtar.Parse(tc.fixture)
@@ -1689,58 +1667,93 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
16891667

16901668
const mockToolName = "coder_list_workspaces"
16911669

1692-
// callAccumulator tracks all tool invocations by name and each instance's arguments.
1693-
type callAccumulator struct {
1670+
// mockMCPServer combines the MCP server proxy with call tracking.
1671+
// This addresses the tech debt of having callAccumulator as a separate return value.
1672+
type mockMCPServer struct {
1673+
Proxies map[string]mcp.ServerProxier
1674+
16941675
calls map[string][]any
16951676
callsMu sync.Mutex
16961677
}
16971678

1698-
func newCallAccumulator() *callAccumulator {
1699-
return &callAccumulator{
1700-
calls: make(map[string][]any),
1701-
}
1679+
func (m *mockMCPServer) addCall(tool string, args any) {
1680+
m.callsMu.Lock()
1681+
defer m.callsMu.Unlock()
1682+
m.calls[tool] = append(m.calls[tool], args)
17021683
}
17031684

1704-
func (a *callAccumulator) addCall(tool string, args any) {
1705-
a.callsMu.Lock()
1706-
defer a.callsMu.Unlock()
1707-
1708-
a.calls[tool] = append(a.calls[tool], args)
1709-
}
1710-
1711-
func (a *callAccumulator) getCallsByTool(name string) []any {
1712-
a.callsMu.Lock()
1713-
defer a.callsMu.Unlock()
1685+
// GetToolCalls returns all recorded invocations for a specific tool.
1686+
func (m *mockMCPServer) GetToolCalls(name string) []any {
1687+
m.callsMu.Lock()
1688+
defer m.callsMu.Unlock()
17141689

17151690
// Protect against concurrent access of the slice.
1716-
result := make([]any, len(a.calls[name]))
1717-
copy(result, a.calls[name])
1691+
result := make([]any, len(m.calls[name]))
1692+
copy(result, m.calls[name])
17181693
return result
17191694
}
17201695

1721-
func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
1696+
func newMockMCPServer(t *testing.T, tracer trace.Tracer) *mockMCPServer {
17221697
t.Helper()
17231698

1699+
mock := &mockMCPServer{
1700+
calls: make(map[string][]any),
1701+
}
1702+
17241703
s := server.NewMCPServer(
17251704
"Mock coder MCP server",
17261705
"1.0.0",
17271706
server.WithToolCapabilities(true),
17281707
)
17291708

1730-
// Accumulate tool calls & their arguments.
1731-
acc := newCallAccumulator()
1709+
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
1710+
tool := mcplib.NewTool(name,
1711+
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
1712+
)
1713+
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1714+
mock.addCall(request.Params.Name, request.Params.Arguments)
1715+
return mcplib.NewToolResultText("mock"), nil
1716+
})
1717+
}
1718+
1719+
mcpSrv := httptest.NewServer(server.NewStreamableHTTPServer(s))
1720+
t.Cleanup(mcpSrv.Close)
1721+
1722+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1723+
proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer)
1724+
require.NoError(t, err)
1725+
1726+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1727+
t.Cleanup(cancel)
1728+
require.NoError(t, proxy.Init(ctx))
1729+
tools := proxy.ListTools()
1730+
require.NotEmpty(t, tools)
1731+
1732+
mock.Proxies = map[string]mcp.ServerProxier{proxy.Name(): proxy}
1733+
return mock
1734+
}
1735+
1736+
// createMockMCPSrvHandler creates just the HTTP handler for the mock MCP server.
1737+
// Use this when you need custom server configuration (e.g., custom server name).
1738+
func createMockMCPSrvHandler(t *testing.T) http.Handler {
1739+
t.Helper()
1740+
1741+
s := server.NewMCPServer(
1742+
"Mock coder MCP server",
1743+
"1.0.0",
1744+
server.WithToolCapabilities(true),
1745+
)
17321746

17331747
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
17341748
tool := mcplib.NewTool(name,
17351749
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
17361750
)
17371751
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
1738-
acc.addCall(request.Params.Name, request.Params.Arguments)
17391752
return mcplib.NewToolResultText("mock"), nil
17401753
})
17411754
}
17421755

1743-
return server.NewStreamableHTTPServer(s), acc
1756+
return server.NewStreamableHTTPServer(s)
17441757
}
17451758

17461759
func openaiCfg(url, key string) aibridge.OpenAIConfig {

metrics_integration_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
237237
provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil)
238238

239239
// Setup mocked MCP server & tools.
240-
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
241-
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
240+
mcpMock := newMockMCPServer(t, testTracer)
241+
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
242242
require.NoError(t, mcpMgr.Init(ctx))
243243

244244
bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer)

trace_integration_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) {
346346
}
347347

348348
// Build the requirements & make the assertions which are common to all providers.
349-
recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc)
349+
recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc)
350350

351351
defer resp.Body.Close()
352352

@@ -358,7 +358,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) {
358358
model = "beddel"
359359
}
360360

361-
for _, proxy := range proxies {
361+
for _, proxy := range mcpMock.Proxies {
362362
require.NotEmpty(t, proxy.ListTools())
363363
tool := proxy.ListTools()[0]
364364

@@ -607,14 +607,14 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) {
607607
}
608608

609609
// Build the requirements & make the assertions which are common to all providers.
610-
recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc)
610+
recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc)
611611

612612
defer resp.Body.Close()
613613

614614
require.Len(t, recorderClient.interceptions, 1)
615615
intcID := recorderClient.interceptions[0].ID
616616

617-
for _, proxy := range proxies {
617+
for _, proxy := range mcpMock.Proxies {
618618
require.NotEmpty(t, proxy.ListTools())
619619
tool := proxy.ListTools()[0]
620620

@@ -687,7 +687,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) {
687687
defer func() { _ = tp.Shutdown(t.Context()) }()
688688

689689
serverName := "serverName"
690-
srv, _ := createMockMCPSrv(t)
690+
srv := createMockMCPSrvHandler(t)
691691
mcpSrv := httptest.NewServer(srv)
692692
t.Cleanup(mcpSrv.Close)
693693

0 commit comments

Comments
 (0)