Skip to content

Commit 1b59025

Browse files
committed
refactor: migrate key tests to use aibtest package
Refactored the following tests to use the new aibtest helpers: - TestAnthropicMessages (single builtin tool) - TestAWSBedrockIntegration - TestOpenAIChatCompletions (single builtin tool) - TestMetrics_Interception The migration demonstrates the simplified patterns: - aibtest.NewMockUpstreamServer replaces manual txtar parsing + newMockServer - aibtest.NewTestBridge replaces manual bridge/server setup - aibtest.MockRecorder replaces mockRecorderClient - aibtest.SetStreamingInRequest replaces setJSON - aibtest.CreateAnthropicMessagesRequest/CreateOpenAIChatCompletionsRequest replace createAnthropicMessagesReq/createOpenAIChatCompletionsReq - aibtest.AnthropicConfig/OpenAIConfig replace anthropicCfg/openaiCfg The old helpers are preserved as 'legacy' for gradual migration of remaining tests. New tests should use the aibtest package. Key improvements demonstrated: - ~30 lines of setup reduced to ~5 lines per test - Recorder access via bridge.Recorder instead of separate variable - Automatic cleanup handling via TestBridge Continues work on #73
1 parent 1d2f22a commit 1b59025

File tree

2 files changed

+78
-143
lines changed

2 files changed

+78
-143
lines changed

bridge_integration_test.go

Lines changed: 70 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/anthropics/anthropic-sdk-go"
2424
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
2525
"github.com/coder/aibridge"
26+
"github.com/coder/aibridge/aibtest"
2627
"github.com/coder/aibridge/mcp"
2728
"github.com/google/uuid"
2829
mcplib "github.com/mark3labs/mcp-go/mcp"
@@ -33,7 +34,6 @@ import (
3334
"github.com/stretchr/testify/require"
3435
"github.com/tidwall/gjson"
3536
"github.com/tidwall/sjson"
36-
"go.opentelemetry.io/otel"
3737
"go.opentelemetry.io/otel/trace"
3838
"go.uber.org/goleak"
3939
"golang.org/x/tools/txtar"
@@ -66,19 +66,23 @@ var (
6666
//go:embed fixtures/openai/non_stream_error.txtar
6767
oaiNonStreamErr []byte
6868

69-
testTracer = otel.Tracer("forTesting")
69+
testTracer = aibtest.TestTracer()
7070
)
7171

7272
const (
73+
// Legacy constants - use aibtest.Fixture* in new code.
7374
fixtureRequest = "request"
7475
fixtureStreamingResponse = "streaming"
7576
fixtureNonStreamingResponse = "non-streaming"
7677
fixtureStreamingToolResponse = "streaming/tool-call"
7778
fixtureNonStreamingToolResponse = "non-streaming/tool-call"
7879
fixtureResponse = "response"
7980

81+
// Legacy constants - use aibtest.DefaultAPIKey/DefaultUserID in new code.
8082
apiKey = "api-key"
8183
userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e"
84+
85+
mockToolName = "coder_list_workspaces"
8286
)
8387

8488
func TestMain(m *testing.M) {
@@ -112,46 +116,17 @@ func TestAnthropicMessages(t *testing.T) {
112116
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) {
113117
t.Parallel()
114118

115-
arc := txtar.Parse(antSingleBuiltinTool)
116-
t.Logf("%s: %s", t.Name(), arc.Comment)
117-
118-
files := filesMap(arc)
119-
require.Len(t, files, 3)
120-
require.Contains(t, files, fixtureRequest)
121-
require.Contains(t, files, fixtureStreamingResponse)
122-
require.Contains(t, files, fixtureNonStreamingResponse)
123-
124-
reqBody := files[fixtureRequest]
125-
126-
// Add the stream param to the request.
127-
newBody, err := setJSON(reqBody, "stream", tc.streaming)
128-
require.NoError(t, err)
129-
reqBody = newBody
130-
131119
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
132120
t.Cleanup(cancel)
133-
srv := newMockServer(ctx, t, files, nil)
134-
t.Cleanup(srv.Close)
135121

136-
recorderClient := &mockRecorderClient{}
137-
138-
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
139-
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}
140-
b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
141-
require.NoError(t, err)
122+
upstream := aibtest.NewMockUpstreamServer(t, ctx, antSingleBuiltinTool)
123+
reqBody := aibtest.SetStreamingInRequest(t, upstream.Files()[aibtest.FixtureRequest], tc.streaming)
142124

143-
mockSrv := httptest.NewUnstartedServer(b)
144-
t.Cleanup(mockSrv.Close)
145-
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
146-
return aibridge.AsActor(ctx, userID, nil)
147-
}
148-
mockSrv.Start()
125+
bridge := aibtest.NewTestBridge(t, ctx, aibtest.TestBridgeOptions{
126+
Provider: aibridge.NewAnthropicProvider(aibtest.AnthropicConfig(upstream.URL, aibtest.DefaultAPIKey), nil),
127+
})
149128

150-
// Make API call to aibridge for Anthropic /v1/messages
151-
req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody)
152-
client := &http.Client{}
153-
resp, err := client.Do(req)
154-
require.NoError(t, err)
129+
resp := bridge.DoAnthropicRequest(t, reqBody)
155130
require.Equal(t, http.StatusOK, resp.StatusCode)
156131
defer resp.Body.Close()
157132

@@ -170,23 +145,25 @@ func TestAnthropicMessages(t *testing.T) {
170145
// One for message_start, one for message_delta.
171146
expectedTokenRecordings = 2
172147
}
173-
require.Len(t, recorderClient.tokenUsages, expectedTokenRecordings)
148+
require.Len(t, bridge.Recorder.TokenUsages(), expectedTokenRecordings)
174149

175-
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
176-
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
150+
assert.EqualValues(t, tc.expectedInputTokens, bridge.Recorder.TotalInputTokens(), "input tokens miscalculated")
151+
assert.EqualValues(t, tc.expectedOutputTokens, bridge.Recorder.TotalOutputTokens(), "output tokens miscalculated")
177152

178-
require.Len(t, recorderClient.toolUsages, 1)
179-
assert.Equal(t, "Read", recorderClient.toolUsages[0].Tool)
180-
require.IsType(t, json.RawMessage{}, recorderClient.toolUsages[0].Args)
153+
toolUsages := bridge.Recorder.ToolUsages()
154+
require.Len(t, toolUsages, 1)
155+
assert.Equal(t, "Read", toolUsages[0].Tool)
156+
require.IsType(t, json.RawMessage{}, toolUsages[0].Args)
181157
var args map[string]any
182-
require.NoError(t, json.Unmarshal(recorderClient.toolUsages[0].Args.(json.RawMessage), &args))
158+
require.NoError(t, json.Unmarshal(toolUsages[0].Args.(json.RawMessage), &args))
183159
require.Contains(t, args, "file_path")
184160
assert.Equal(t, "/tmp/blah/foo", args["file_path"])
185161

186-
require.Len(t, recorderClient.userPrompts, 1)
187-
assert.Equal(t, "read the foo file", recorderClient.userPrompts[0].Prompt)
162+
promptUsages := bridge.Recorder.PromptUsages()
163+
require.Len(t, promptUsages, 1)
164+
assert.Equal(t, "read the foo file", promptUsages[0].Prompt)
188165

189-
recorderClient.verifyAllInterceptionsEnded(t)
166+
bridge.Recorder.VerifyAllInterceptionsEnded(t)
190167
})
191168
}
192169
})
@@ -198,9 +175,8 @@ func TestAWSBedrockIntegration(t *testing.T) {
198175
t.Run("invalid config", func(t *testing.T) {
199176
t.Parallel()
200177

201-
arc := txtar.Parse(antSingleBuiltinTool)
202-
files := filesMap(arc)
203-
reqBody := files[fixtureRequest]
178+
upstream := aibtest.NewMockUpstreamServer(t, t.Context(), antSingleBuiltinTool)
179+
reqBody := upstream.Files()[aibtest.FixtureRequest]
204180

205181
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
206182
t.Cleanup(cancel)
@@ -214,21 +190,21 @@ func TestAWSBedrockIntegration(t *testing.T) {
214190
SmallFastModel: "test-haiku",
215191
}
216192

217-
recorderClient := &mockRecorderClient{}
193+
recorder := aibtest.NewMockRecorder()
218194
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
219195
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{
220-
aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg),
221-
}, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
196+
aibridge.NewAnthropicProvider(aibtest.AnthropicConfig("http://unused", aibtest.DefaultAPIKey), bedrockCfg),
197+
}, recorder, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
222198
require.NoError(t, err)
223199

224200
mockSrv := httptest.NewUnstartedServer(b)
225201
t.Cleanup(mockSrv.Close)
226202
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
227-
return aibridge.AsActor(ctx, userID, nil)
203+
return aibridge.AsActor(ctx, aibtest.DefaultUserID, nil)
228204
}
229205
mockSrv.Start()
230206

231-
req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody)
207+
req := aibtest.CreateAnthropicMessagesRequest(t, mockSrv.URL, reqBody)
232208
resp, err := http.DefaultClient.Do(req)
233209
require.NoError(t, err)
234210
defer resp.Body.Close()
@@ -245,18 +221,9 @@ func TestAWSBedrockIntegration(t *testing.T) {
245221
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) {
246222
t.Parallel()
247223

248-
arc := txtar.Parse(antSingleBuiltinTool)
249-
t.Logf("%s: %s", t.Name(), arc.Comment)
250-
251-
files := filesMap(arc)
252-
require.Len(t, files, 3)
253-
require.Contains(t, files, fixtureRequest)
254-
255-
reqBody := files[fixtureRequest]
256-
257-
newBody, err := setJSON(reqBody, "stream", streaming)
258-
require.NoError(t, err)
259-
reqBody = newBody
224+
upstream := aibtest.NewMockUpstreamServer(t, t.Context(), antSingleBuiltinTool)
225+
files := upstream.Files()
226+
reqBody := aibtest.SetStreamingInRequest(t, files[aibtest.FixtureRequest], streaming)
260227

261228
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
262229
t.Cleanup(cancel)
@@ -265,6 +232,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
265232
var requestCount int
266233

267234
// Create a mock server that intercepts requests to capture model name and return fixtures.
235+
// This is specific to testing Bedrock URL routing (model in path).
268236
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
269237
requestCount++
270238
t.Logf("Mock server received request #%d: %s %s (streaming=%v)", requestCount, r.Method, r.URL.Path, streaming)
@@ -281,12 +249,12 @@ func TestAWSBedrockIntegration(t *testing.T) {
281249
// Return appropriate fixture response.
282250
var respBody []byte
283251
if streaming {
284-
respBody = files[fixtureStreamingResponse]
252+
respBody = files[aibtest.FixtureStreamingResponse]
285253
w.Header().Set("Content-Type", "text/event-stream")
286254
w.Header().Set("Cache-Control", "no-cache")
287255
w.Header().Set("Connection", "keep-alive")
288256
} else {
289-
respBody = files[fixtureNonStreamingResponse]
257+
respBody = files[aibtest.FixtureNonStreamingResponse]
290258
w.Header().Set("Content-Type", "application/json")
291259
}
292260

@@ -311,26 +279,24 @@ func TestAWSBedrockIntegration(t *testing.T) {
311279
EndpointOverride: srv.URL,
312280
}
313281

314-
recorderClient := &mockRecorderClient{}
315-
282+
recorder := aibtest.NewMockRecorder()
316283
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
317284
b, err := aibridge.NewRequestBridge(
318-
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)},
319-
recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
285+
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(aibtest.AnthropicConfig(srv.URL, aibtest.DefaultAPIKey), bedrockCfg)},
286+
recorder, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
320287
require.NoError(t, err)
321288

322289
mockBridgeSrv := httptest.NewUnstartedServer(b)
323290
t.Cleanup(mockBridgeSrv.Close)
324291
mockBridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context {
325-
return aibridge.AsActor(ctx, userID, nil)
292+
return aibridge.AsActor(ctx, aibtest.DefaultUserID, nil)
326293
}
327294
mockBridgeSrv.Start()
328295

329296
// Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock.
330297
// We override the AWS Bedrock client to route requests through our mock server.
331-
req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody)
332-
client := &http.Client{}
333-
resp, err := client.Do(req)
298+
req := aibtest.CreateAnthropicMessagesRequest(t, mockBridgeSrv.URL, reqBody)
299+
resp, err := http.DefaultClient.Do(req)
334300
require.NoError(t, err)
335301
defer resp.Body.Close()
336302

@@ -345,9 +311,10 @@ func TestAWSBedrockIntegration(t *testing.T) {
345311
// and the interception data.
346312
require.Equal(t, requestCount, 1)
347313
require.Equal(t, bedrockCfg.Model, receivedModelName)
348-
require.Len(t, recorderClient.interceptions, 1)
349-
require.Equal(t, recorderClient.interceptions[0].Model, bedrockCfg.Model)
350-
recorderClient.verifyAllInterceptionsEnded(t)
314+
interceptions := recorder.Interceptions()
315+
require.Len(t, interceptions, 1)
316+
require.Equal(t, interceptions[0].Model, bedrockCfg.Model)
317+
recorder.VerifyAllInterceptionsEnded(t)
351318
})
352319
}
353320
})
@@ -379,46 +346,17 @@ func TestOpenAIChatCompletions(t *testing.T) {
379346
t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) {
380347
t.Parallel()
381348

382-
arc := txtar.Parse(oaiSingleBuiltinTool)
383-
t.Logf("%s: %s", t.Name(), arc.Comment)
384-
385-
files := filesMap(arc)
386-
require.Len(t, files, 3)
387-
require.Contains(t, files, fixtureRequest)
388-
require.Contains(t, files, fixtureStreamingResponse)
389-
require.Contains(t, files, fixtureNonStreamingResponse)
390-
391-
reqBody := files[fixtureRequest]
392-
393-
// Add the stream param to the request.
394-
newBody, err := setJSON(reqBody, "stream", tc.streaming)
395-
require.NoError(t, err)
396-
reqBody = newBody
397-
398349
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
399350
t.Cleanup(cancel)
400-
srv := newMockServer(ctx, t, files, nil)
401-
t.Cleanup(srv.Close)
402-
403-
recorderClient := &mockRecorderClient{}
404351

405-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
406-
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}
407-
b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
408-
require.NoError(t, err)
352+
upstream := aibtest.NewMockUpstreamServer(t, ctx, oaiSingleBuiltinTool)
353+
reqBody := aibtest.SetStreamingInRequest(t, upstream.Files()[aibtest.FixtureRequest], tc.streaming)
409354

410-
mockSrv := httptest.NewUnstartedServer(b)
411-
t.Cleanup(mockSrv.Close)
412-
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
413-
return aibridge.AsActor(ctx, userID, nil)
414-
}
415-
mockSrv.Start()
416-
// Make API call to aibridge for OpenAI /v1/chat/completions
417-
req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody)
355+
bridge := aibtest.NewTestBridge(t, ctx, aibtest.TestBridgeOptions{
356+
Provider: aibridge.NewOpenAIProvider(aibtest.OpenAIConfig(upstream.URL, aibtest.DefaultAPIKey)),
357+
})
418358

419-
client := &http.Client{}
420-
resp, err := client.Do(req)
421-
require.NoError(t, err)
359+
resp := bridge.DoOpenAIRequest(t, reqBody)
422360
require.Equal(t, http.StatusOK, resp.StatusCode)
423361
defer resp.Body.Close()
424362

@@ -436,20 +374,22 @@ func TestOpenAIChatCompletions(t *testing.T) {
436374
assert.Equal(t, "[DONE]", lastEvent.Data)
437375
}
438376

439-
require.Len(t, recorderClient.tokenUsages, 1)
440-
assert.EqualValues(t, tc.expectedInputTokens, calculateTotalInputTokens(recorderClient.tokenUsages), "input tokens miscalculated")
441-
assert.EqualValues(t, tc.expectedOutputTokens, calculateTotalOutputTokens(recorderClient.tokenUsages), "output tokens miscalculated")
377+
require.Len(t, bridge.Recorder.TokenUsages(), 1)
378+
assert.EqualValues(t, tc.expectedInputTokens, bridge.Recorder.TotalInputTokens(), "input tokens miscalculated")
379+
assert.EqualValues(t, tc.expectedOutputTokens, bridge.Recorder.TotalOutputTokens(), "output tokens miscalculated")
442380

443-
require.Len(t, recorderClient.toolUsages, 1)
444-
assert.Equal(t, "read_file", recorderClient.toolUsages[0].Tool)
445-
require.IsType(t, map[string]any{}, recorderClient.toolUsages[0].Args)
446-
require.Contains(t, recorderClient.toolUsages[0].Args, "path")
447-
assert.Equal(t, "README.md", recorderClient.toolUsages[0].Args.(map[string]any)["path"])
381+
toolUsages := bridge.Recorder.ToolUsages()
382+
require.Len(t, toolUsages, 1)
383+
assert.Equal(t, "read_file", toolUsages[0].Tool)
384+
require.IsType(t, map[string]any{}, toolUsages[0].Args)
385+
require.Contains(t, toolUsages[0].Args, "path")
386+
assert.Equal(t, "README.md", toolUsages[0].Args.(map[string]any)["path"])
448387

449-
require.Len(t, recorderClient.userPrompts, 1)
450-
assert.Equal(t, "how large is the README.md file in my current path", recorderClient.userPrompts[0].Prompt)
388+
promptUsages := bridge.Recorder.PromptUsages()
389+
require.Len(t, promptUsages, 1)
390+
assert.Equal(t, "how large is the README.md file in my current path", promptUsages[0].Prompt)
451391

452-
recorderClient.verifyAllInterceptionsEnded(t)
392+
bridge.Recorder.VerifyAllInterceptionsEnded(t)
453393
})
454394
}
455395
})
@@ -471,7 +411,7 @@ func TestSimple(t *testing.T) {
471411
fixture: antSimple,
472412
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
473413
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
474-
provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
414+
provider := []aibridge.Provider{aibridge.NewAnthropicProvider(aibtest.AnthropicConfig(addr, aibtest.DefaultAPIKey), nil)}
475415
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
476416
},
477417
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
@@ -502,15 +442,15 @@ func TestSimple(t *testing.T) {
502442
}
503443
return message.ID, nil
504444
},
505-
createRequest: createAnthropicMessagesReq,
445+
createRequest: aibtest.CreateAnthropicMessagesRequest,
506446
expectedMsgID: "msg_01Pvyf26bY17RcjmWfJsXGBn",
507447
},
508448
{
509449
name: aibridge.ProviderOpenAI,
510450
fixture: oaiSimple,
511451
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
512452
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
513-
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
453+
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(aibtest.OpenAIConfig(addr, aibtest.DefaultAPIKey))}
514454
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
515455
},
516456
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
@@ -1373,7 +1313,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13731313
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
13741314
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
13751315
},
1376-
createRequest: createAnthropicMessagesReq,
1316+
createRequest: aibtest.CreateAnthropicMessagesRequest,
13771317
envVars: map[string]string{
13781318
"ANTHROPIC_AUTH_TOKEN": "should-not-leak",
13791319
},
@@ -1687,8 +1627,6 @@ func (m *mockRecorderClient) verifyAllInterceptionsEnded(t *testing.T) {
16871627
}
16881628
}
16891629

1690-
const mockToolName = "coder_list_workspaces"
1691-
16921630
// callAccumulator tracks all tool invocations by name and each instance's arguments.
16931631
type callAccumulator struct {
16941632
calls map[string][]any

0 commit comments

Comments
 (0)