Skip to content

Commit d5b0614

Browse files
committed
review 1: streaming upstream fix, tool attrs, request path
1 parent 597f3c0 commit d5b0614

18 files changed

+239
-157
lines changed

aibtrace/aibtrace.go

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,36 @@ import (
99
)
1010

1111
type traceInterceptionAttrsContextKey struct{}
12+
type traceRequestBridgeAttrsContextKey struct{}
1213

1314
const (
1415
// trace attribute key constants
15-
InterceptionID = "interception_id"
16-
UserID = "user_id"
17-
Provider = "provider"
18-
Model = "model"
19-
Streaming = "streaming"
20-
IsBedrock = "aws_bedrock"
21-
MCPProxyName = "mcp_proxy_name"
22-
MCPToolName = "mcp_tool_name"
16+
RequestPath = "request_path"
17+
18+
InterceptionID = "interception_id"
19+
UserID = "user_id"
20+
Provider = "provider"
21+
Model = "model"
22+
Streaming = "streaming"
23+
IsBedrock = "aws_bedrock"
24+
2325
PassthroughURL = "passthrough_url"
2426
PassthroughMethod = "passthrough_method"
27+
28+
MCPInput = "mcp_input"
29+
MCPProxyName = "mcp_proxy_name"
30+
MCPToolName = "mcp_tool_name"
31+
MCPServerName = "mcp_server_name"
32+
MCPServerURL = "mcp_server_url"
33+
34+
APIKeyID = "api_key_id"
2535
)
2636

27-
func WithTraceInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
37+
func WithInterceptionAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
2838
return context.WithValue(ctx, traceInterceptionAttrsContextKey{}, traceAttrs)
2939
}
3040

31-
func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue {
41+
func InterceptionAttributesFromContext(ctx context.Context) []attribute.KeyValue {
3242
attrs, ok := ctx.Value(traceInterceptionAttrsContextKey{}).([]attribute.KeyValue)
3343
if !ok {
3444
return nil
@@ -37,6 +47,19 @@ func TraceInterceptionAttributesFromContext(ctx context.Context) []attribute.Key
3747
return attrs
3848
}
3949

50+
func WithRequestBridgeAttributesInContext(ctx context.Context, traceAttrs []attribute.KeyValue) context.Context {
51+
return context.WithValue(ctx, traceRequestBridgeAttrsContextKey{}, traceAttrs)
52+
}
53+
54+
func RequestBridgeAttributesFromContext(ctx context.Context) []attribute.KeyValue {
55+
attrs, ok := ctx.Value(traceRequestBridgeAttrsContextKey{}).([]attribute.KeyValue)
56+
if !ok {
57+
return nil
58+
}
59+
60+
return attrs
61+
}
62+
4063
func EndSpanErr(span trace.Span, err *error) {
4164
if span == nil {
4265
return

bridge_integration_test.go

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"github.com/tidwall/gjson"
3535
"github.com/tidwall/sjson"
3636
"go.opentelemetry.io/otel"
37+
"go.opentelemetry.io/otel/trace"
3738
"go.uber.org/goleak"
3839
"golang.org/x/tools/txtar"
3940
)
@@ -65,7 +66,7 @@ var (
6566
//go:embed fixtures/openai/non_stream_error.txtar
6667
oaiNonStreamErr []byte
6768

68-
defaultTracer = otel.Tracer("github.com/coder/aibridge")
69+
testTracer = otel.Tracer("forTesting")
6970
)
7071

7172
const (
@@ -136,7 +137,7 @@ func TestAnthropicMessages(t *testing.T) {
136137

137138
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
138139
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), nil)}
139-
b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
140+
b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
140141
require.NoError(t, err)
141142

142143
mockSrv := httptest.NewUnstartedServer(b)
@@ -217,7 +218,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
217218
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
218219
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{
219220
aibridge.NewAnthropicProvider(anthropicCfg("http://unused", apiKey), bedrockCfg),
220-
}, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
221+
}, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
221222
require.NoError(t, err)
222223

223224
mockSrv := httptest.NewUnstartedServer(b)
@@ -315,7 +316,7 @@ func TestAWSBedrockIntegration(t *testing.T) {
315316
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
316317
b, err := aibridge.NewRequestBridge(
317318
ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(srv.URL, apiKey), bedrockCfg)},
318-
recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
319+
recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
319320
require.NoError(t, err)
320321

321322
mockBridgeSrv := httptest.NewUnstartedServer(b)
@@ -403,7 +404,7 @@ func TestOpenAIChatCompletions(t *testing.T) {
403404

404405
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
405406
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(srv.URL, apiKey))}
406-
b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
407+
b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
407408
require.NoError(t, err)
408409

409410
mockSrv := httptest.NewUnstartedServer(b)
@@ -471,7 +472,7 @@ func TestSimple(t *testing.T) {
471472
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
472473
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
473474
provider := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
474-
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
475+
return aibridge.NewRequestBridge(t.Context(), provider, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
475476
},
476477
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
477478
if streaming {
@@ -510,7 +511,7 @@ func TestSimple(t *testing.T) {
510511
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
511512
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
512513
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
513-
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
514+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
514515
},
515516
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
516517
if streaming {
@@ -642,7 +643,7 @@ func TestFallthrough(t *testing.T) {
642643
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
643644
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
644645
provider := aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)
645-
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
646+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
646647
require.NoError(t, err)
647648
return provider, bridge
648649
},
@@ -653,7 +654,7 @@ func TestFallthrough(t *testing.T) {
653654
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
654655
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
655656
provider := aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))
656-
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
657+
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
657658
require.NoError(t, err)
658659
return provider, bridge
659660
},
@@ -724,15 +725,15 @@ func TestFallthrough(t *testing.T) {
724725
}
725726

726727
// setupMCPServerProxiesForTest creates a mock MCP server, initializes the MCP bridge, and returns the tools
727-
func setupMCPServerProxiesForTest(t *testing.T) map[string]mcp.ServerProxier {
728+
func setupMCPServerProxiesForTest(t *testing.T, tracer trace.Tracer) map[string]mcp.ServerProxier {
728729
t.Helper()
729730

730731
// Setup Coder MCP integration
731732
mcpSrv := httptest.NewServer(createMockMCPSrv(t))
732733
t.Cleanup(mcpSrv.Close)
733734

734735
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
735-
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, "coder", mcpSrv.URL, nil, nil, nil)
736+
proxy, err := mcp.NewStreamableHTTPServerProxy(logger, tracer, "coder", mcpSrv.URL, nil, nil, nil)
736737
require.NoError(t, err)
737738

738739
// Initialize MCP client, fetch tools, and inject into bridge
@@ -760,7 +761,7 @@ func TestAnthropicInjectedTools(t *testing.T) {
760761
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
761762
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
762763
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
763-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
764+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
764765
}
765766

766767
// Build the requirements & make the assertions which are common to all providers.
@@ -842,7 +843,7 @@ func TestOpenAIInjectedTools(t *testing.T) {
842843
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
843844
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
844845
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
845-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
846+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
846847
}
847848

848849
// Build the requirements & make the assertions which are common to all providers.
@@ -977,10 +978,10 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu
977978
recorderClient := &mockRecorderClient{}
978979

979980
// Setup MCP tools.
980-
tools := setupMCPServerProxiesForTest(t)
981+
tools := setupMCPServerProxiesForTest(t, testTracer)
981982

982983
// Configure the bridge with injected tools.
983-
mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer)
984+
mcpMgr := mcp.NewServerProxyManager(tools, testTracer)
984985
require.NoError(t, mcpMgr.Init(ctx))
985986
b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr)
986987
require.NoError(t, err)
@@ -1029,7 +1030,7 @@ func TestErrorHandling(t *testing.T) {
10291030
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
10301031
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
10311032
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1032-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1033+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
10331034
},
10341035
responseHandlerFn: func(resp *http.Response) {
10351036
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
@@ -1047,7 +1048,7 @@ func TestErrorHandling(t *testing.T) {
10471048
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
10481049
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
10491050
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1050-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1051+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
10511052
},
10521053
responseHandlerFn: func(resp *http.Response) {
10531054
require.Equal(t, http.StatusBadRequest, resp.StatusCode)
@@ -1096,7 +1097,7 @@ func TestErrorHandling(t *testing.T) {
10961097

10971098
recorderClient := &mockRecorderClient{}
10981099

1099-
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer))
1100+
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer))
11001101
require.NoError(t, err)
11011102

11021103
// Invoke request to mocked API via aibridge.
@@ -1136,7 +1137,7 @@ func TestErrorHandling(t *testing.T) {
11361137
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
11371138
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
11381139
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1139-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1140+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
11401141
},
11411142
responseHandlerFn: func(resp *http.Response) {
11421143
// Server responds first with 200 OK then starts streaming.
@@ -1155,7 +1156,7 @@ func TestErrorHandling(t *testing.T) {
11551156
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
11561157
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
11571158
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1158-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1159+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
11591160
},
11601161
responseHandlerFn: func(resp *http.Response) {
11611162
// Server responds first with 200 OK then starts streaming.
@@ -1198,7 +1199,7 @@ func TestErrorHandling(t *testing.T) {
11981199

11991200
recorderClient := &mockRecorderClient{}
12001201

1201-
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, defaultTracer))
1202+
b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer))
12021203
require.NoError(t, err)
12031204

12041205
// Invoke request to mocked API via aibridge.
@@ -1242,7 +1243,7 @@ func TestStableRequestEncoding(t *testing.T) {
12421243
createRequestFunc: createAnthropicMessagesReq,
12431244
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
12441245
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1245-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1246+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
12461247
},
12471248
},
12481249
{
@@ -1251,7 +1252,7 @@ func TestStableRequestEncoding(t *testing.T) {
12511252
createRequestFunc: createOpenAIChatCompletionsReq,
12521253
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
12531254
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1254-
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, defaultTracer, logger)
1255+
return aibridge.NewRequestBridge(t.Context(), providers, client, srvProxyMgr, nil, testTracer, logger)
12551256
},
12561257
},
12571258
}
@@ -1264,10 +1265,10 @@ func TestStableRequestEncoding(t *testing.T) {
12641265
t.Cleanup(cancel)
12651266

12661267
// Setup MCP tools.
1267-
tools := setupMCPServerProxiesForTest(t)
1268+
tools := setupMCPServerProxiesForTest(t, testTracer)
12681269

12691270
// Configure the bridge with injected tools.
1270-
mcpMgr := mcp.NewServerProxyManager(tools, defaultTracer)
1271+
mcpMgr := mcp.NewServerProxyManager(tools, testTracer)
12711272
require.NoError(t, mcpMgr.Init(ctx))
12721273

12731274
arc := txtar.Parse(tc.fixture)
@@ -1358,7 +1359,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13581359
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
13591360
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
13601361
providers := []aibridge.Provider{aibridge.NewAnthropicProvider(anthropicCfg(addr, apiKey), nil)}
1361-
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
1362+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
13621363
},
13631364
createRequest: createAnthropicMessagesReq,
13641365
envVars: map[string]string{
@@ -1372,7 +1373,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) {
13721373
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
13731374
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
13741375
providers := []aibridge.Provider{aibridge.NewOpenAIProvider(openaiCfg(addr, apiKey))}
1375-
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, defaultTracer), nil, defaultTracer, logger)
1376+
return aibridge.NewRequestBridge(t.Context(), providers, client, mcp.NewServerProxyManager(nil, testTracer), nil, testTracer, logger)
13761377
},
13771378
createRequest: createOpenAIChatCompletionsReq,
13781379
envVars: map[string]string{

intercept_anthropic_messages_base.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,13 @@ func (i *AnthropicMessagesInterceptionBase) Model() string {
6363
return string(i.req.Model)
6464
}
6565

66-
func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(ctx context.Context, streaming bool) []attribute.KeyValue {
66+
func (s *AnthropicMessagesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
6767
return []attribute.KeyValue{
68-
attribute.String(aibtrace.Provider, ProviderAnthropic),
68+
attribute.String(aibtrace.RequestPath, r.URL.Path),
6969
attribute.String(aibtrace.InterceptionID, s.id.String()),
70+
attribute.String(aibtrace.UserID, actorFromContext(r.Context()).id),
71+
attribute.String(aibtrace.Provider, ProviderAnthropic),
7072
attribute.String(aibtrace.Model, s.Model()),
71-
attribute.String(aibtrace.UserID, actorFromContext(ctx).id),
7273
attribute.Bool(aibtrace.Streaming, streaming),
7374
attribute.Bool(aibtrace.IsBedrock, s.bedrockCfg != nil),
7475
}

0 commit comments

Comments
 (0)