@@ -23,6 +23,7 @@ import (
2323 "github.com/anthropics/anthropic-sdk-go"
2424 "github.com/anthropics/anthropic-sdk-go/packages/ssestream"
2525 "github.com/coder/aibridge"
26+ "github.com/coder/aibridge/aibtest"
2627 "github.com/coder/aibridge/mcp"
2728 "github.com/google/uuid"
2829 mcplib "github.com/mark3labs/mcp-go/mcp"
@@ -33,7 +34,6 @@ import (
3334 "github.com/stretchr/testify/require"
3435 "github.com/tidwall/gjson"
3536 "github.com/tidwall/sjson"
36- "go.opentelemetry.io/otel"
3737 "go.opentelemetry.io/otel/trace"
3838 "go.uber.org/goleak"
3939 "golang.org/x/tools/txtar"
@@ -66,19 +66,23 @@ var (
6666 //go:embed fixtures/openai/non_stream_error.txtar
6767 oaiNonStreamErr []byte
6868
69- testTracer = otel . Tracer ( "forTesting" )
69+ testTracer = aibtest . TestTracer ( )
7070)
7171
7272const (
73+ // Legacy constants - use aibtest.Fixture* in new code.
7374 fixtureRequest = "request"
7475 fixtureStreamingResponse = "streaming"
7576 fixtureNonStreamingResponse = "non-streaming"
7677 fixtureStreamingToolResponse = "streaming/tool-call"
7778 fixtureNonStreamingToolResponse = "non-streaming/tool-call"
7879 fixtureResponse = "response"
7980
81+ // Legacy constants - use aibtest.DefaultAPIKey/DefaultUserID in new code.
8082 apiKey = "api-key"
8183 userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e"
84+
85+ mockToolName = "coder_list_workspaces"
8286)
8387
8488func TestMain (m * testing.M ) {
@@ -112,46 +116,17 @@ func TestAnthropicMessages(t *testing.T) {
112116 t .Run (fmt .Sprintf ("%s/streaming=%v" , t .Name (), tc .streaming ), func (t * testing.T ) {
113117 t .Parallel ()
114118
115- arc := txtar .Parse (antSingleBuiltinTool )
116- t .Logf ("%s: %s" , t .Name (), arc .Comment )
117-
118- files := filesMap (arc )
119- require .Len (t , files , 3 )
120- require .Contains (t , files , fixtureRequest )
121- require .Contains (t , files , fixtureStreamingResponse )
122- require .Contains (t , files , fixtureNonStreamingResponse )
123-
124- reqBody := files [fixtureRequest ]
125-
126- // Add the stream param to the request.
127- newBody , err := setJSON (reqBody , "stream" , tc .streaming )
128- require .NoError (t , err )
129- reqBody = newBody
130-
131119 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
132120 t .Cleanup (cancel )
133- srv := newMockServer (ctx , t , files , nil )
134- t .Cleanup (srv .Close )
135121
136- recorderClient := & mockRecorderClient {}
137-
138- logger := slogtest .Make (t , & slogtest.Options {}).Leveled (slog .LevelDebug )
139- providers := []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (srv .URL , apiKey ), nil )}
140- b , err := aibridge .NewRequestBridge (ctx , providers , recorderClient , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
141- require .NoError (t , err )
122+ upstream := aibtest .NewMockUpstreamServer (t , ctx , antSingleBuiltinTool )
123+ reqBody := aibtest .SetStreamingInRequest (t , upstream .Files ()[aibtest .FixtureRequest ], tc .streaming )
142124
143- mockSrv := httptest .NewUnstartedServer (b )
144- t .Cleanup (mockSrv .Close )
145- mockSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
146- return aibridge .AsActor (ctx , userID , nil )
147- }
148- mockSrv .Start ()
125+ bridge := aibtest .NewTestBridge (t , ctx , aibtest.TestBridgeOptions {
126+ Provider : aibridge .NewAnthropicProvider (aibtest .AnthropicConfig (upstream .URL , aibtest .DefaultAPIKey ), nil ),
127+ })
149128
150- // Make API call to aibridge for Anthropic /v1/messages
151- req := createAnthropicMessagesReq (t , mockSrv .URL , reqBody )
152- client := & http.Client {}
153- resp , err := client .Do (req )
154- require .NoError (t , err )
129+ resp := bridge .DoAnthropicRequest (t , reqBody )
155130 require .Equal (t , http .StatusOK , resp .StatusCode )
156131 defer resp .Body .Close ()
157132
@@ -170,23 +145,25 @@ func TestAnthropicMessages(t *testing.T) {
170145 // One for message_start, one for message_delta.
171146 expectedTokenRecordings = 2
172147 }
173- require .Len (t , recorderClient . tokenUsages , expectedTokenRecordings )
148+ require .Len (t , bridge . Recorder . TokenUsages () , expectedTokenRecordings )
174149
175- assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens ( recorderClient . tokenUsages ), "input tokens miscalculated" )
176- assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens ( recorderClient . tokenUsages ), "output tokens miscalculated" )
150+ assert .EqualValues (t , tc .expectedInputTokens , bridge . Recorder . TotalInputTokens ( ), "input tokens miscalculated" )
151+ assert .EqualValues (t , tc .expectedOutputTokens , bridge . Recorder . TotalOutputTokens ( ), "output tokens miscalculated" )
177152
178- require .Len (t , recorderClient .toolUsages , 1 )
179- assert .Equal (t , "Read" , recorderClient .toolUsages [0 ].Tool )
180- require .IsType (t , json.RawMessage {}, recorderClient .toolUsages [0 ].Args )
153+ toolUsages := bridge .Recorder .ToolUsages ()
154+ require .Len (t , toolUsages , 1 )
155+ assert .Equal (t , "Read" , toolUsages [0 ].Tool )
156+ require .IsType (t , json.RawMessage {}, toolUsages [0 ].Args )
181157 var args map [string ]any
182- require .NoError (t , json .Unmarshal (recorderClient . toolUsages [0 ].Args .(json.RawMessage ), & args ))
158+ require .NoError (t , json .Unmarshal (toolUsages [0 ].Args .(json.RawMessage ), & args ))
183159 require .Contains (t , args , "file_path" )
184160 assert .Equal (t , "/tmp/blah/foo" , args ["file_path" ])
185161
186- require .Len (t , recorderClient .userPrompts , 1 )
187- assert .Equal (t , "read the foo file" , recorderClient .userPrompts [0 ].Prompt )
162+ promptUsages := bridge .Recorder .PromptUsages ()
163+ require .Len (t , promptUsages , 1 )
164+ assert .Equal (t , "read the foo file" , promptUsages [0 ].Prompt )
188165
189- recorderClient . verifyAllInterceptionsEnded (t )
166+ bridge . Recorder . VerifyAllInterceptionsEnded (t )
190167 })
191168 }
192169 })
@@ -198,9 +175,8 @@ func TestAWSBedrockIntegration(t *testing.T) {
198175 t .Run ("invalid config" , func (t * testing.T ) {
199176 t .Parallel ()
200177
201- arc := txtar .Parse (antSingleBuiltinTool )
202- files := filesMap (arc )
203- reqBody := files [fixtureRequest ]
178+ upstream := aibtest .NewMockUpstreamServer (t , t .Context (), antSingleBuiltinTool )
179+ reqBody := upstream .Files ()[aibtest .FixtureRequest ]
204180
205181 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
206182 t .Cleanup (cancel )
@@ -214,21 +190,21 @@ func TestAWSBedrockIntegration(t *testing.T) {
214190 SmallFastModel : "test-haiku" ,
215191 }
216192
217- recorderClient := & mockRecorderClient {}
193+ recorder := aibtest . NewMockRecorder ()
218194 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true }).Leveled (slog .LevelDebug )
219195 b , err := aibridge .NewRequestBridge (ctx , []aibridge.Provider {
220- aibridge .NewAnthropicProvider (anthropicCfg ("http://unused" , apiKey ), bedrockCfg ),
221- }, recorderClient , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
196+ aibridge .NewAnthropicProvider (aibtest . AnthropicConfig ("http://unused" , aibtest . DefaultAPIKey ), bedrockCfg ),
197+ }, recorder , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
222198 require .NoError (t , err )
223199
224200 mockSrv := httptest .NewUnstartedServer (b )
225201 t .Cleanup (mockSrv .Close )
226202 mockSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
227- return aibridge .AsActor (ctx , userID , nil )
203+ return aibridge .AsActor (ctx , aibtest . DefaultUserID , nil )
228204 }
229205 mockSrv .Start ()
230206
231- req := createAnthropicMessagesReq (t , mockSrv .URL , reqBody )
207+ req := aibtest . CreateAnthropicMessagesRequest (t , mockSrv .URL , reqBody )
232208 resp , err := http .DefaultClient .Do (req )
233209 require .NoError (t , err )
234210 defer resp .Body .Close ()
@@ -245,18 +221,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
245221 t .Run (fmt .Sprintf ("%s/streaming=%v" , t .Name (), streaming ), func (t * testing.T ) {
246222 t .Parallel ()
247223
248- arc := txtar .Parse (antSingleBuiltinTool )
249- t .Logf ("%s: %s" , t .Name (), arc .Comment )
250-
251- files := filesMap (arc )
252- require .Len (t , files , 3 )
253- require .Contains (t , files , fixtureRequest )
254-
255- reqBody := files [fixtureRequest ]
256-
257- newBody , err := setJSON (reqBody , "stream" , streaming )
258- require .NoError (t , err )
259- reqBody = newBody
224+ upstream := aibtest .NewMockUpstreamServer (t , t .Context (), antSingleBuiltinTool )
225+ files := upstream .Files ()
226+ reqBody := aibtest .SetStreamingInRequest (t , files [aibtest .FixtureRequest ], streaming )
260227
261228 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
262229 t .Cleanup (cancel )
@@ -265,6 +232,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
265232 var requestCount int
266233
267234 // Create a mock server that intercepts requests to capture model name and return fixtures.
235+ // This is specific to testing Bedrock URL routing (model in path).
268236 srv := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
269237 requestCount ++
270238 t .Logf ("Mock server received request #%d: %s %s (streaming=%v)" , requestCount , r .Method , r .URL .Path , streaming )
@@ -281,12 +249,12 @@ func TestAWSBedrockIntegration(t *testing.T) {
281249 // Return appropriate fixture response.
282250 var respBody []byte
283251 if streaming {
284- respBody = files [fixtureStreamingResponse ]
252+ respBody = files [aibtest . FixtureStreamingResponse ]
285253 w .Header ().Set ("Content-Type" , "text/event-stream" )
286254 w .Header ().Set ("Cache-Control" , "no-cache" )
287255 w .Header ().Set ("Connection" , "keep-alive" )
288256 } else {
289- respBody = files [fixtureNonStreamingResponse ]
257+ respBody = files [aibtest . FixtureNonStreamingResponse ]
290258 w .Header ().Set ("Content-Type" , "application/json" )
291259 }
292260
@@ -311,26 +279,24 @@ func TestAWSBedrockIntegration(t *testing.T) {
311279 EndpointOverride : srv .URL ,
312280 }
313281
314- recorderClient := & mockRecorderClient {}
315-
282+ recorder := aibtest .NewMockRecorder ()
316283 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true }).Leveled (slog .LevelDebug )
317284 b , err := aibridge .NewRequestBridge (
318- ctx , []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (srv .URL , apiKey ), bedrockCfg )},
319- recorderClient , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
285+ ctx , []aibridge.Provider {aibridge .NewAnthropicProvider (aibtest . AnthropicConfig (srv .URL , aibtest . DefaultAPIKey ), bedrockCfg )},
286+ recorder , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
320287 require .NoError (t , err )
321288
322289 mockBridgeSrv := httptest .NewUnstartedServer (b )
323290 t .Cleanup (mockBridgeSrv .Close )
324291 mockBridgeSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
325- return aibridge .AsActor (ctx , userID , nil )
292+ return aibridge .AsActor (ctx , aibtest . DefaultUserID , nil )
326293 }
327294 mockBridgeSrv .Start ()
328295
329296 // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock.
330297 // We override the AWS Bedrock client to route requests through our mock server.
331- req := createAnthropicMessagesReq (t , mockBridgeSrv .URL , reqBody )
332- client := & http.Client {}
333- resp , err := client .Do (req )
298+ req := aibtest .CreateAnthropicMessagesRequest (t , mockBridgeSrv .URL , reqBody )
299+ resp , err := http .DefaultClient .Do (req )
334300 require .NoError (t , err )
335301 defer resp .Body .Close ()
336302
@@ -345,9 +311,10 @@ func TestAWSBedrockIntegration(t *testing.T) {
345311 // and the interception data.
346312 require .Equal (t , requestCount , 1 )
347313 require .Equal (t , bedrockCfg .Model , receivedModelName )
348- require .Len (t , recorderClient .interceptions , 1 )
349- require .Equal (t , recorderClient .interceptions [0 ].Model , bedrockCfg .Model )
350- recorderClient .verifyAllInterceptionsEnded (t )
314+ interceptions := recorder .Interceptions ()
315+ require .Len (t , interceptions , 1 )
316+ require .Equal (t , interceptions [0 ].Model , bedrockCfg .Model )
317+ recorder .VerifyAllInterceptionsEnded (t )
351318 })
352319 }
353320 })
@@ -379,46 +346,17 @@ func TestOpenAIChatCompletions(t *testing.T) {
379346 t .Run (fmt .Sprintf ("%s/streaming=%v" , t .Name (), tc .streaming ), func (t * testing.T ) {
380347 t .Parallel ()
381348
382- arc := txtar .Parse (oaiSingleBuiltinTool )
383- t .Logf ("%s: %s" , t .Name (), arc .Comment )
384-
385- files := filesMap (arc )
386- require .Len (t , files , 3 )
387- require .Contains (t , files , fixtureRequest )
388- require .Contains (t , files , fixtureStreamingResponse )
389- require .Contains (t , files , fixtureNonStreamingResponse )
390-
391- reqBody := files [fixtureRequest ]
392-
393- // Add the stream param to the request.
394- newBody , err := setJSON (reqBody , "stream" , tc .streaming )
395- require .NoError (t , err )
396- reqBody = newBody
397-
398349 ctx , cancel := context .WithTimeout (t .Context (), time .Second * 30 )
399350 t .Cleanup (cancel )
400- srv := newMockServer (ctx , t , files , nil )
401- t .Cleanup (srv .Close )
402-
403- recorderClient := & mockRecorderClient {}
404351
405- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
406- providers := []aibridge.Provider {aibridge .NewOpenAIProvider (openaiCfg (srv .URL , apiKey ))}
407- b , err := aibridge .NewRequestBridge (t .Context (), providers , recorderClient , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
408- require .NoError (t , err )
352+ upstream := aibtest .NewMockUpstreamServer (t , ctx , oaiSingleBuiltinTool )
353+ reqBody := aibtest .SetStreamingInRequest (t , upstream .Files ()[aibtest .FixtureRequest ], tc .streaming )
409354
410- mockSrv := httptest .NewUnstartedServer (b )
411- t .Cleanup (mockSrv .Close )
412- mockSrv .Config .BaseContext = func (_ net.Listener ) context.Context {
413- return aibridge .AsActor (ctx , userID , nil )
414- }
415- mockSrv .Start ()
416- // Make API call to aibridge for OpenAI /v1/chat/completions
417- req := createOpenAIChatCompletionsReq (t , mockSrv .URL , reqBody )
355+ bridge := aibtest .NewTestBridge (t , ctx , aibtest.TestBridgeOptions {
356+ Provider : aibridge .NewOpenAIProvider (aibtest .OpenAIConfig (upstream .URL , aibtest .DefaultAPIKey )),
357+ })
418358
419- client := & http.Client {}
420- resp , err := client .Do (req )
421- require .NoError (t , err )
359+ resp := bridge .DoOpenAIRequest (t , reqBody )
422360 require .Equal (t , http .StatusOK , resp .StatusCode )
423361 defer resp .Body .Close ()
424362
@@ -436,20 +374,22 @@ func TestOpenAIChatCompletions(t *testing.T) {
436374 assert .Equal (t , "[DONE]" , lastEvent .Data )
437375 }
438376
439- require .Len (t , recorderClient . tokenUsages , 1 )
440- assert .EqualValues (t , tc .expectedInputTokens , calculateTotalInputTokens ( recorderClient . tokenUsages ), "input tokens miscalculated" )
441- assert .EqualValues (t , tc .expectedOutputTokens , calculateTotalOutputTokens ( recorderClient . tokenUsages ), "output tokens miscalculated" )
377+ require .Len (t , bridge . Recorder . TokenUsages () , 1 )
378+ assert .EqualValues (t , tc .expectedInputTokens , bridge . Recorder . TotalInputTokens ( ), "input tokens miscalculated" )
379+ assert .EqualValues (t , tc .expectedOutputTokens , bridge . Recorder . TotalOutputTokens ( ), "output tokens miscalculated" )
442380
443- require .Len (t , recorderClient .toolUsages , 1 )
444- assert .Equal (t , "read_file" , recorderClient .toolUsages [0 ].Tool )
445- require .IsType (t , map [string ]any {}, recorderClient .toolUsages [0 ].Args )
446- require .Contains (t , recorderClient .toolUsages [0 ].Args , "path" )
447- assert .Equal (t , "README.md" , recorderClient .toolUsages [0 ].Args .(map [string ]any )["path" ])
381+ toolUsages := bridge .Recorder .ToolUsages ()
382+ require .Len (t , toolUsages , 1 )
383+ assert .Equal (t , "read_file" , toolUsages [0 ].Tool )
384+ require .IsType (t , map [string ]any {}, toolUsages [0 ].Args )
385+ require .Contains (t , toolUsages [0 ].Args , "path" )
386+ assert .Equal (t , "README.md" , toolUsages [0 ].Args .(map [string ]any )["path" ])
448387
449- require .Len (t , recorderClient .userPrompts , 1 )
450- assert .Equal (t , "how large is the README.md file in my current path" , recorderClient .userPrompts [0 ].Prompt )
388+ promptUsages := bridge .Recorder .PromptUsages ()
389+ require .Len (t , promptUsages , 1 )
390+ assert .Equal (t , "how large is the README.md file in my current path" , promptUsages [0 ].Prompt )
451391
452- recorderClient . verifyAllInterceptionsEnded (t )
392+ bridge . Recorder . VerifyAllInterceptionsEnded (t )
453393 })
454394 }
455395 })
@@ -471,7 +411,7 @@ func TestSimple(t *testing.T) {
471411 fixture : antSimple ,
472412 configureFunc : func (addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
473413 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
474- provider := []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (addr , apiKey ), nil )}
414+ provider := []aibridge.Provider {aibridge .NewAnthropicProvider (aibtest . AnthropicConfig (addr , aibtest . DefaultAPIKey ), nil )}
475415 return aibridge .NewRequestBridge (t .Context (), provider , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
476416 },
477417 getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
@@ -502,15 +442,15 @@ func TestSimple(t *testing.T) {
502442 }
503443 return message .ID , nil
504444 },
505- createRequest : createAnthropicMessagesReq ,
445+ createRequest : aibtest . CreateAnthropicMessagesRequest ,
506446 expectedMsgID : "msg_01Pvyf26bY17RcjmWfJsXGBn" ,
507447 },
508448 {
509449 name : aibridge .ProviderOpenAI ,
510450 fixture : oaiSimple ,
511451 configureFunc : func (addr string , client aibridge.Recorder ) (* aibridge.RequestBridge , error ) {
512452 logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : false }).Leveled (slog .LevelDebug )
513- providers := []aibridge.Provider {aibridge .NewOpenAIProvider (openaiCfg (addr , apiKey ))}
453+ providers := []aibridge.Provider {aibridge .NewOpenAIProvider (aibtest . OpenAIConfig (addr , aibtest . DefaultAPIKey ))}
514454 return aibridge .NewRequestBridge (t .Context (), providers , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
515455 },
516456 getResponseIDFunc : func (streaming bool , resp * http.Response ) (string , error ) {
@@ -1373,7 +1313,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13731313 providers := []aibridge.Provider {aibridge .NewAnthropicProvider (anthropicCfg (addr , apiKey ), nil )}
13741314 return aibridge .NewRequestBridge (t .Context (), providers , client , mcp .NewServerProxyManager (nil , testTracer ), logger , nil , testTracer )
13751315 },
1376- createRequest : createAnthropicMessagesReq ,
1316+ createRequest : aibtest . CreateAnthropicMessagesRequest ,
13771317 envVars : map [string ]string {
13781318 "ANTHROPIC_AUTH_TOKEN" : "should-not-leak" ,
13791319 },
@@ -1687,8 +1627,6 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
16871627 }
16881628}
16891629
1690- const mockToolName = "coder_list_workspaces"
1691-
16921630// callAccumulator tracks all tool invocations by name and each instance's arguments.
16931631type callAccumulator struct {
16941632 calls map [string ][]any
0 commit comments