Skip to content

Commit ba198a7

Browse files
authored
fix: prevent additional headers from being added to upstream requests (#61)
* chore: drive-by refactor Signed-off-by: Danny Kopping <danny@coder.com> * chore: prevent default options adding unwanted headers Signed-off-by: Danny Kopping <danny@coder.com> * chore: add test Signed-off-by: Danny Kopping <danny@coder.com> --------- Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 64eff22 commit ba198a7

7 files changed

+117
-22
lines changed

bridge_integration_test.go

Lines changed: 103 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ func TestSimple(t *testing.T) {
504504
fixture: oaiSimple,
505505
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
506506
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
507-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, mcp.NewServerProxyManager(nil))
507+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
508508
},
509509
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
510510
if streaming {
@@ -655,7 +655,7 @@ func TestFallthrough(t *testing.T) {
655655
fixture: oaiFallthrough,
656656
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
657657
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
658-
provider := aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))
658+
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
659659
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
660660
require.NoError(t, err)
661661
return provider, bridge
@@ -843,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
843843

844844
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
845845
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
846-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
846+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
847847
}
848848

849849
// Build the requirements & make the assertions which are common to all providers.
@@ -1046,7 +1046,7 @@ func TestErrorHandling(t *testing.T) {
10461046
createRequestFunc: createOpenAIChatCompletionsReq,
10471047
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
10481048
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1049-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1049+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
10501050
},
10511051
responseHandlerFn: func(resp *http.Response) {
10521052
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
@@ -1152,7 +1152,7 @@ func TestErrorHandling(t *testing.T) {
11521152
createRequestFunc: createOpenAIChatCompletionsReq,
11531153
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
11541154
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
1155-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1155+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
11561156
},
11571157
responseHandlerFn: func(resp *http.Response) {
11581158
// Server responds first with 200 OK then starts streaming.
@@ -1246,7 +1246,7 @@ func TestStableRequestEncoding(t *testing.T) {
12461246
fixture: oaiSimple,
12471247
createRequestFunc: createOpenAIChatCompletionsReq,
12481248
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
1249-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(aibridge.OpenAIConfig(anthropicCfg(addr, apiKey)))}, logger, client, srvProxyMgr)
1249+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, srvProxyMgr)
12501250
},
12511251
},
12521252
}
@@ -1334,6 +1334,103 @@ func TestStableRequestEncoding(t *testing.T) {
13341334
}
13351335
}
13361336

1337+
func TestEnvironmentDoNotLeak(t *testing.T) {
1338+
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
1339+
1340+
// Test that environment variables containing API keys/tokens are not leaked to upstream requests.
1341+
// See https://github.com/coder/aibridge/issues/60.
1342+
testCases := []struct {
1343+
name string
1344+
fixture []byte
1345+
configureFunc func(string, aibridge.Recorder) (*aibridge.RequestBridge, error)
1346+
createRequest func(*testing.T, string, []byte) *http.Request
1347+
envVars map[string]string
1348+
headerName string
1349+
}{
1350+
{
1351+
name: aibridge.ProviderAnthropic,
1352+
fixture: antSimple,
1353+
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
1354+
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
1355+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}, logger, client, mcp.NewServerProxyManager(nil))
1356+
},
1357+
createRequest: createAnthropicMessagesReq,
1358+
envVars: map[string]string{
1359+
"ANTHROPIC_AUTH_TOKEN": "should-not-leak",
1360+
},
1361+
headerName: "Authorization", // We only send through the X-Api-Key, so this one should not be present.
1362+
},
1363+
{
1364+
name: aibridge.ProviderOpenAI,
1365+
fixture: oaiSimple,
1366+
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
1367+
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
1368+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
1369+
},
1370+
createRequest: createOpenAIChatCompletionsReq,
1371+
envVars: map[string]string{
1372+
"OPENAI_ORG_ID": "should-not-leak",
1373+
},
1374+
headerName: "OpenAI-Organization",
1375+
},
1376+
}
1377+
1378+
for _, tc := range testCases {
1379+
t.Run(tc.name, func(t *testing.T) {
1380+
// NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution.
1381+
1382+
arc := txtar.Parse(tc.fixture)
1383+
files := filesMap(arc)
1384+
reqBody := files[fixtureRequest]
1385+
1386+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1387+
t.Cleanup(cancel)
1388+
1389+
// Track headers received by the upstream server.
1390+
var receivedHeaders http.Header
1391+
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1392+
receivedHeaders = r.Header.Clone()
1393+
w.Header().Set("Content-Type", "application/json")
1394+
w.WriteHeader(http.StatusOK)
1395+
_, _ = w.Write(files[fixtureNonStreamingResponse])
1396+
}))
1397+
srv.Config.BaseContext = func(_ net.Listener) context.Context {
1398+
return ctx
1399+
}
1400+
srv.Start()
1401+
t.Cleanup(srv.Close)
1402+
1403+
// Set environment variables that the SDK would automatically read.
1404+
// These should NOT leak into upstream requests.
1405+
for key, val := range tc.envVars {
1406+
t.Setenv(key, val)
1407+
}
1408+
1409+
recorderClient := &mockRecorderClient{}
1410+
b, err := tc.configureFunc(srv.URL, recorderClient)
1411+
require.NoError(t, err)
1412+
1413+
mockSrv := httptest.NewUnstartedServer(b)
1414+
t.Cleanup(mockSrv.Close)
1415+
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
1416+
return aibridge.AsActor(ctx, userID, nil)
1417+
}
1418+
mockSrv.Start()
1419+
1420+
req := tc.createRequest(t, mockSrv.URL, reqBody)
1421+
client := &http.Client{}
1422+
resp, err := client.Do(req)
1423+
require.NoError(t, err)
1424+
require.Equal(t, http.StatusOK, resp.StatusCode)
1425+
defer resp.Body.Close()
1426+
1427+
// Verify that environment values did not leak.
1428+
require.NotNil(t, receivedHeaders)
1429+
require.Empty(t, receivedHeaders.Get(tc.headerName))
1430+
})
1431+
}
1432+
}
1433+
13371434
func calculateTotalInputTokens(in []*aibridge.TokenUsageRecord) int64 {
13381435
var total int64
13391436
for _, el := range in {

intercept_anthropic_messages_base.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func (i *AnthropicMessagesInterceptionBase) isSmallFastModel() bool {
9696
return strings.Contains(string(i.req.Model), "haiku")
9797
}
9898

99-
func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Context, opts ...option.RequestOption) (anthropic.Client, error) {
99+
func (i *AnthropicMessagesInterceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
100100
opts = append(opts, option.WithAPIKey(i.cfg.Key))
101101
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
102102

@@ -105,7 +105,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte
105105
defer cancel()
106106
bedrockOpt, err := i.withAWSBedrock(ctx, i.bedrockCfg)
107107
if err != nil {
108-
return anthropic.Client{}, err
108+
return anthropic.MessageService{}, err
109109
}
110110
opts = append(opts, bedrockOpt)
111111
i.augmentRequestForBedrock()
@@ -122,7 +122,7 @@ func (i *AnthropicMessagesInterceptionBase) newAnthropicClient(ctx context.Conte
122122
}
123123
}
124124

125-
return anthropic.NewClient(opts...), nil
125+
return anthropic.NewMessageService(opts...), nil
126126
}
127127

128128
func (i *AnthropicMessagesInterceptionBase) withAWSBedrock(ctx context.Context, cfg *AWSBedrockConfig) (option.RequestOption, error) {

intercept_anthropic_messages_blocking.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
5858

5959
opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout
6060

61-
client, err := i.newAnthropicClient(ctx, opts...)
61+
svc, err := i.newMessagesService(ctx, opts...)
6262
if err != nil {
6363
err = fmt.Errorf("create anthropic client: %w", err)
6464
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -73,7 +73,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr
7373
var cumulativeUsage anthropic.Usage
7474

7575
for {
76-
resp, err = client.Messages.New(ctx, messages)
76+
resp, err = svc.New(ctx, messages)
7777
if err != nil {
7878
if isConnError(err) {
7979
// Can't write a response, just error out.

intercept_anthropic_messages_streaming.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW
8888
streamCtx, streamCancel := context.WithCancelCause(ctx)
8989
defer streamCancel(errors.New("deferred"))
9090

91-
client, err := i.newAnthropicClient(streamCtx)
91+
svc, err := i.newMessagesService(streamCtx)
9292
if err != nil {
9393
err = fmt.Errorf("create anthropic client: %w", err)
9494
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -118,7 +118,7 @@ newStream:
118118
break
119119
}
120120

121-
stream := client.Messages.NewStreaming(streamCtx, messages)
121+
stream := svc.NewStreaming(streamCtx, messages)
122122

123123
var message anthropic.Message
124124
var lastToolName string

intercept_openai_chat_base.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ type OpenAIChatInterceptionBase struct {
2626
mcpProxy mcp.ServerProxier
2727
}
2828

29-
func (i *OpenAIChatInterceptionBase) newOpenAIClient(baseURL, key string) openai.Client {
30-
var opts []option.RequestOption
31-
opts = append(opts, option.WithAPIKey(key))
32-
opts = append(opts, option.WithBaseURL(baseURL))
29+
func (i *OpenAIChatInterceptionBase) newCompletionsService(baseURL, key string) openai.ChatCompletionService {
30+
opts := []option.RequestOption{option.WithAPIKey(key), option.WithBaseURL(baseURL)}
3331

34-
return openai.NewClient(opts...)
32+
return openai.NewChatCompletionService(opts...)
3533
}
3634

3735
func (i *OpenAIChatInterceptionBase) ID() uuid.UUID {

intercept_openai_chat_blocking.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
4141
}
4242

4343
ctx := r.Context()
44-
client := i.newOpenAIClient(i.baseURL, i.key)
44+
svc := i.newCompletionsService(i.baseURL, i.key)
4545
logger := i.logger.With(slog.F("model", i.req.Model))
4646

4747
var (
@@ -61,7 +61,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
6161
var opts []option.RequestOption
6262
opts = append(opts, option.WithRequestTimeout(time.Second*60)) // TODO: configurable timeout
6363

64-
completion, err = client.Chat.Completions.New(ctx, i.req.ChatCompletionNewParams, opts...)
64+
completion, err = svc.New(ctx, i.req.ChatCompletionNewParams, opts...)
6565
if err != nil {
6666
break
6767
}

intercept_openai_chat_streaming.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
6565
defer cancel()
6666
r = r.WithContext(ctx) // Rewire context for SSE cancellation.
6767

68-
client := i.newOpenAIClient(i.baseURL, i.key)
68+
svc := i.newCompletionsService(i.baseURL, i.key)
6969
logger := i.logger.With(slog.F("model", i.req.Model))
7070

7171
streamCtx, streamCancel := context.WithCancelCause(ctx)
@@ -100,7 +100,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
100100
interceptionErr error
101101
)
102102
for {
103-
stream = client.Chat.Completions.NewStreaming(streamCtx, i.req.ChatCompletionNewParams)
103+
stream = svc.NewStreaming(streamCtx, i.req.ChatCompletionNewParams)
104104
processor := newStreamProcessor(streamCtx, i.logger.Named("stream-processor"), i.getInjectedToolByName)
105105

106106
var toolCall *openai.FinishedChatCompletionToolCall

0 commit comments

Comments
 (0)