@@ -724,28 +724,7 @@ func TestFallthrough(t *testing.T) {
724724 }
725725}
726726
727- // setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
728- func setupMCPServerProxiesForTest (t * testing.T , tracer trace.Tracer ) (map [string ]mcp.ServerProxier , * callAccumulator ) {
729- t .Helper ()
730-
731- // Setup Coder MCP integration
732- srv , acc := createMockMCPSrv (t )
733- mcpSrv := httptest .NewServer (srv )
734- t .Cleanup (mcpSrv .Close )
735-
736- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
737- proxy , err := mcp .NewStreamableHTTPServerProxy ("coder" , mcpSrv .URL , nil , nil , nil , logger , tracer )
738- require .NoError (t , err )
739727
740- // Initialize MCP client, fetch tools, and inject into bridge
741- ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
742- t .Cleanup (cancel )
743- require .NoError (t , proxy .Init (ctx ))
744- tools := proxy .ListTools ()
745- require .NotEmpty (t , tools )
746-
747- return map [string ]mcp.ServerProxier {proxy .Name (): proxy }, acc
748- }
749728
750729type (
751730 configureFunc func (string , aibridge.Recorder , * mcp.ServerProxyManager ) (* aibridge.RequestBridge , error )
@@ -766,7 +745,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
766745 }
767746
768747 // Build the requirements & make the assertions which are common to all providers.
769- recorderClient , mcpCalls , _ , resp := setupInjectedToolTest (t , antSingleInjectedTool , streaming , configureFn , createAnthropicMessagesReq )
748+ recorderClient , mcpMock , resp := setupInjectedToolTest (t , antSingleInjectedTool , streaming , configureFn , createAnthropicMessagesReq )
770749
771750 // Ensure expected tool was invoked with expected input.
772751 require .Len (t , recorderClient .toolUsages , 1 )
@@ -776,7 +755,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
776755 actual , err := json .Marshal (recorderClient .toolUsages [0 ].Args )
777756 require .NoError (t , err )
778757 require .EqualValues (t , expected , actual )
779- invocations := mcpCalls . getCallsByTool (mockToolName )
758+ invocations := mcpMock . GetToolCalls (mockToolName )
780759 require .Len (t , invocations , 1 )
781760 actual , err = json .Marshal (invocations [0 ])
782761 require .NoError (t , err )
@@ -853,7 +832,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
853832 }
854833
855834 // Build the requirements & make the assertions which are common to all providers.
856- recorderClient , mcpCalls , _ , resp := setupInjectedToolTest (t , oaiSingleInjectedTool , streaming , configureFn , createOpenAIChatCompletionsReq )
835+ recorderClient , mcpMock , resp := setupInjectedToolTest (t , oaiSingleInjectedTool , streaming , configureFn , createOpenAIChatCompletionsReq )
857836
858837 // Ensure expected tool was invoked with expected input.
859838 require .Len (t , recorderClient .toolUsages , 1 )
@@ -863,7 +842,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
863842 actual , err := json .Marshal (recorderClient .toolUsages [0 ].Args )
864843 require .NoError (t , err )
865844 require .EqualValues (t , expected , actual )
866- invocations := mcpCalls . getCallsByTool (mockToolName )
845+ invocations := mcpMock . GetToolCalls (mockToolName )
867846 require .Len (t , invocations , 1 )
868847 actual , err = json .Marshal (invocations [0 ])
869848 require .NoError (t , err )
@@ -942,8 +921,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
942921}
943922
944923// setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests.
945- // Kinda fugly right now, we can refactor this later.
946- func setupInjectedToolTest (t * testing.T , fixture []byte , streaming bool , configureFn configureFunc , createRequestFn func (* testing.T , string , []byte ) * http.Request ) (* mockRecorderClient , * callAccumulator , map [string ]mcp.ServerProxier , * http.Response ) {
924+ func setupInjectedToolTest (t * testing.T , fixture []byte , streaming bool , configureFn configureFunc , createRequestFn func (* testing.T , string , []byte ) * http.Request ) (* mockRecorderClient , * mockMCPServer , * http.Response ) {
947925 t .Helper ()
948926
949927 arc := txtar .Parse (fixture )
@@ -988,11 +966,11 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
988966
989967 recorderClient := & mockRecorderClient {}
990968
991- // Setup MCP mcpProxiers .
992- mcpProxiers , acc := setupMCPServerProxiesForTest (t , testTracer )
969+ // Setup MCP server with integrated call tracking .
970+ mcpMock := newMockMCPServer (t , testTracer )
993971
994972 // Configure the bridge with injected tools.
995- mcpMgr := mcp .NewServerProxyManager (mcpProxiers , testTracer )
973+ mcpMgr := mcp .NewServerProxyManager (mcpMock . Proxies , testTracer )
996974 require .NoError (t , mcpMgr .Init (ctx ))
997975 b , err := configureFn (mockSrv .URL , recorderClient , mcpMgr )
998976 require .NoError (t , err )
@@ -1019,7 +997,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
1019997 return mockSrv .callCount .Load () == 2
1020998 }, time .Second * 10 , time .Millisecond * 50 )
1021999
1022- return recorderClient , acc , mcpProxiers , resp
1000+ return recorderClient , mcpMock , resp
10231001}
10241002
10251003func TestErrorHandling (t * testing.T ) {
@@ -1277,10 +1255,10 @@ func TestStableRequestEncoding(t *testing.T) {
12771255 t .Cleanup (cancel )
12781256
12791257 // Setup MCP tools.
1280- mcpProxiers , _ := setupMCPServerProxiesForTest (t , testTracer )
1258+ mcpMock := newMockMCPServer (t , testTracer )
12811259
12821260 // Configure the bridge with injected tools.
1283- mcpMgr := mcp .NewServerProxyManager (mcpProxiers , testTracer )
1261+ mcpMgr := mcp .NewServerProxyManager (mcpMock . Proxies , testTracer )
12841262 require .NoError (t , mcpMgr .Init (ctx ))
12851263
12861264 arc := txtar .Parse (tc .fixture )
@@ -1689,58 +1667,93 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
16891667
16901668const mockToolName = "coder_list_workspaces"
16911669
1692- // callAccumulator tracks all tool invocations by name and each instance's arguments.
1693- type callAccumulator struct {
1670+ // mockMCPServer combines the MCP server proxy with call tracking.
1671+ // This addresses the tech debt of having callAccumulator as a separate return value.
1672+ type mockMCPServer struct {
1673+ Proxies map [string ]mcp.ServerProxier
1674+
16941675 calls map [string ][]any
16951676 callsMu sync.Mutex
16961677}
16971678
1698- func newCallAccumulator () * callAccumulator {
1699- return & callAccumulator {
1700- calls : make ( map [ string ][] any ),
1701- }
1679+ func ( m * mockMCPServer ) addCall ( tool string , args any ) {
1680+ m . callsMu . Lock ()
1681+ defer m . callsMu . Unlock ()
1682+ m . calls [ tool ] = append ( m . calls [ tool ], args )
17021683}
17031684
1704- func (a * callAccumulator ) addCall (tool string , args any ) {
1705- a .callsMu .Lock ()
1706- defer a .callsMu .Unlock ()
1707-
1708- a .calls [tool ] = append (a .calls [tool ], args )
1709- }
1710-
1711- func (a * callAccumulator ) getCallsByTool (name string ) []any {
1712- a .callsMu .Lock ()
1713- defer a .callsMu .Unlock ()
1685+ // GetToolCalls returns all recorded invocations for a specific tool.
1686+ func (m * mockMCPServer ) GetToolCalls (name string ) []any {
1687+ m .callsMu .Lock ()
1688+ defer m .callsMu .Unlock ()
17141689
17151690 // Protect against concurrent access of the slice.
1716- result := make ([]any , len (a .calls [name ]))
1717- copy (result , a .calls [name ])
1691+ result := make ([]any , len (m .calls [name ]))
1692+ copy (result , m .calls [name ])
17181693 return result
17191694}
17201695
1721- func createMockMCPSrv (t * testing.T ) (http. Handler , * callAccumulator ) {
1696+ func newMockMCPServer (t * testing.T , tracer trace. Tracer ) * mockMCPServer {
17221697 t .Helper ()
17231698
1699+ mock := & mockMCPServer {
1700+ calls : make (map [string ][]any ),
1701+ }
1702+
17241703 s := server .NewMCPServer (
17251704 "Mock coder MCP server" ,
17261705 "1.0.0" ,
17271706 server .WithToolCapabilities (true ),
17281707 )
17291708
1730- // Accumulate tool calls & their arguments.
1731- acc := newCallAccumulator ()
1709+ for _ , name := range []string {mockToolName , "coder_list_templates" , "coder_template_version_parameters" , "coder_get_authenticated_user" , "coder_create_workspace_build" } {
1710+ tool := mcplib .NewTool (name ,
1711+ mcplib .WithDescription (fmt .Sprintf ("Mock of the %s tool" , name )),
1712+ )
1713+ s .AddTool (tool , func (ctx context.Context , request mcplib.CallToolRequest ) (* mcplib.CallToolResult , error ) {
1714+ mock .addCall (request .Params .Name , request .Params .Arguments )
1715+ return mcplib .NewToolResultText ("mock" ), nil
1716+ })
1717+ }
1718+
1719+ mcpSrv := httptest .NewServer (server .NewStreamableHTTPServer (s ))
1720+ t .Cleanup (mcpSrv .Close )
1721+
1722+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
1723+ proxy , err := mcp .NewStreamableHTTPServerProxy ("coder" , mcpSrv .URL , nil , nil , nil , logger , tracer )
1724+ require .NoError (t , err )
1725+
1726+ ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
1727+ t .Cleanup (cancel )
1728+ require .NoError (t , proxy .Init (ctx ))
1729+ tools := proxy .ListTools ()
1730+ require .NotEmpty (t , tools )
1731+
1732+ mock .Proxies = map [string ]mcp.ServerProxier {proxy .Name (): proxy }
1733+ return mock
1734+ }
1735+
1736+ // createMockMCPSrvHandler creates just the HTTP handler for the mock MCP server.
1737+ // Use this when you need custom server configuration (e.g., custom server name).
1738+ func createMockMCPSrvHandler (t * testing.T ) http.Handler {
1739+ t .Helper ()
1740+
1741+ s := server .NewMCPServer (
1742+ "Mock coder MCP server" ,
1743+ "1.0.0" ,
1744+ server .WithToolCapabilities (true ),
1745+ )
17321746
17331747 for _ , name := range []string {mockToolName , "coder_list_templates" , "coder_template_version_parameters" , "coder_get_authenticated_user" , "coder_create_workspace_build" } {
17341748 tool := mcplib .NewTool (name ,
17351749 mcplib .WithDescription (fmt .Sprintf ("Mock of the %s tool" , name )),
17361750 )
17371751 s .AddTool (tool , func (ctx context.Context , request mcplib.CallToolRequest ) (* mcplib.CallToolResult , error ) {
1738- acc .addCall (request .Params .Name , request .Params .Arguments )
17391752 return mcplib .NewToolResultText ("mock" ), nil
17401753 })
17411754 }
17421755
1743- return server .NewStreamableHTTPServer (s ), acc
1756+ return server .NewStreamableHTTPServer (s )
17441757}
17451758
17461759func openaiCfg (url , key string ) aibridge.OpenAIConfig {
0 commit comments