diff --git a/.gitignore b/.gitignore index 7f5be4bf18..f0840c001e 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ cmd/thv-operator/.task/checksum/crdref-gen # Test coverage coverage* -crd-helm-wrapper \ No newline at end of file +crd-helm-wrapper +cmd/vmcp/__debug_bin* diff --git a/.golangci.yml b/.golangci.yml index 62c3611473..ff2b3d54e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,6 +139,7 @@ linters: - third_party$ - builtin$ - examples$ + - scripts$ formatters: enable: - gci @@ -155,3 +156,4 @@ formatters: - third_party$ - builtin$ - examples$ + - scripts$ diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index f9c0aa8a70..aa55537e3a 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "go.opentelemetry.io/otel/trace" "k8s.io/client-go/rest" "github.com/stacklok/toolhive/pkg/audit" @@ -310,8 +311,27 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create conflict resolver: %w", err) } - // Create aggregator - agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, cfg.Aggregation.Tools) + // If telemetry is configured, create the provider early so aggregator can use it + var telemetryProvider *telemetry.Provider + if cfg.Telemetry != nil { + telemetryProvider, err = telemetry.NewProvider(ctx, *cfg.Telemetry) + if err != nil { + return fmt.Errorf("failed to create telemetry provider: %w", err) + } + defer func() { + err := telemetryProvider.Shutdown(ctx) + if err != nil { + logger.Errorf("failed to shutdown telemetry provider: %v", err) + } + }() + } + + // Create aggregator with tracer provider (nil if telemetry not configured) + var tracerProvider trace.TracerProvider + if telemetryProvider != nil { + tracerProvider = telemetryProvider.TracerProvider() + } + agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, cfg.Aggregation.Tools, tracerProvider) // Use DynamicRegistry for version-based cache invalidation // Works in both standalone (CLI with YAML config) and Kubernetes (operator-deployed) modes @@ -381,21 +401,8 @@ func runServe(cmd *cobra.Command, _ []string) error { host, _ := cmd.Flags().GetString("host") port, _ := cmd.Flags().GetInt("port") - // If telemetry is configured, create the provider. - var telemetryProvider *telemetry.Provider - if cfg.Telemetry != nil { - var err error - telemetryProvider, err = telemetry.NewProvider(ctx, *cfg.Telemetry) - if err != nil { - return fmt.Errorf("failed to create telemetry provider: %w", err) - } - defer func() { - err := telemetryProvider.Shutdown(ctx) - if err != nil { - logger.Errorf("failed to shutdown telemetry provider: %v", err) - } - }() - } + // Note: telemetryProvider was already created earlier (before aggregator creation) + // to enable tracing in the aggregator // Configure health monitoring if enabled var healthMonitorConfig *health.MonitorConfig diff --git a/codecov.yaml b/codecov.yaml index 1a8032e484..410f9ae7ee 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -13,6 +13,8 @@ coverage: - "**/mocks/**/*" - "**/mock_*.go" - "**/zz_generated.deepcopy.go" + - "**/*_test.go" + - "**/*_test_coverage.go" status: project: default: diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go index 3cdb58c217..18b44095ac 100644 --- a/pkg/runner/config_builder_test.go +++ b/pkg/runner/config_builder_test.go @@ -1079,8 +1079,8 @@ func TestRunConfigBuilder_WithRegistryProxyPort(t *testing.T) { ProxyPort: testPort, TargetPort: testPort, }, - cliProxyPort: 9000, - expectedProxyPort: 9000, + cliProxyPort: 9999, + expectedProxyPort: 9999, }, { name: "random port when neither CLI nor registry specified", diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 95be734af0..ca51d207d8 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -8,6 +8,10 @@ import ( "fmt" "sync" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/logger" @@ -21,15 +25,18 @@ type defaultAggregator struct { backendClient vmcp.BackendClient conflictResolver ConflictResolver toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config + tracer trace.Tracer } // NewDefaultAggregator creates a new default aggregator implementation. // conflictResolver handles tool name conflicts across backends. // workloadConfigs specifies per-backend tool filtering and overrides. +// tracerProvider is used to create a tracer for distributed tracing (pass nil for no tracing). func NewDefaultAggregator( backendClient vmcp.BackendClient, conflictResolver ConflictResolver, workloadConfigs []*config.WorkloadToolConfig, + tracerProvider trace.TracerProvider, ) Aggregator { // Build tool config map for quick lookup by backend ID toolConfigMap := make(map[string]*config.WorkloadToolConfig) @@ -39,16 +46,38 @@ func NewDefaultAggregator( } } + // Create tracer from provider (use noop tracer if provider is nil) + var tracer trace.Tracer + if tracerProvider != nil { + tracer = tracerProvider.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator") + } else { + tracer = noop.NewTracerProvider().Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator") + } + return &defaultAggregator{ backendClient: backendClient, conflictResolver: conflictResolver, toolConfigMap: toolConfigMap, + tracer: tracer, } } // QueryCapabilities queries a single backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. -func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { +func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (_ *BackendCapabilities, retErr error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + ), + ) + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() + logger.Debugf("Querying capabilities from backend %s", backend.ID) // Create a BackendTarget from the Backend @@ -74,6 +103,12 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. SupportsSampling: capabilities.SupportsSampling, } + span.SetAttributes( + attribute.Int("tools.count", len(result.Tools)), + attribute.Int("resources.count", len(result.Resources)), + attribute.Int("prompts.count", len(result.Prompts)), + ) + logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts", backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) @@ -85,7 +120,20 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. func (a *defaultAggregator) QueryAllCapabilities( ctx context.Context, backends []vmcp.Backend, -) (map[string]*BackendCapabilities, error) { +) (_ map[string]*BackendCapabilities, retErr error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() + logger.Infof("Querying capabilities from %d backends", len(backends)) // Use errgroup for parallel queries with context cancellation @@ -125,6 +173,10 @@ func (a *defaultAggregator) QueryAllCapabilities( return nil, fmt.Errorf("no backends returned capabilities") } + span.SetAttributes( + attribute.Int("successful.backends", len(capabilities)), + ) + logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends)) return capabilities, nil } @@ -134,7 +186,20 @@ func (a *defaultAggregator) QueryAllCapabilities( func (a *defaultAggregator) ResolveConflicts( ctx context.Context, capabilities map[string]*BackendCapabilities, -) (*ResolvedCapabilities, error) { +) (_ *ResolvedCapabilities, retErr error) { + ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts", + trace.WithAttributes( + attribute.Int("backends.count", len(capabilities)), + ), + ) + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() + logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) // Group tools by backend for conflict resolution @@ -191,6 +256,12 @@ func (a *defaultAggregator) ResolveConflicts( resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling } + span.SetAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ) + logger.Debugf("Resolved %d unique tools, %d resources, %d prompts", len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts)) @@ -199,11 +270,26 @@ func (a *defaultAggregator) ResolveConflicts( // MergeCapabilities creates the final unified capability view and routing table. // Uses the backend registry to populate full BackendTarget information for routing. -func (*defaultAggregator) MergeCapabilities( +func (a *defaultAggregator) MergeCapabilities( ctx context.Context, resolved *ResolvedCapabilities, registry vmcp.BackendRegistry, -) (*AggregatedCapabilities, error) { +) (_ *AggregatedCapabilities, retErr error) { + ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities", + trace.WithAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ), + ) + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() + logger.Debugf("Merging capabilities into final view") // Create routing table @@ -304,6 +390,13 @@ func (*defaultAggregator) MergeCapabilities( }, } + span.SetAttributes( + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts", aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) @@ -315,7 +408,23 @@ func (*defaultAggregator) MergeCapabilities( // 2. Query all backends // 3. Resolve conflicts // 4. Merge into final view with full backend information -func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { +func (a *defaultAggregator) AggregateCapabilities( + ctx context.Context, + backends []vmcp.Backend, +) (_ *AggregatedCapabilities, retErr error) { + ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() + logger.Infof("Starting capability aggregation for %d backends", len(backends)) // Step 1: Create registry from discovered backends @@ -343,6 +452,14 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends // Update metadata with backend count aggregated.Metadata.BackendCount = len(backends) + span.SetAttributes( + attribute.Int("aggregated.backends", aggregated.Metadata.BackendCount), + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts", aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) diff --git a/pkg/vmcp/aggregator/default_aggregator_test.go b/pkg/vmcp/aggregator/default_aggregator_test.go index 3798f5324a..6d07c23af5 100644 --- a/pkg/vmcp/aggregator/default_aggregator_test.go +++ b/pkg/vmcp/aggregator/default_aggregator_test.go @@ -35,7 +35,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(expectedCaps, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.NoError(t, err) @@ -59,7 +59,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.Error(t, err) @@ -90,7 +90,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -120,7 +120,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { return nil, errors.New("connection timeout") }).Times(2) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -140,7 +140,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.Error(t, err) @@ -171,7 +171,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -204,7 +204,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -263,7 +263,7 @@ func TestDefaultAggregator_MergeCapabilities(t *testing.T) { } registry := vmcp.NewImmutableRegistry(backends) - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) aggregated, err := agg.MergeCapabilities(context.Background(), resolved, registry) require.NoError(t, err) @@ -331,7 +331,7 @@ func TestDefaultAggregator_AggregateCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.AggregateCapabilities(context.Background(), backends) require.NoError(t, err) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index e8e5bf0ab1..a30b717ce1 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -15,6 +15,7 @@ import ( "io" "net" "net/http" + "time" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -127,8 +128,6 @@ func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) } - logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) - return a.base.RoundTrip(reqClone) } @@ -170,6 +169,8 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm target.WorkloadID, err) } + logger.Debugf("Applied authentication strategy %q to backend %s", authStrategy.Name(), target.WorkloadID) + // Add authentication layer with pre-resolved strategy baseTransport = &authRoundTripper{ base: baseTransport, @@ -204,8 +205,10 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm }) // Create HTTP client with configured transport chain + // Set timeouts to prevent long-lived connections that require continuous listening httpClient := &http.Client{ Transport: sizeLimitedTransport, + Timeout: 30 * time.Second, // Prevent hanging on connections } var c *client.Client @@ -214,8 +217,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm case "streamable-http", "streamable": c, err = client.NewStreamableHttpClient( target.BaseURL, - transport.WithHTTPTimeout(0), - transport.WithContinuousListening(), + transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0 transport.WithHTTPBasicClient(httpClient), ) if err != nil { @@ -696,10 +698,10 @@ func (h *httpBackendClient) ReadResource( } // Extract _meta field from backend response - // Note: Due to MCP SDK limitations, the SDK's ReadResourceResult may not include Meta. - // This preserves it for future SDK improvements. meta := conversion.FromMCPMeta(result.Meta) + // Note: Due to MCP SDK limitations, the SDK's ReadResourceResult may not include Meta. + // This preserves it for future SDK improvements. return &vmcp.ResourceReadResult{ Contents: data, MimeType: mimeType, diff --git a/pkg/vmcp/discovery/manager_test_coverage.go b/pkg/vmcp/discovery/manager_test_coverage.go new file mode 100644 index 0000000000..3826fc2849 --- /dev/null +++ b/pkg/vmcp/discovery/manager_test_coverage.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" + vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestDefaultManager_CacheVersionMismatch tests cache invalidation on version mismatch +func TestDefaultManager_CacheVersionMismatch(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + mockRegistry := vmcpmocks.NewMockDynamicRegistry(ctrl) + + // First call - version 1 + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManagerWithRegistry(mockAggregator, mockRegistry) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery - should cache + caps1, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps1) + + // Second discovery with same version - should use cache + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(1) + caps2, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps2) + + // Third discovery with different version - should invalidate cache + mockRegistry.EXPECT().Version().Return(uint64(2)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + caps3, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps3) +} + +// TestDefaultManager_CacheAtCapacity tests cache eviction when at capacity +func TestDefaultManager_CacheAtCapacity(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // Create many different cache keys to fill cache + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(maxCacheSize + 1) // One more than capacity + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + // Fill cache to capacity + for i := 0; i < maxCacheSize; i++ { + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i)), + }) + + backends := []vmcp.Backend{ + {ID: "backend-" + string(rune(i)), Name: "Backend"}, + } + + _, err := manager.Discover(ctx, backends) + require.NoError(t, err) + } + + // Next discovery should not cache (at capacity) + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-new", + }) + + backends := []vmcp.Backend{ + {ID: "backend-new", Name: "Backend"}, + } + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} + +// TestDefaultManager_CacheAtCapacity_ExistingKey tests cache update when at capacity but key exists +func TestDefaultManager_CacheAtCapacity_ExistingKey(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // First call + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) + + // Fill cache to capacity with other keys + for i := 0; i < maxCacheSize-1; i++ { + ctxOther := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i+2)), + }) + + backendsOther := []vmcp.Backend{ + {ID: "backend-" + string(rune(i+2)), Name: "Backend"}, + } + + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err := manager.Discover(ctxOther, backendsOther) + require.NoError(t, err) + } + + // Update existing key should work even at capacity + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index d593aa0401..ccc3a8effc 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -41,7 +41,11 @@ type healthChecker struct { // - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled) // // Returns a new HealthChecker implementation. -func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration) vmcp.HealthChecker { +func NewHealthChecker( + client vmcp.BackendClient, + timeout time.Duration, + degradedThreshold time.Duration, +) vmcp.HealthChecker { return &healthChecker{ client: client, timeout: timeout, @@ -62,11 +66,15 @@ func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degraded // The error return is informational and provides context about what failed. // The BackendHealthStatus return indicates the categorized health state. func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { - // Apply timeout if configured - checkCtx := ctx + // Mark context as health check to bypass authentication logging + // Health checks verify backend availability and should not require user credentials + healthCheckCtx := WithHealthCheckMarker(ctx) + + // Apply timeout if configured (after adding health check marker) + checkCtx := healthCheckCtx var cancel context.CancelFunc if h.timeout > 0 { - checkCtx, cancel = context.WithTimeout(ctx, h.timeout) + checkCtx, cancel = context.WithTimeout(healthCheckCtx, h.timeout) defer cancel() } diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index 0086a8b4a9..8bb5e40231 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -119,6 +119,7 @@ func TestIntegration_AggregatorToRouterToServer(t *testing.T) { mockBackendClient, conflictResolver, nil, // no tool configs + nil, // no tracer provider in tests ) // Step 3: Run aggregation on mock backends @@ -311,7 +312,7 @@ func TestIntegration_HTTPRequestFlowWithRoutingTable(t *testing.T) { // Create discovery manager conflictResolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, conflictResolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, conflictResolver, nil, nil) discoveryMgr, err := discovery.NewManager(agg) require.NoError(t, err) @@ -501,7 +502,7 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) { Times(2) resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) result, err := agg.AggregateCapabilities(ctx, createBackendsWithConflicts()) require.NoError(t, err) @@ -539,7 +540,7 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) { resolver, err := aggregator.NewPriorityConflictResolver([]string{"backend1", "backend2"}) require.NoError(t, err) - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) result, err := agg.AggregateCapabilities(ctx, createBackendsWithConflicts()) require.NoError(t, err) @@ -659,7 +660,7 @@ func TestIntegration_AuditLogging(t *testing.T) { Discover(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) return agg.AggregateCapabilities(ctx, backends) }). AnyTimes() diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index ca73e206f2..56b9bbf52f 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -37,7 +37,8 @@ import ( ) // WaitForVirtualMCPServerReady waits for a VirtualMCPServer to reach Ready status -// and ensures the associated pods are actually running and ready +// and ensures at least one associated pod is actually running and ready. +// This is used when waiting for a single expected pod (e.g., one replica deployment). func WaitForVirtualMCPServerReady( ctx context.Context, c client.Client, @@ -58,7 +59,7 @@ func WaitForVirtualMCPServerReady( for _, condition := range vmcpServer.Status.Conditions { if condition.Type == "Ready" { if condition.Status == "True" { - // Also check that the pods are actually running and ready + // Also check that at least one pod is actually running and ready labels := map[string]string{ "app.kubernetes.io/name": "virtualmcpserver", "app.kubernetes.io/instance": name, @@ -75,7 +76,9 @@ func WaitForVirtualMCPServerReady( }, timeout, pollingInterval).Should(gomega.Succeed()) } -// checkPodsReady checks if all pods matching the given labels are ready +// checkPodsReady waits for at least one pod matching the given labels to be ready. +// This is used when checking for a single expected pod (e.g., one replica deployment). +// Pods not in Running phase are skipped (e.g., Succeeded, Failed from previous deployments). func checkPodsReady(ctx context.Context, c client.Client, namespace string, labels map[string]string) error { podList := &corev1.PodList{} if err := c.List(ctx, podList, @@ -89,8 +92,9 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe } for _, pod := range podList.Items { + // Skip pods that are not running (e.g., Succeeded, Failed from old deployments) if pod.Status.Phase != corev1.PodRunning { - return fmt.Errorf("pod %s is in phase %s", pod.Name, pod.Status.Phase) + continue } containerReady := false @@ -114,6 +118,17 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe return fmt.Errorf("pod %s not ready", pod.Name) } } + + // After filtering, ensure we found at least one running pod + runningPods := 0 + for _, pod := range podList.Items { + if pod.Status.Phase == corev1.PodRunning { + runningPods++ + } + } + if runningPods == 0 { + return fmt.Errorf("no running pods found with labels %v", labels) + } return nil } @@ -234,7 +249,8 @@ func GetVirtualMCPServerPods(ctx context.Context, c client.Client, vmcpServerNam return podList, err } -// WaitForPodsReady waits for all pods matching labels to be ready +// WaitForPodsReady waits for at least one pod matching labels to be ready. +// This is used when waiting for a single expected pod to be ready (e.g., one replica deployment). func WaitForPodsReady( ctx context.Context, c client.Client, diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go index e7e33fd623..18af2c94df 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go @@ -1162,12 +1162,28 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: } It("should list and call tools from all backends with discovered auth", func() { + By("Verifying vMCP pods are still running and ready before health check") + vmcpLabels := map[string]string{ + "app.kubernetes.io/name": "virtualmcpserver", + "app.kubernetes.io/instance": vmcpServerName, + } + WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second) + + // Create HTTP client with timeout for health checks + healthCheckClient := &http.Client{ + Timeout: 10 * time.Second, + } + By("Verifying HTTP connectivity to VirtualMCPServer health endpoint") Eventually(func() error { + // Re-check pod readiness before each health check attempt + if err := checkPodsReady(ctx, k8sClient, testNamespace, vmcpLabels); err != nil { + return fmt.Errorf("pods not ready: %w", err) + } url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := http.Get(url) + resp, err := healthCheckClient.Get(url) if err != nil { - return err + return fmt.Errorf("health check failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { diff --git a/test/integration/vmcp/helpers/vmcp_server.go b/test/integration/vmcp/helpers/vmcp_server.go index 95aa5f5142..99caa6c086 100644 --- a/test/integration/vmcp/helpers/vmcp_server.go +++ b/test/integration/vmcp/helpers/vmcp_server.go @@ -161,7 +161,7 @@ func NewVMCPServer( } // Create aggregator - agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, nil) + agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, nil, nil) // Create discovery manager discoveryMgr, err := discovery.NewManager(agg)