Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 69 additions & 56 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,28 +724,7 @@ func TestFallthrough(t *testing.T) {
}
}

// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) (map[string]mcp.ServerProxier, *callAccumulator) {
t.Helper()

// Setup Coder MCP integration
srv, acc := createMockMCPSrv(t)
mcpSrv := httptest.NewServer(srv)
t.Cleanup(mcpSrv.Close)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer)
require.NoError(t, err)

// Initialize MCP client, fetch tools, and inject into bridge
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
require.NoError(t, proxy.Init(ctx))
tools := proxy.ListTools()
require.NotEmpty(t, tools)

return map[string]mcp.ServerProxier{proxy.Name(): proxy}, acc
}

type (
configureFunc func(string, aibridge.Recorder, *mcp.ServerProxyManager) (*aibridge.RequestBridge, error)
Expand All @@ -766,7 +745,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)
recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, streaming, configureFn, createAnthropicMessagesReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
Expand All @@ -776,7 +755,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
invocations := mcpMock.GetToolCalls(mockToolName)
require.Len(t, invocations, 1)
actual, err = json.Marshal(invocations[0])
require.NoError(t, err)
Expand Down Expand Up @@ -853,7 +832,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, mcpCalls, _, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)
recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, createOpenAIChatCompletionsReq)

// Ensure expected tool was invoked with expected input.
require.Len(t, recorderClient.toolUsages, 1)
Expand All @@ -863,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
actual, err := json.Marshal(recorderClient.toolUsages[0].Args)
require.NoError(t, err)
require.EqualValues(t, expected, actual)
invocations := mcpCalls.getCallsByTool(mockToolName)
invocations := mcpMock.GetToolCalls(mockToolName)
require.Len(t, invocations, 1)
actual, err = json.Marshal(invocations[0])
require.NoError(t, err)
Expand Down Expand Up @@ -942,8 +921,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
}

// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
// Kinda fugly right now, we can refactor this later.
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) {
func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request) (*mockRecorderClient, *mockMCPServer, *http.Response) {
t.Helper()

arc := txtar.Parse(fixture)
Expand Down Expand Up @@ -988,11 +966,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu

recorderClient := &mockRecorderClient{}

// Setup MCP mcpProxiers.
mcpProxiers, acc := setupMCPServerProxiesForTest(t, testTracer)
// Setup MCP server with integrated call tracking.
mcpMock := newMockMCPServer(t, testTracer)

// Configure the bridge with injected tools.
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
require.NoError(t, mcpMgr.Init(ctx))
b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr)
require.NoError(t, err)
Expand All @@ -1019,7 +997,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
return mockSrv.callCount.Load() == 2
}, time.Second*10, time.Millisecond*50)

return recorderClient, acc, mcpProxiers, resp
return recorderClient, mcpMock, resp
}

func TestErrorHandling(t *testing.T) {
Expand Down Expand Up @@ -1277,10 +1255,10 @@ func TestStableRequestEncoding(t *testing.T) {
t.Cleanup(cancel)

// Setup MCP tools.
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
mcpMock := newMockMCPServer(t, testTracer)

// Configure the bridge with injected tools.
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
require.NoError(t, mcpMgr.Init(ctx))

arc := txtar.Parse(tc.fixture)
Expand Down Expand Up @@ -1689,58 +1667,93 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {

const mockToolName = "coder_list_workspaces"

// callAccumulator tracks all tool invocations by name and each instance's arguments.
type callAccumulator struct {
// mockMCPServer combines the MCP server proxy with call tracking.
// This addresses the tech debt of having callAccumulator as a separate return value.
type mockMCPServer struct {
Proxies map[string]mcp.ServerProxier

calls map[string][]any
callsMu sync.Mutex
}

func newCallAccumulator() *callAccumulator {
return &callAccumulator{
calls: make(map[string][]any),
}
func (m *mockMCPServer) addCall(tool string, args any) {
m.callsMu.Lock()
defer m.callsMu.Unlock()
m.calls[tool] = append(m.calls[tool], args)
}

func (a *callAccumulator) addCall(tool string, args any) {
a.callsMu.Lock()
defer a.callsMu.Unlock()

a.calls[tool] = append(a.calls[tool], args)
}

func (a *callAccumulator) getCallsByTool(name string) []any {
a.callsMu.Lock()
defer a.callsMu.Unlock()
// GetToolCalls returns all recorded invocations for a specific tool.
func (m *mockMCPServer) GetToolCalls(name string) []any {
m.callsMu.Lock()
defer m.callsMu.Unlock()

// Protect against concurrent access of the slice.
result := make([]any, len(a.calls[name]))
copy(result, a.calls[name])
result := make([]any, len(m.calls[name]))
copy(result, m.calls[name])
return result
}

func createMockMCPSrv(t *testing.T) (http.Handler, *callAccumulator) {
func newMockMCPServer(t *testing.T, tracer trace.Tracer) *mockMCPServer {
t.Helper()

mock := &mockMCPServer{
calls: make(map[string][]any),
}

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

// Accumulate tool calls & their arguments.
acc := newCallAccumulator()
for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
mock.addCall(request.Params.Name, request.Params.Arguments)
return mcplib.NewToolResultText("mock"), nil
})
}

mcpSrv := httptest.NewServer(server.NewStreamableHTTPServer(s))
t.Cleanup(mcpSrv.Close)

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
proxy, err := mcp.NewStreamableHTTPServerProxy("coder", mcpSrv.URL, nil, nil, nil, logger, tracer)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)
require.NoError(t, proxy.Init(ctx))
tools := proxy.ListTools()
require.NotEmpty(t, tools)

mock.Proxies = map[string]mcp.ServerProxier{proxy.Name(): proxy}
return mock
}

// createMockMCPSrvHandler creates just the HTTP handler for the mock MCP server.
// Use this when you need custom server configuration (e.g., custom server name).
func createMockMCPSrvHandler(t *testing.T) http.Handler {
t.Helper()

s := server.NewMCPServer(
"Mock coder MCP server",
"1.0.0",
server.WithToolCapabilities(true),
)

for _, name := range []string{mockToolName, "coder_list_templates", "coder_template_version_parameters", "coder_get_authenticated_user", "coder_create_workspace_build"} {
tool := mcplib.NewTool(name,
mcplib.WithDescription(fmt.Sprintf("Mock of the %s tool", name)),
)
s.AddTool(tool, func(ctx context.Context, request mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
acc.addCall(request.Params.Name, request.Params.Arguments)
return mcplib.NewToolResultText("mock"), nil
})
}

return server.NewStreamableHTTPServer(s), acc
return server.NewStreamableHTTPServer(s)
}

func openaiCfg(url, key string) aibridge.OpenAIConfig {
Expand Down
4 changes: 2 additions & 2 deletions metrics_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) {
provider := aibridge.NewAnthropicProvider(anthropicCfg(mockAPI.URL, apiKey), nil)

// Setup mocked MCP server & tools.
mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer)
mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer)
mcpMock := newMockMCPServer(t, testTracer)
mcpMgr := mcp.NewServerProxyManager(mcpMock.Proxies, testTracer)
require.NoError(t, mcpMgr.Init(ctx))

bridge, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, recorder, mcpMgr, logger, metrics, testTracer)
Expand Down
10 changes: 5 additions & 5 deletions trace_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, _, proxies, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc)
recorderClient, mcpMock, resp := setupInjectedToolTest(t, antSingleInjectedTool, tc.streaming, configureFn, reqFunc)

defer resp.Body.Close()

Expand All @@ -358,7 +358,7 @@ func TestAnthropicInjectedToolsTrace(t *testing.T) {
model = "beddel"
}

for _, proxy := range proxies {
for _, proxy := range mcpMock.Proxies {
require.NotEmpty(t, proxy.ListTools())
tool := proxy.ListTools()[0]

Expand Down Expand Up @@ -607,14 +607,14 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) {
}

// Build the requirements & make the assertions which are common to all providers.
recorderClient, _, proxies, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc)
recorderClient, mcpMock, resp := setupInjectedToolTest(t, oaiSingleInjectedTool, streaming, configureFn, reqFunc)

defer resp.Body.Close()

require.Len(t, recorderClient.interceptions, 1)
intcID := recorderClient.interceptions[0].ID

for _, proxy := range proxies {
for _, proxy := range mcpMock.Proxies {
require.NotEmpty(t, proxy.ListTools())
tool := proxy.ListTools()[0]

Expand Down Expand Up @@ -687,7 +687,7 @@ func TestNewServerProxyManagerTraces(t *testing.T) {
defer func() { _ = tp.Shutdown(t.Context()) }()

serverName := "serverName"
srv, _ := createMockMCPSrv(t)
srv := createMockMCPSrvHandler(t)
mcpSrv := httptest.NewServer(srv)
t.Cleanup(mcpSrv.Close)

Expand Down
Loading