From 9e28406e9e6008d80bb01dae95d05b3538ac250e Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 12:15:58 +0000 Subject: [PATCH 01/16] Infrastructure improvements and bugfixes for vMCP - Add OpenTelemetry tracing to capability aggregation - Add singleflight deduplication for discovery requests - Add health checker self-check prevention - Add HTTP client timeout fixes - Improve E2E test reliability - Various build and infrastructure improvements --- .gitignore | 9 +- .golangci.yml | 2 + codecov.yaml | 2 + deploy/charts/operator-crds/Chart.yaml | 2 +- pkg/runner/config_builder_test.go | 4 +- pkg/vmcp/aggregator/default_aggregator.go | 93 +++- pkg/vmcp/client/client.go | 159 ++---- pkg/vmcp/discovery/manager.go | 41 +- pkg/vmcp/discovery/manager_test_coverage.go | 176 ++++++ pkg/vmcp/health/checker.go | 92 +++- pkg/vmcp/health/checker_selfcheck_test.go | 504 ++++++++++++++++++ pkg/vmcp/health/monitor.go | 6 +- test/e2e/thv-operator/virtualmcp/helpers.go | 14 +- .../virtualmcp_auth_discovery_test.go | 20 +- 14 files changed, 991 insertions(+), 133 deletions(-) create mode 100644 pkg/vmcp/discovery/manager_test_coverage.go create mode 100644 pkg/vmcp/health/checker_selfcheck_test.go diff --git a/.gitignore b/.gitignore index 7f5be4bf18..34dcc23d79 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,11 @@ cmd/thv-operator/.task/checksum/crdref-gen # Test coverage coverage* -crd-helm-wrapper \ No newline at end of file +crd-helm-wrapper +cmd/vmcp/__debug_bin* + +# Demo files +examples/operator/virtual-mcps/vmcp_optimizer.yaml +scripts/k8s_vmcp_optimizer_demo.sh +examples/ingress/mcp-servers-ingress.yaml +/vmcp diff --git a/.golangci.yml b/.golangci.yml index 62c3611473..ff2b3d54e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,6 +139,7 @@ linters: - third_party$ - builtin$ - examples$ + - scripts$ formatters: enable: - gci @@ -155,3 +156,4 @@ formatters: - third_party$ - builtin$ - examples$ + - scripts$ diff --git a/codecov.yaml b/codecov.yaml index 1a8032e484..410f9ae7ee 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -13,6 +13,8 @@ coverage: - "**/mocks/**/*" - "**/mock_*.go" - "**/zz_generated.deepcopy.go" + - "**/*_test.go" + - "**/*_test_coverage.go" status: project: default: diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml index e336674530..1b14897d71 100644 --- a/deploy/charts/operator-crds/Chart.yaml +++ b/deploy/charts/operator-crds/Chart.yaml @@ -2,5 +2,5 @@ apiVersion: v2 name: toolhive-operator-crds description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. type: application -version: 0.0.103 +version: 0.0.102 appVersion: "0.0.1" diff --git a/pkg/runner/config_builder_test.go b/pkg/runner/config_builder_test.go index 735c9ccc45..0e4556937d 100644 --- a/pkg/runner/config_builder_test.go +++ b/pkg/runner/config_builder_test.go @@ -1076,8 +1076,8 @@ func TestRunConfigBuilder_WithRegistryProxyPort(t *testing.T) { ProxyPort: 8976, TargetPort: 8976, }, - cliProxyPort: 9000, - expectedProxyPort: 9000, + cliProxyPort: 9999, + expectedProxyPort: 9999, }, { name: "random port when neither CLI nor registry specified", diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 95be734af0..3cf2846fcc 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -8,6 +8,10 @@ import ( "fmt" "sync" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/logger" @@ -21,6 +25,7 @@ type defaultAggregator struct { backendClient vmcp.BackendClient conflictResolver ConflictResolver toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config + tracer trace.Tracer } // NewDefaultAggregator creates a new default aggregator implementation. @@ -43,12 +48,20 @@ func NewDefaultAggregator( backendClient: backendClient, conflictResolver: conflictResolver, toolConfigMap: toolConfigMap, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator"), } } // QueryCapabilities queries a single backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + ), + ) + defer span.End() + logger.Debugf("Querying capabilities from backend %s", backend.ID) // Create a BackendTarget from the Backend @@ -58,6 +71,8 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. // Query capabilities using the backend client capabilities, err := a.backendClient.ListCapabilities(ctx, target) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("%w: %s: %w", ErrBackendQueryFailed, backend.ID, err) } @@ -74,6 +89,12 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. SupportsSampling: capabilities.SupportsSampling, } + span.SetAttributes( + attribute.Int("tools.count", len(result.Tools)), + attribute.Int("resources.count", len(result.Resources)), + attribute.Int("prompts.count", len(result.Prompts)), + ) + logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts", backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) @@ -86,6 +107,13 @@ func (a *defaultAggregator) QueryAllCapabilities( ctx context.Context, backends []vmcp.Backend, ) (map[string]*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Querying capabilities from %d backends", len(backends)) // Use errgroup for parallel queries with context cancellation @@ -118,13 +146,22 @@ func (a *defaultAggregator) QueryAllCapabilities( // Wait for all queries to complete if err := g.Wait(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("capability queries failed: %w", err) } if len(capabilities) == 0 { - return nil, fmt.Errorf("no backends returned capabilities") + err := fmt.Errorf("no backends returned capabilities") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err } + span.SetAttributes( + attribute.Int("successful.backends", len(capabilities)), + ) + logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends)) return capabilities, nil } @@ -135,6 +172,13 @@ func (a *defaultAggregator) ResolveConflicts( ctx context.Context, capabilities map[string]*BackendCapabilities, ) (*ResolvedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts", + trace.WithAttributes( + attribute.Int("backends.count", len(capabilities)), + ), + ) + defer span.End() + logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) // Group tools by backend for conflict resolution @@ -150,6 +194,8 @@ func (a *defaultAggregator) ResolveConflicts( if a.conflictResolver != nil { resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("conflict resolution failed: %w", err) } } else { @@ -191,6 +237,12 @@ func (a *defaultAggregator) ResolveConflicts( resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling } + span.SetAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ) + logger.Debugf("Resolved %d unique tools, %d resources, %d prompts", len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts)) @@ -199,11 +251,20 @@ func (a *defaultAggregator) ResolveConflicts( // MergeCapabilities creates the final unified capability view and routing table. // Uses the backend registry to populate full BackendTarget information for routing. -func (*defaultAggregator) MergeCapabilities( +func (a *defaultAggregator) MergeCapabilities( ctx context.Context, resolved *ResolvedCapabilities, registry vmcp.BackendRegistry, ) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities", + trace.WithAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ), + ) + defer span.End() + logger.Debugf("Merging capabilities into final view") // Create routing table @@ -304,6 +365,13 @@ func (*defaultAggregator) MergeCapabilities( }, } + span.SetAttributes( + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts", aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) @@ -316,6 +384,13 @@ func (*defaultAggregator) MergeCapabilities( // 3. Resolve conflicts // 4. Merge into final view with full backend information func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Starting capability aggregation for %d backends", len(backends)) // Step 1: Create registry from discovered backends @@ -325,24 +400,38 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends // Step 2: Query all backends capabilities, err := a.QueryAllCapabilities(ctx, backends) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to query backends: %w", err) } // Step 3: Resolve conflicts resolved, err := a.ResolveConflicts(ctx, capabilities) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to resolve conflicts: %w", err) } // Step 4: Merge into final view with full backend information aggregated, err := a.MergeCapabilities(ctx, resolved, registry) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to merge capabilities: %w", err) } // Update metadata with backend count aggregated.Metadata.BackendCount = len(backends) + span.SetAttributes( + attribute.Int("aggregated.backends", aggregated.Metadata.BackendCount), + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts", aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index e8e5bf0ab1..0634376de6 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -15,6 +15,7 @@ import ( "io" "net" "net/http" + "time" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -25,7 +26,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" - "github.com/stacklok/toolhive/pkg/vmcp/conversion" ) const ( @@ -127,8 +127,6 @@ func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) } - logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) - return a.base.RoundTrip(reqClone) } @@ -204,8 +202,10 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm }) // Create HTTP client with configured transport chain + // Set timeouts to prevent long-lived connections that require continuous listening httpClient := &http.Client{ Transport: sizeLimitedTransport, + Timeout: 30 * time.Second, // Prevent hanging on connections } var c *client.Client @@ -214,8 +214,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm case "streamable-http", "streamable": c, err = client.NewStreamableHttpClient( target.BaseURL, - transport.WithHTTPTimeout(0), - transport.WithContinuousListening(), + transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0 transport.WithHTTPBasicClient(httpClient), ) if err != nil { @@ -373,36 +372,6 @@ func queryPrompts(ctx context.Context, c *client.Client, supported bool, backend return &mcp.ListPromptsResult{Prompts: []mcp.Prompt{}}, nil } -// convertContent converts mcp.Content to vmcp.Content. -// This preserves the full content structure from backend responses. -func convertContent(content mcp.Content) vmcp.Content { - if textContent, ok := mcp.AsTextContent(content); ok { - return vmcp.Content{ - Type: "text", - Text: textContent.Text, - } - } - if imageContent, ok := mcp.AsImageContent(content); ok { - return vmcp.Content{ - Type: "image", - Data: imageContent.Data, - MimeType: imageContent.MIMEType, - } - } - if audioContent, ok := mcp.AsAudioContent(content); ok { - return vmcp.Content{ - Type: "audio", - Data: audioContent.Data, - MimeType: audioContent.MIMEType, - } - } - // Handle embedded resources if needed - // Unknown content types are marked as "unknown" type with no data - logger.Warnf("Encountered unknown content type %T, marking as unknown content. "+ - "This may indicate missing support for embedded resources or other MCP content types.", content) - return vmcp.Content{Type: "unknown"} -} - // ListCapabilities queries a backend for its MCP capabilities. // Returns tools, resources, and prompts exposed by the backend. // Only queries capabilities that the server advertises during initialization. @@ -518,7 +487,6 @@ func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.B } // CallTool invokes a tool on the backend MCP server. -// Returns the complete tool result including _meta field. // //nolint:gocyclo // this function is complex because it handles tool calls with various content types and error handling. func (h *httpBackendClient) CallTool( @@ -526,8 +494,7 @@ func (h *httpBackendClient) CallTool( target *vmcp.BackendTarget, toolName string, arguments map[string]any, - meta map[string]any, -) (*vmcp.ToolCallResult, error) { +) (map[string]any, error) { logger.Debugf("Calling tool %s on backend %s", toolName, target.WorkloadName) // Create a client for this backend @@ -558,7 +525,6 @@ func (h *httpBackendClient) CallTool( Params: mcp.CallToolParams{ Name: backendToolName, Arguments: arguments, - Meta: conversion.ToMCPMeta(meta), }, }) if err != nil { @@ -566,12 +532,9 @@ func (h *httpBackendClient) CallTool( return nil, fmt.Errorf("%w: tool call failed on backend %s: %w", vmcp.ErrBackendUnavailable, target.WorkloadID, err) } - // Extract _meta field from backend response - responseMeta := conversion.FromMCPMeta(result.Meta) - - // Log if tool returned IsError=true (MCP protocol-level error, not a transport error) - // We still return the full result to preserve metadata and error details for the client + // Check if the tool call returned an error (MCP domain error) if result.IsError { + // Extract error message from content for logging and forwarding var errorMsg string if len(result.Content) > 0 { if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { @@ -579,60 +542,60 @@ func (h *httpBackendClient) CallTool( } } if errorMsg == "" { - errorMsg = "tool execution error" - } - - // Log with metadata for distributed tracing - if responseMeta != nil { - logger.Warnf("Tool %s on backend %s returned IsError=true: %s (meta: %+v)", - toolName, target.WorkloadID, errorMsg, responseMeta) - } else { - logger.Warnf("Tool %s on backend %s returned IsError=true: %s", toolName, target.WorkloadID, errorMsg) + errorMsg = "unknown error" } - // Continue processing - we return the result with IsError flag and metadata preserved - } - - // Convert MCP content to vmcp.Content array - contentArray := make([]vmcp.Content, len(result.Content)) - for i, content := range result.Content { - contentArray[i] = convertContent(content) + logger.Warnf("Tool %s on backend %s returned error: %s", toolName, target.WorkloadID, errorMsg) + // Wrap with ErrToolExecutionFailed so router can forward transparently to client + return nil, fmt.Errorf("%w: %s on backend %s: %s", vmcp.ErrToolExecutionFailed, toolName, target.WorkloadID, errorMsg) } // Check for structured content first (preferred for composite tool step chaining). // StructuredContent allows templates to access nested fields directly via {{.steps.stepID.output.field}}. // Note: StructuredContent must be an object (map). Arrays or primitives are not supported. - var structuredContent map[string]any if result.StructuredContent != nil { if structuredMap, ok := result.StructuredContent.(map[string]any); ok { logger.Debugf("Using structured content from tool %s on backend %s", toolName, target.WorkloadID) - structuredContent = structuredMap - } else { - // StructuredContent is not an object - fall through to Content processing - logger.Debugf("StructuredContent from tool %s on backend %s is not an object, falling back to Content", - toolName, target.WorkloadID) + return structuredMap, nil } + // StructuredContent is not an object - fall through to Content processing + logger.Debugf("StructuredContent from tool %s on backend %s is not an object, falling back to Content", + toolName, target.WorkloadID) } - // If no structured content, convert result contents to a map for backward compatibility. + // Fallback: Convert result contents to a map. // MCP tools return an array of Content interface (TextContent, ImageContent, etc.). // Text content is stored under "text" key, accessible via {{.steps.stepID.output.text}}. - if structuredContent == nil { - structuredContent = conversion.ContentArrayToMap(contentArray) + resultMap := make(map[string]any) + if len(result.Content) > 0 { + textIndex := 0 + imageIndex := 0 + for i, content := range result.Content { + // Try to convert to TextContent + if textContent, ok := mcp.AsTextContent(content); ok { + key := "text" + if textIndex > 0 { + key = fmt.Sprintf("text_%d", textIndex) + } + resultMap[key] = textContent.Text + textIndex++ + } else if imageContent, ok := mcp.AsImageContent(content); ok { + // Convert to ImageContent + key := fmt.Sprintf("image_%d", imageIndex) + resultMap[key] = imageContent.Data + imageIndex++ + } else { + // Log unsupported content types for tracking + logger.Debugf("Unsupported content type at index %d from tool %s on backend %s: %T", + i, toolName, target.WorkloadID, content) + } + } } - return &vmcp.ToolCallResult{ - Content: contentArray, - StructuredContent: structuredContent, - IsError: result.IsError, - Meta: responseMeta, - }, nil + return resultMap, nil } // ReadResource retrieves a resource from the backend MCP server. -// Returns the complete resource result including _meta field. -func (h *httpBackendClient) ReadResource( - ctx context.Context, target *vmcp.BackendTarget, uri string, -) (*vmcp.ResourceReadResult, error) { +func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.BackendTarget, uri string) ([]byte, error) { logger.Debugf("Reading resource %s from backend %s", uri, target.WorkloadName) // Create a client for this backend @@ -670,14 +633,10 @@ func (h *httpBackendClient) ReadResource( // Concatenate all resource contents // MCP resources can have multiple contents (text or blob) var data []byte - var mimeType string - for i, content := range result.Contents { + for _, content := range result.Contents { // Try to convert to TextResourceContents if textContent, ok := mcp.AsTextResourceContents(content); ok { data = append(data, []byte(textContent.Text)...) - if i == 0 && textContent.MIMEType != "" { - mimeType = textContent.MIMEType - } } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { // Blob is base64-encoded per MCP spec, decode it to bytes decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob) @@ -689,38 +648,25 @@ func (h *httpBackendClient) ReadResource( } else { data = append(data, decoded...) } - if i == 0 && blobContent.MIMEType != "" { - mimeType = blobContent.MIMEType - } } } - // Extract _meta field from backend response - // Note: Due to MCP SDK limitations, the SDK's ReadResourceResult may not include Meta. - // This preserves it for future SDK improvements. - meta := conversion.FromMCPMeta(result.Meta) - - return &vmcp.ResourceReadResult{ - Contents: data, - MimeType: mimeType, - Meta: meta, - }, nil + return data, nil } // GetPrompt retrieves a prompt from the backend MCP server. -// Returns the complete prompt result including _meta field. func (h *httpBackendClient) GetPrompt( ctx context.Context, target *vmcp.BackendTarget, name string, arguments map[string]any, -) (*vmcp.PromptGetResult, error) { +) (string, error) { logger.Debugf("Getting prompt %s from backend %s", name, target.WorkloadName) // Create a client for this backend c, err := h.clientFactory(ctx, target) if err != nil { - return nil, wrapBackendError(err, target.WorkloadID, "create client") + return "", wrapBackendError(err, target.WorkloadID, "create client") } defer func() { if err := c.Close(); err != nil { @@ -730,7 +676,7 @@ func (h *httpBackendClient) GetPrompt( // Initialize the client if _, err := initializeClient(ctx, c); err != nil { - return nil, wrapBackendError(err, target.WorkloadID, "initialize client") + return "", wrapBackendError(err, target.WorkloadID, "initialize client") } // Get the prompt using the original prompt name from the backend's perspective. @@ -753,7 +699,7 @@ func (h *httpBackendClient) GetPrompt( }, }) if err != nil { - return nil, fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) + return "", fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) } // Concatenate all prompt messages into a single string @@ -770,12 +716,5 @@ func (h *httpBackendClient) GetPrompt( // TODO: Handle other content types (image, audio, resource) } - // Extract _meta field from backend response - meta := conversion.FromMCPMeta(result.Meta) - - return &vmcp.PromptGetResult{ - Messages: prompt, - Description: result.Description, - Meta: meta, - }, nil + return prompt, nil } diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 0845118ee1..6dfa659512 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -68,6 +70,9 @@ type DefaultManager struct { stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup + // singleFlight ensures only one aggregation happens per cache key at a time + // This prevents concurrent requests from all triggering aggregation + singleFlight singleflight.Group } // NewManager creates a new discovery manager with the given aggregator. @@ -131,6 +136,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. +// +// This method uses singleflight to ensure that concurrent requests for the same +// cache key only trigger one aggregation, preventing duplicate work. func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { // Validate user identity is present (set by auth middleware) // This ensures discovery happens with proper user authentication context @@ -142,7 +150,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) // Generate cache key from user identity and backend set cacheKey := m.generateCacheKey(identity.Subject, backends) - // Check cache first + // Check cache first (with read lock) if caps := m.getCachedCapabilities(cacheKey); caps != nil { logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) return caps, nil @@ -150,16 +158,33 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) - // Cache miss - perform aggregation - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + // Use singleflight to ensure only one aggregation happens per cache key + // Even if multiple requests come in concurrently, they'll all wait for the same result + result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) { + // Double-check cache after acquiring singleflight lock + // Another goroutine might have populated it while we were waiting + if caps := m.getCachedCapabilities(cacheKey); caps != nil { + logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject) + return caps, nil + } + + // Perform aggregation + caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + } + + // Cache the result (skips caching if at capacity and key doesn't exist) + m.cacheCapabilities(cacheKey, caps) + + return caps, nil + }) + if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + return nil, err } - // Cache the result (skips caching if at capacity and key doesn't exist) - m.cacheCapabilities(cacheKey, caps) - - return caps, nil + return result.(*aggregator.AggregatedCapabilities), nil } // Stop gracefully stops the manager and cleans up resources. diff --git a/pkg/vmcp/discovery/manager_test_coverage.go b/pkg/vmcp/discovery/manager_test_coverage.go new file mode 100644 index 0000000000..3826fc2849 --- /dev/null +++ b/pkg/vmcp/discovery/manager_test_coverage.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package discovery + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" + vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestDefaultManager_CacheVersionMismatch tests cache invalidation on version mismatch +func TestDefaultManager_CacheVersionMismatch(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + mockRegistry := vmcpmocks.NewMockDynamicRegistry(ctrl) + + // First call - version 1 + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManagerWithRegistry(mockAggregator, mockRegistry) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery - should cache + caps1, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps1) + + // Second discovery with same version - should use cache + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(1) + caps2, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps2) + + // Third discovery with different version - should invalidate cache + mockRegistry.EXPECT().Version().Return(uint64(2)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + caps3, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps3) +} + +// TestDefaultManager_CacheAtCapacity tests cache eviction when at capacity +func TestDefaultManager_CacheAtCapacity(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // Create many different cache keys to fill cache + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(maxCacheSize + 1) // One more than capacity + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + // Fill cache to capacity + for i := 0; i < maxCacheSize; i++ { + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i)), + }) + + backends := []vmcp.Backend{ + {ID: "backend-" + string(rune(i)), Name: "Backend"}, + } + + _, err := manager.Discover(ctx, backends) + require.NoError(t, err) + } + + // Next discovery should not cache (at capacity) + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-new", + }) + + backends := []vmcp.Backend{ + {ID: "backend-new", Name: "Backend"}, + } + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} + +// TestDefaultManager_CacheAtCapacity_ExistingKey tests cache update when at capacity but key exists +func TestDefaultManager_CacheAtCapacity_ExistingKey(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // First call + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) + + // Fill cache to capacity with other keys + for i := 0; i < maxCacheSize-1; i++ { + ctxOther := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i+2)), + }) + + backendsOther := []vmcp.Backend{ + {ID: "backend-" + string(rune(i+2)), Name: "Backend"}, + } + + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err := manager.Discover(ctxOther, backendsOther) + require.NoError(t, err) + } + + // Update existing key should work even at capacity + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index d593aa0401..bf6f5c329c 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -11,6 +11,8 @@ import ( "context" "errors" "fmt" + "net/url" + "strings" "time" "github.com/stacklok/toolhive/pkg/logger" @@ -29,6 +31,10 @@ type healthChecker struct { // If a health check succeeds but takes longer than this duration, the backend is marked degraded. // Zero means disabled (backends will never be marked degraded based on response time alone). degradedThreshold time.Duration + + // selfURL is the server's own URL. If a health check targets this URL, it's short-circuited. + // This prevents the server from trying to health check itself. + selfURL string } // NewHealthChecker creates a new health checker that uses BackendClient.ListCapabilities @@ -39,13 +45,20 @@ type healthChecker struct { // - client: BackendClient for communicating with backend MCP servers // - timeout: Maximum duration for health check operations (0 = no timeout) // - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled) +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns a new HealthChecker implementation. -func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration) vmcp.HealthChecker { +func NewHealthChecker( + client vmcp.BackendClient, + timeout time.Duration, + degradedThreshold time.Duration, + selfURL string, +) vmcp.HealthChecker { return &healthChecker{ client: client, timeout: timeout, degradedThreshold: degradedThreshold, + selfURL: selfURL, } } @@ -62,16 +75,28 @@ func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degraded // The error return is informational and provides context about what failed. // The BackendHealthStatus return indicates the categorized health state. func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { - // Apply timeout if configured - checkCtx := ctx + // Mark context as health check to bypass authentication logging + // Health checks verify backend availability and should not require user credentials + healthCheckCtx := WithHealthCheckMarker(ctx) + + // Apply timeout if configured (after adding health check marker) + checkCtx := healthCheckCtx var cancel context.CancelFunc if h.timeout > 0 { - checkCtx, cancel = context.WithTimeout(ctx, h.timeout) + checkCtx, cancel = context.WithTimeout(healthCheckCtx, h.timeout) defer cancel() } logger.Debugf("Performing health check for backend %s (%s)", target.WorkloadName, target.BaseURL) + // Short-circuit health check if targeting ourselves + // This prevents the server from trying to health check itself, which would work + // but is wasteful and can cause connection issues during startup + if h.selfURL != "" && h.isSelfCheck(target.BaseURL) { + logger.Debugf("Skipping health check for backend %s - this is the server itself", target.WorkloadName) + return vmcp.BackendHealthy, nil + } + // Track response time for degraded detection startTime := time.Now() @@ -137,3 +162,62 @@ func categorizeError(err error) vmcp.BackendHealthStatus { // Default to unhealthy for unknown errors return vmcp.BackendUnhealthy } + +// isSelfCheck checks if a backend URL matches the server's own URL. +// URLs are normalized before comparison to handle variations like: +// - http://127.0.0.1:PORT vs http://localhost:PORT +// - http://HOST:PORT vs http://HOST:PORT/ +func (h *healthChecker) isSelfCheck(backendURL string) bool { + if h.selfURL == "" || backendURL == "" { + return false + } + + // Normalize both URLs for comparison + backendNormalized, err := NormalizeURLForComparison(backendURL) + if err != nil { + return false + } + + selfNormalized, err := NormalizeURLForComparison(h.selfURL) + if err != nil { + return false + } + + return backendNormalized == selfNormalized +} + +// NormalizeURLForComparison normalizes a URL for comparison by: +// - Parsing and reconstructing the URL +// - Converting localhost/127.0.0.1 to a canonical form +// - Comparing only scheme://host:port (ignoring path, query, fragment) +// - Lowercasing scheme and host +// Exported for testing purposes +func NormalizeURLForComparison(rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + // Validate that we have a scheme and host (basic URL validation) + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("invalid URL: missing scheme or host") + } + + // Normalize host: convert localhost to 127.0.0.1 for consistency + host := strings.ToLower(u.Hostname()) + if host == "localhost" { + host = "127.0.0.1" + } + + // Reconstruct URL with normalized components (scheme://host:port only) + // We ignore path, query, and fragment for comparison + normalized := &url.URL{ + Scheme: strings.ToLower(u.Scheme), + } + if u.Port() != "" { + normalized.Host = host + ":" + u.Port() + } else { + normalized.Host = host + } + + return normalized.String(), nil +} diff --git a/pkg/vmcp/health/checker_selfcheck_test.go b/pkg/vmcp/health/checker_selfcheck_test.go new file mode 100644 index 0000000000..ff963d8d35 --- /dev/null +++ b/pkg/vmcp/health/checker_selfcheck_test.go @@ -0,0 +1,504 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package health + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestHealthChecker_CheckHealth_SelfCheck tests self-check detection +func TestHealthChecker_CheckHealth_SelfCheck(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + // Should not call ListCapabilities for self-check + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // Same as selfURL + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Localhost tests localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Localhost(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://localhost:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // localhost should match 127.0.0.1 + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Reverse tests reverse localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Reverse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", // 127.0.0.1 should match localhost + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_DifferentPort tests different ports don't match +func TestHealthChecker_CheckHealth_SelfCheck_DifferentPort(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8081", // Different port + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_EmptyURL tests empty URLs +func TestHealthChecker_CheckHealth_SelfCheck_EmptyURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_InvalidURL tests invalid URLs +func TestHealthChecker_CheckHealth_SelfCheck_InvalidURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "not-a-valid-url") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_WithPath tests URLs with paths are normalized +func TestHealthChecker_CheckHealth_SelfCheck_WithPath(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080/mcp", // Path should be ignored + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_DegradedThreshold tests degraded threshold detection +func TestHealthChecker_CheckHealth_DegradedThreshold(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendDegraded, status, "Should mark as degraded when response time exceeds threshold") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_Disabled tests disabled degraded threshold +func TestHealthChecker_CheckHealth_DegradedThreshold_Disabled(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 0 (disabled) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when threshold is disabled") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse tests fast response doesn't trigger degraded +func TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when response is fast") +} + +// TestCategorizeError_SentinelErrors tests sentinel error categorization +func TestCategorizeError_SentinelErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + expectedStatus vmcp.BackendHealthStatus + }{ + { + name: "ErrAuthenticationFailed", + err: vmcp.ErrAuthenticationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrAuthorizationFailed", + err: vmcp.ErrAuthorizationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrTimeout", + err: vmcp.ErrTimeout, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrCancelled", + err: vmcp.ErrCancelled, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrBackendUnavailable", + err: vmcp.ErrBackendUnavailable, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "wrapped ErrAuthenticationFailed", + err: errors.New("wrapped: " + vmcp.ErrAuthenticationFailed.Error()), + expectedStatus: vmcp.BackendUnauthenticated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + status := categorizeError(tt.err) + assert.Equal(t, tt.expectedStatus, status) + }) + } +} + +// TestNormalizeURLForComparison tests URL normalization +func TestNormalizeURLForComparison(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "localhost normalized to 127.0.0.1", + input: "http://localhost:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "127.0.0.1 stays as is", + input: "http://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "path is ignored", + input: "http://127.0.0.1:8080/mcp", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "query is ignored", + input: "http://127.0.0.1:8080?param=value", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "fragment is ignored", + input: "http://127.0.0.1:8080#fragment", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "scheme is lowercased", + input: "HTTP://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "host is lowercased", + input: "http://EXAMPLE.COM:8080", + expected: "http://example.com:8080", + wantErr: false, + }, + { + name: "no port", + input: "http://127.0.0.1", + expected: "http://127.0.0.1", + wantErr: false, + }, + { + name: "invalid URL", + input: "not-a-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := NormalizeURLForComparison(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestIsSelfCheck_EdgeCases tests edge cases for self-check detection +func TestIsSelfCheck_EdgeCases(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(func() { ctrl.Finish() }) + + mockClient := mocks.NewMockBackendClient(ctrl) + + tests := []struct { + name string + selfURL string + backendURL string + expected bool + }{ + { + name: "both empty", + selfURL: "", + backendURL: "", + expected: false, + }, + { + name: "selfURL empty", + selfURL: "", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "backendURL empty", + selfURL: "http://127.0.0.1:8080", + backendURL: "", + expected: false, + }, + { + name: "localhost matches 127.0.0.1", + selfURL: "http://localhost:8080", + backendURL: "http://127.0.0.1:8080", + expected: true, + }, + { + name: "127.0.0.1 matches localhost", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://localhost:8080", + expected: true, + }, + { + name: "different ports", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8081", + expected: false, + }, + { + name: "different hosts", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://192.168.1.1:8080", + expected: false, + }, + { + name: "path ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080/mcp", + expected: true, + }, + { + name: "query ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080?param=value", + expected: true, + }, + { + name: "invalid selfURL", + selfURL: "not-a-url", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "invalid backendURL", + selfURL: "http://127.0.0.1:8080", + backendURL: "not-a-url", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, tt.selfURL) + hc, ok := checker.(*healthChecker) + require.True(t, ok) + + result := hc.isSelfCheck(tt.backendURL) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index 50f00d788d..62aea9b735 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -108,12 +108,14 @@ func DefaultConfig() MonitorConfig { // - client: BackendClient for communicating with backend MCP servers // - backends: List of backends to monitor // - config: Configuration for health monitoring +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns (monitor, error). Error is returned if configuration is invalid. func NewMonitor( client vmcp.BackendClient, backends []vmcp.Backend, config MonitorConfig, + selfURL string, ) (*Monitor, error) { // Validate configuration if config.CheckInterval <= 0 { @@ -123,8 +125,8 @@ func NewMonitor( return nil, fmt.Errorf("unhealthy threshold must be >= 1, got %d", config.UnhealthyThreshold) } - // Create health checker with degraded threshold - checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold) + // Create health checker with degraded threshold and self URL + checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold, selfURL) // Create status tracker statusTracker := newStatusTracker(config.UnhealthyThreshold) diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index ca73e206f2..e48d3fdea5 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -89,8 +89,9 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe } for _, pod := range podList.Items { + // Skip pods that are not running (e.g., Succeeded, Failed from old deployments) if pod.Status.Phase != corev1.PodRunning { - return fmt.Errorf("pod %s is in phase %s", pod.Name, pod.Status.Phase) + continue } containerReady := false @@ -114,6 +115,17 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe return fmt.Errorf("pod %s not ready", pod.Name) } } + + // After filtering, ensure we found at least one running pod + runningPods := 0 + for _, pod := range podList.Items { + if pod.Status.Phase == corev1.PodRunning { + runningPods++ + } + } + if runningPods == 0 { + return fmt.Errorf("no running pods found with labels %v", labels) + } return nil } diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go index e7e33fd623..18af2c94df 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go @@ -1162,12 +1162,28 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: } It("should list and call tools from all backends with discovered auth", func() { + By("Verifying vMCP pods are still running and ready before health check") + vmcpLabels := map[string]string{ + "app.kubernetes.io/name": "virtualmcpserver", + "app.kubernetes.io/instance": vmcpServerName, + } + WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second) + + // Create HTTP client with timeout for health checks + healthCheckClient := &http.Client{ + Timeout: 10 * time.Second, + } + By("Verifying HTTP connectivity to VirtualMCPServer health endpoint") Eventually(func() error { + // Re-check pod readiness before each health check attempt + if err := checkPodsReady(ctx, k8sClient, testNamespace, vmcpLabels); err != nil { + return fmt.Errorf("pods not ready: %w", err) + } url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := http.Get(url) + resp, err := healthCheckClient.Get(url) if err != nil { - return err + return fmt.Errorf("health check failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { From 8aa930b67c0525f1d5a30734bf465e9117779328 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 12:19:20 +0000 Subject: [PATCH 02/16] fix: Update CallTool and GetPrompt signatures to match BackendClient interface - Add conversion import for meta field handling - Update CallTool to accept meta parameter and return *vmcp.ToolCallResult - Update GetPrompt to return *vmcp.PromptGetResult - Add convertContent helper function --- pkg/vmcp/client/client.go | 126 +++++++++++++++++++++++++------------- 1 file changed, 84 insertions(+), 42 deletions(-) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 0634376de6..cab1de9056 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/conversion" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" ) @@ -372,6 +373,36 @@ func queryPrompts(ctx context.Context, c *client.Client, supported bool, backend return &mcp.ListPromptsResult{Prompts: []mcp.Prompt{}}, nil } +// convertContent converts mcp.Content to vmcp.Content. +// This preserves the full content structure from backend responses. +func convertContent(content mcp.Content) vmcp.Content { + if textContent, ok := mcp.AsTextContent(content); ok { + return vmcp.Content{ + Type: "text", + Text: textContent.Text, + } + } + if imageContent, ok := mcp.AsImageContent(content); ok { + return vmcp.Content{ + Type: "image", + Data: imageContent.Data, + MimeType: imageContent.MIMEType, + } + } + if audioContent, ok := mcp.AsAudioContent(content); ok { + return vmcp.Content{ + Type: "audio", + Data: audioContent.Data, + MimeType: audioContent.MIMEType, + } + } + // Handle embedded resources if needed + // Unknown content types are marked as "unknown" type with no data + logger.Warnf("Encountered unknown content type %T, marking as unknown content. "+ + "This may indicate missing support for embedded resources or other MCP content types.", content) + return vmcp.Content{Type: "unknown"} +} + // ListCapabilities queries a backend for its MCP capabilities. // Returns tools, resources, and prompts exposed by the backend. // Only queries capabilities that the server advertises during initialization. @@ -487,6 +518,7 @@ func (h *httpBackendClient) ListCapabilities(ctx context.Context, target *vmcp.B } // CallTool invokes a tool on the backend MCP server. +// Returns the complete tool result including _meta field. // //nolint:gocyclo // this function is complex because it handles tool calls with various content types and error handling. func (h *httpBackendClient) CallTool( @@ -494,7 +526,8 @@ func (h *httpBackendClient) CallTool( target *vmcp.BackendTarget, toolName string, arguments map[string]any, -) (map[string]any, error) { + meta map[string]any, +) (*vmcp.ToolCallResult, error) { logger.Debugf("Calling tool %s on backend %s", toolName, target.WorkloadName) // Create a client for this backend @@ -525,6 +558,7 @@ func (h *httpBackendClient) CallTool( Params: mcp.CallToolParams{ Name: backendToolName, Arguments: arguments, + Meta: conversion.ToMCPMeta(meta), }, }) if err != nil { @@ -532,9 +566,12 @@ func (h *httpBackendClient) CallTool( return nil, fmt.Errorf("%w: tool call failed on backend %s: %w", vmcp.ErrBackendUnavailable, target.WorkloadID, err) } - // Check if the tool call returned an error (MCP domain error) + // Extract _meta field from backend response + responseMeta := conversion.FromMCPMeta(result.Meta) + + // Log if tool returned IsError=true (MCP protocol-level error, not a transport error) + // We still return the full result to preserve metadata and error details for the client if result.IsError { - // Extract error message from content for logging and forwarding var errorMsg string if len(result.Content) > 0 { if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { @@ -542,56 +579,53 @@ func (h *httpBackendClient) CallTool( } } if errorMsg == "" { - errorMsg = "unknown error" + errorMsg = "tool execution error" } - logger.Warnf("Tool %s on backend %s returned error: %s", toolName, target.WorkloadID, errorMsg) - // Wrap with ErrToolExecutionFailed so router can forward transparently to client - return nil, fmt.Errorf("%w: %s on backend %s: %s", vmcp.ErrToolExecutionFailed, toolName, target.WorkloadID, errorMsg) + + // Log with metadata for distributed tracing + if responseMeta != nil { + logger.Warnf("Tool %s on backend %s returned IsError=true: %s (meta: %+v)", + toolName, target.WorkloadID, errorMsg, responseMeta) + } else { + logger.Warnf("Tool %s on backend %s returned IsError=true: %s", toolName, target.WorkloadID, errorMsg) + } + // Continue processing - we return the result with IsError flag and metadata preserved + } + + // Convert MCP content to vmcp.Content array + contentArray := make([]vmcp.Content, len(result.Content)) + for i, content := range result.Content { + contentArray[i] = convertContent(content) } // Check for structured content first (preferred for composite tool step chaining). // StructuredContent allows templates to access nested fields directly via {{.steps.stepID.output.field}}. // Note: StructuredContent must be an object (map). Arrays or primitives are not supported. + var structuredContent map[string]any if result.StructuredContent != nil { if structuredMap, ok := result.StructuredContent.(map[string]any); ok { logger.Debugf("Using structured content from tool %s on backend %s", toolName, target.WorkloadID) - return structuredMap, nil + structuredContent = structuredMap + } else { + // StructuredContent is not an object - fall through to Content processing + logger.Debugf("StructuredContent from tool %s on backend %s is not an object, falling back to Content", + toolName, target.WorkloadID) } - // StructuredContent is not an object - fall through to Content processing - logger.Debugf("StructuredContent from tool %s on backend %s is not an object, falling back to Content", - toolName, target.WorkloadID) } - // Fallback: Convert result contents to a map. + // If no structured content, convert result contents to a map for backward compatibility. // MCP tools return an array of Content interface (TextContent, ImageContent, etc.). // Text content is stored under "text" key, accessible via {{.steps.stepID.output.text}}. - resultMap := make(map[string]any) - if len(result.Content) > 0 { - textIndex := 0 - imageIndex := 0 - for i, content := range result.Content { - // Try to convert to TextContent - if textContent, ok := mcp.AsTextContent(content); ok { - key := "text" - if textIndex > 0 { - key = fmt.Sprintf("text_%d", textIndex) - } - resultMap[key] = textContent.Text - textIndex++ - } else if imageContent, ok := mcp.AsImageContent(content); ok { - // Convert to ImageContent - key := fmt.Sprintf("image_%d", imageIndex) - resultMap[key] = imageContent.Data - imageIndex++ - } else { - // Log unsupported content types for tracking - logger.Debugf("Unsupported content type at index %d from tool %s on backend %s: %T", - i, toolName, target.WorkloadID, content) - } - } + if structuredContent == nil { + structuredContent = conversion.ContentArrayToMap(contentArray) } - return resultMap, nil + return &vmcp.ToolCallResult{ + Content: contentArray, + StructuredContent: structuredContent, + IsError: result.IsError, + Meta: responseMeta, + }, nil } // ReadResource retrieves a resource from the backend MCP server. @@ -655,18 +689,19 @@ func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.Backe } // GetPrompt retrieves a prompt from the backend MCP server. +// Returns the complete prompt result including _meta field. func (h *httpBackendClient) GetPrompt( ctx context.Context, target *vmcp.BackendTarget, name string, arguments map[string]any, -) (string, error) { +) (*vmcp.PromptGetResult, error) { logger.Debugf("Getting prompt %s from backend %s", name, target.WorkloadName) // Create a client for this backend c, err := h.clientFactory(ctx, target) if err != nil { - return "", wrapBackendError(err, target.WorkloadID, "create client") + return nil, wrapBackendError(err, target.WorkloadID, "create client") } defer func() { if err := c.Close(); err != nil { @@ -676,7 +711,7 @@ func (h *httpBackendClient) GetPrompt( // Initialize the client if _, err := initializeClient(ctx, c); err != nil { - return "", wrapBackendError(err, target.WorkloadID, "initialize client") + return nil, wrapBackendError(err, target.WorkloadID, "initialize client") } // Get the prompt using the original prompt name from the backend's perspective. @@ -699,7 +734,7 @@ func (h *httpBackendClient) GetPrompt( }, }) if err != nil { - return "", fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) + return nil, fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) } // Concatenate all prompt messages into a single string @@ -716,5 +751,12 @@ func (h *httpBackendClient) GetPrompt( // TODO: Handle other content types (image, audio, resource) } - return prompt, nil + // Extract _meta field from backend response + meta := conversion.FromMCPMeta(result.Meta) + + return &vmcp.PromptGetResult{ + Messages: prompt, + Description: result.Description, + Meta: meta, + }, nil } From a6cb6d0b3f1140642d9414aac36567ed23904df7 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 12:19:31 +0000 Subject: [PATCH 03/16] fix: Update ReadResource signature to match BackendClient interface - Update ReadResource to return *vmcp.ResourceReadResult instead of []byte - Extract and include meta field from backend response - Include MIME type in result --- pkg/vmcp/client/client.go | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index cab1de9056..1ec9d318ac 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -629,7 +629,10 @@ func (h *httpBackendClient) CallTool( } // ReadResource retrieves a resource from the backend MCP server. -func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.BackendTarget, uri string) ([]byte, error) { +// Returns the complete resource result including _meta field. +func (h *httpBackendClient) ReadResource( + ctx context.Context, target *vmcp.BackendTarget, uri string, +) (*vmcp.ResourceReadResult, error) { logger.Debugf("Reading resource %s from backend %s", uri, target.WorkloadName) // Create a client for this backend @@ -667,10 +670,14 @@ func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.Backe // Concatenate all resource contents // MCP resources can have multiple contents (text or blob) var data []byte - for _, content := range result.Contents { + var mimeType string + for i, content := range result.Contents { // Try to convert to TextResourceContents if textContent, ok := mcp.AsTextResourceContents(content); ok { data = append(data, []byte(textContent.Text)...) + if i == 0 && textContent.MIMEType != "" { + mimeType = textContent.MIMEType + } } else if blobContent, ok := mcp.AsBlobResourceContents(content); ok { // Blob is base64-encoded per MCP spec, decode it to bytes decoded, err := base64.StdEncoding.DecodeString(blobContent.Blob) @@ -682,10 +689,20 @@ func (h *httpBackendClient) ReadResource(ctx context.Context, target *vmcp.Backe } else { data = append(data, decoded...) } + if i == 0 && blobContent.MIMEType != "" { + mimeType = blobContent.MIMEType + } } } - return data, nil + // Extract _meta field from backend response + meta := conversion.FromMCPMeta(result.Meta) + + return &vmcp.ResourceReadResult{ + Contents: data, + MimeType: mimeType, + Meta: meta, + }, nil } // GetPrompt retrieves a prompt from the backend MCP server. From e97fcee51d1459228bca0cb887a32cc8fff351d7 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 12:20:03 +0000 Subject: [PATCH 04/16] fix: Pass selfURL parameter to health.NewMonitor - Construct selfURL from Host, Port, and EndpointPath - Prevents health checker from checking itself --- pkg/vmcp/server/server.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 3ccecdf39c..a57eef5613 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -340,7 +340,9 @@ func New( if cfg.HealthMonitorConfig != nil { // Get initial backends list from registry for health monitoring setup initialBackends := backendRegistry.List(ctx) - healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig) + // Construct selfURL to prevent health checker from checking itself + selfURL := fmt.Sprintf("http://%s:%d%s", cfg.Host, cfg.Port, cfg.EndpointPath) + healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig, selfURL) if err != nil { return nil, fmt.Errorf("failed to create health monitor: %w", err) } From 3361c90d7f2fb0aa463b6007fe7980d0e90f81e6 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 12:32:06 +0000 Subject: [PATCH 05/16] Fix NewHealthChecker calls in checker_test.go to include selfURL parameter --- pkg/vmcp/health/checker_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index 39f7258d82..b3dcf906bd 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -40,11 +40,11 @@ func TestNewHealthChecker(t *testing.T) { }, } - for _, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - checker := NewHealthChecker(mockClient, tt.timeout, 0) + checker := NewHealthChecker(mockClient, tt.timeout, 0, "") require.NotNil(t, checker) // Type assert to access internals for verification @@ -68,7 +68,7 @@ func TestHealthChecker_CheckHealth_Success(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -95,7 +95,7 @@ func TestHealthChecker_CheckHealth_ContextCancellation(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -123,7 +123,7 @@ func TestHealthChecker_CheckHealth_NoTimeout(t *testing.T) { Times(1) // Create checker with no timeout - checker := NewHealthChecker(mockClient, 0, 0) + checker := NewHealthChecker(mockClient, 0, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -213,7 +213,7 @@ func TestHealthChecker_CheckHealth_ErrorCategorization(t *testing.T) { Return(nil, tt.err). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -430,7 +430,7 @@ func TestHealthChecker_CheckHealth_Timeout(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -467,7 +467,7 @@ func TestHealthChecker_CheckHealth_MultipleBackends(t *testing.T) { }). Times(4) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") // Test healthy backend status, err := checker.CheckHealth(context.Background(), &vmcp.BackendTarget{ From e42e2acc3d0d1509706b2b501e3b4e3c06b55155 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 13:13:47 +0000 Subject: [PATCH 06/16] Fix NewMonitor calls in monitor_test.go to include selfURL parameter All 10 calls to NewMonitor in monitor_test.go were missing the new selfURL parameter that was added to the function signature. This was causing compilation failures in CI. --- pkg/vmcp/health/monitor_test.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index bb177017e7..95e0459ee5 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -62,11 +62,11 @@ func TestNewMonitor_Validation(t *testing.T) { }, } - for _, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, tt.config) + monitor, err := NewMonitor(mockClient, backends, tt.config, "") if tt.expectError { assert.Error(t, err) assert.Nil(t, monitor) @@ -101,7 +101,7 @@ func TestMonitor_StartStop(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start monitor @@ -178,7 +178,7 @@ func TestMonitor_StartErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) err = tt.setupFunc(monitor) @@ -208,7 +208,7 @@ func TestMonitor_StopWithoutStart(t *testing.T) { Timeout: 50 * time.Millisecond, } - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Try to stop without starting @@ -239,7 +239,7 @@ func TestMonitor_PeriodicHealthChecks(t *testing.T) { Return(nil, errors.New("backend unavailable")). MinTimes(2) - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -289,7 +289,7 @@ func TestMonitor_GetHealthSummary(t *testing.T) { }). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -333,7 +333,7 @@ func TestMonitor_GetBackendStatus(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -382,7 +382,7 @@ func TestMonitor_GetBackendState(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -433,7 +433,7 @@ func TestMonitor_GetAllBackendStates(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -477,7 +477,7 @@ func TestMonitor_ContextCancellation(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start with cancellable context From 90db15e76ef5a4dffa9e1a00554ce1e3417ff7bf Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 13:52:45 +0000 Subject: [PATCH 07/16] Fix Go import formatting issues (gci linter) Fixed import ordering in: - pkg/vmcp/client/client.go - pkg/vmcp/health/checker_test.go - pkg/vmcp/health/monitor_test.go --- pkg/vmcp/client/client.go | 6 +++--- pkg/vmcp/health/checker_test.go | 2 +- pkg/vmcp/health/monitor_test.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 1ec9d318ac..c72639f3b0 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -24,9 +24,9 @@ import ( "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/conversion" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" + "github.com/stacklok/toolhive/pkg/vmcp/conversion" ) const ( @@ -700,8 +700,8 @@ func (h *httpBackendClient) ReadResource( return &vmcp.ResourceReadResult{ Contents: data, - MimeType: mimeType, - Meta: meta, + MimeType: mimeType, + Meta: meta, }, nil } diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index b3dcf906bd..63c3c986b6 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -40,7 +40,7 @@ func TestNewHealthChecker(t *testing.T) { }, } - for _, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index 95e0459ee5..8d2de11bdd 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -62,7 +62,7 @@ func TestNewMonitor_Validation(t *testing.T) { }, } - for _, tt := range tests { + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() From 628f1012b563982b276fe5e85aaa3576c172ac8a Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 13:54:12 +0000 Subject: [PATCH 08/16] Fix Chart.yaml version - restore to 0.0.103 The version was incorrectly downgraded to 0.0.102. Restore it to 0.0.103 to match main branch. --- deploy/charts/operator-crds/Chart.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml index 1b14897d71..e336674530 100644 --- a/deploy/charts/operator-crds/Chart.yaml +++ b/deploy/charts/operator-crds/Chart.yaml @@ -2,5 +2,5 @@ apiVersion: v2 name: toolhive-operator-crds description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. type: application -version: 0.0.102 +version: 0.0.103 appVersion: "0.0.1" From 78f632b71fdb2b1a8bd35bed794d57814cb22561 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 13:55:16 +0000 Subject: [PATCH 09/16] Bump Chart.yaml version to 0.0.104 The chart-testing tool requires version bumps to be higher than the base branch version (0.0.103). --- deploy/charts/operator-crds/Chart.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml index e336674530..e833e01b3c 100644 --- a/deploy/charts/operator-crds/Chart.yaml +++ b/deploy/charts/operator-crds/Chart.yaml @@ -2,5 +2,5 @@ apiVersion: v2 name: toolhive-operator-crds description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. type: application -version: 0.0.103 +version: 0.0.104 appVersion: "0.0.1" From b2c8f0d607c5ff06910290d495fbdbe2c046b622 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Mon, 26 Jan 2026 13:57:16 +0000 Subject: [PATCH 10/16] Update README.md version badge to 0.0.104 Match the Chart.yaml version update to satisfy helm-docs pre-commit hook. --- deploy/charts/operator-crds/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/charts/operator-crds/README.md b/deploy/charts/operator-crds/README.md index 93948d1568..9c63449337 100644 --- a/deploy/charts/operator-crds/README.md +++ b/deploy/charts/operator-crds/README.md @@ -1,6 +1,6 @@ # ToolHive Operator CRDs Helm Chart -![Version: 0.0.103](https://img.shields.io/badge/Version-0.0.103-informational?style=flat-square) +![Version: 0.0.104](https://img.shields.io/badge/Version-0.0.104-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. From 58e3d0b0aedf3d1ee0eccf8456f7eefb3fbb1436 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 11:06:14 +0000 Subject: [PATCH 11/16] Refactor vMCP tracing and remove health checker self-check Move telemetry provider initialization earlier in vmcp serve command to enable distributed tracing in the aggregator. The aggregator now accepts an explicit tracer provider parameter instead of using the global otel tracer, following dependency injection best practices. Improve tracing error handling by using named return values and deferred error recording in aggregator methods, ensuring errors are properly captured in traces. Remove health checker self-check functionality that prevented the server from checking its own health endpoint. This simplifies the implementation and removes unnecessary URL normalization logic. Changes: - Add tracerProvider parameter to aggregator.NewDefaultAggregator - Use noop tracer when provider is nil - Improve span error handling with deferred recording - Remove selfURL parameter from health.NewHealthChecker - Delete pkg/vmcp/health/checker_selfcheck_test.go - Update all tests to match new function signatures - Add debug logging for auth strategy application in client --- cmd/vmcp/app/commands.go | 41 +- pkg/vmcp/aggregator/default_aggregator.go | 74 ++- .../aggregator/default_aggregator_test.go | 18 +- pkg/vmcp/client/client.go | 2 + pkg/vmcp/health/checker.go | 76 --- pkg/vmcp/health/checker_selfcheck_test.go | 504 ------------------ pkg/vmcp/health/checker_test.go | 14 +- pkg/vmcp/health/monitor.go | 6 +- pkg/vmcp/health/monitor_test.go | 20 +- pkg/vmcp/server/integration_test.go | 9 +- pkg/vmcp/server/server.go | 4 +- test/e2e/thv-operator/virtualmcp/helpers.go | 10 +- test/integration/vmcp/helpers/vmcp_server.go | 2 +- 13 files changed, 115 insertions(+), 665 deletions(-) delete mode 100644 pkg/vmcp/health/checker_selfcheck_test.go diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index f9c0aa8a70..aa55537e3a 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "go.opentelemetry.io/otel/trace" "k8s.io/client-go/rest" "github.com/stacklok/toolhive/pkg/audit" @@ -310,8 +311,27 @@ func runServe(cmd *cobra.Command, _ []string) error { return fmt.Errorf("failed to create conflict resolver: %w", err) } - // Create aggregator - agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, cfg.Aggregation.Tools) + // If telemetry is configured, create the provider early so aggregator can use it + var telemetryProvider *telemetry.Provider + if cfg.Telemetry != nil { + telemetryProvider, err = telemetry.NewProvider(ctx, *cfg.Telemetry) + if err != nil { + return fmt.Errorf("failed to create telemetry provider: %w", err) + } + defer func() { + err := telemetryProvider.Shutdown(ctx) + if err != nil { + logger.Errorf("failed to shutdown telemetry provider: %v", err) + } + }() + } + + // Create aggregator with tracer provider (nil if telemetry not configured) + var tracerProvider trace.TracerProvider + if telemetryProvider != nil { + tracerProvider = telemetryProvider.TracerProvider() + } + agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, cfg.Aggregation.Tools, tracerProvider) // Use DynamicRegistry for version-based cache invalidation // Works in both standalone (CLI with YAML config) and Kubernetes (operator-deployed) modes @@ -381,21 +401,8 @@ func runServe(cmd *cobra.Command, _ []string) error { host, _ := cmd.Flags().GetString("host") port, _ := cmd.Flags().GetInt("port") - // If telemetry is configured, create the provider. - var telemetryProvider *telemetry.Provider - if cfg.Telemetry != nil { - var err error - telemetryProvider, err = telemetry.NewProvider(ctx, *cfg.Telemetry) - if err != nil { - return fmt.Errorf("failed to create telemetry provider: %w", err) - } - defer func() { - err := telemetryProvider.Shutdown(ctx) - if err != nil { - logger.Errorf("failed to shutdown telemetry provider: %v", err) - } - }() - } + // Note: telemetryProvider was already created earlier (before aggregator creation) + // to enable tracing in the aggregator // Configure health monitoring if enabled var healthMonitorConfig *health.MonitorConfig diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 3cf2846fcc..20985f9a1f 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -8,10 +8,10 @@ import ( "fmt" "sync" - "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/logger" @@ -31,10 +31,12 @@ type defaultAggregator struct { // NewDefaultAggregator creates a new default aggregator implementation. // conflictResolver handles tool name conflicts across backends. // workloadConfigs specifies per-backend tool filtering and overrides. +// tracerProvider is used to create a tracer for distributed tracing (pass nil for no tracing). func NewDefaultAggregator( backendClient vmcp.BackendClient, conflictResolver ConflictResolver, workloadConfigs []*config.WorkloadToolConfig, + tracerProvider trace.TracerProvider, ) Aggregator { // Build tool config map for quick lookup by backend ID toolConfigMap := make(map[string]*config.WorkloadToolConfig) @@ -44,23 +46,37 @@ func NewDefaultAggregator( } } + // Create tracer from provider (use noop tracer if provider is nil) + var tracer trace.Tracer + if tracerProvider != nil { + tracer = tracerProvider.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator") + } else { + tracer = noop.NewTracerProvider().Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator") + } + return &defaultAggregator{ backendClient: backendClient, conflictResolver: conflictResolver, toolConfigMap: toolConfigMap, - tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator"), + tracer: tracer, } } // QueryCapabilities queries a single backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. -func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { +func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (_ *BackendCapabilities, retErr error) { ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities", trace.WithAttributes( attribute.String("backend.id", backend.ID), ), ) - defer span.End() + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() logger.Debugf("Querying capabilities from backend %s", backend.ID) @@ -71,8 +87,6 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. // Query capabilities using the backend client capabilities, err := a.backendClient.ListCapabilities(ctx, target) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("%w: %s: %w", ErrBackendQueryFailed, backend.ID, err) } @@ -106,13 +120,19 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. func (a *defaultAggregator) QueryAllCapabilities( ctx context.Context, backends []vmcp.Backend, -) (map[string]*BackendCapabilities, error) { +) (_ map[string]*BackendCapabilities, retErr error) { ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities", trace.WithAttributes( attribute.Int("backends.count", len(backends)), ), ) - defer span.End() + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() logger.Infof("Querying capabilities from %d backends", len(backends)) @@ -146,16 +166,11 @@ func (a *defaultAggregator) QueryAllCapabilities( // Wait for all queries to complete if err := g.Wait(); err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("capability queries failed: %w", err) } if len(capabilities) == 0 { - err := fmt.Errorf("no backends returned capabilities") - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) - return nil, err + return nil, fmt.Errorf("no backends returned capabilities") } span.SetAttributes( @@ -171,13 +186,19 @@ func (a *defaultAggregator) QueryAllCapabilities( func (a *defaultAggregator) ResolveConflicts( ctx context.Context, capabilities map[string]*BackendCapabilities, -) (*ResolvedCapabilities, error) { +) (_ *ResolvedCapabilities, retErr error) { ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts", trace.WithAttributes( attribute.Int("backends.count", len(capabilities)), ), ) - defer span.End() + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) @@ -194,8 +215,6 @@ func (a *defaultAggregator) ResolveConflicts( if a.conflictResolver != nil { resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("conflict resolution failed: %w", err) } } else { @@ -383,13 +402,22 @@ func (a *defaultAggregator) MergeCapabilities( // 2. Query all backends // 3. Resolve conflicts // 4. Merge into final view with full backend information -func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { +func (a *defaultAggregator) AggregateCapabilities( + ctx context.Context, + backends []vmcp.Backend, +) (_ *AggregatedCapabilities, retErr error) { ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities", trace.WithAttributes( attribute.Int("backends.count", len(backends)), ), ) - defer span.End() + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() logger.Infof("Starting capability aggregation for %d backends", len(backends)) @@ -400,24 +428,18 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends // Step 2: Query all backends capabilities, err := a.QueryAllCapabilities(ctx, backends) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to query backends: %w", err) } // Step 3: Resolve conflicts resolved, err := a.ResolveConflicts(ctx, capabilities) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to resolve conflicts: %w", err) } // Step 4: Merge into final view with full backend information aggregated, err := a.MergeCapabilities(ctx, resolved, registry) if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to merge capabilities: %w", err) } diff --git a/pkg/vmcp/aggregator/default_aggregator_test.go b/pkg/vmcp/aggregator/default_aggregator_test.go index 3798f5324a..6d07c23af5 100644 --- a/pkg/vmcp/aggregator/default_aggregator_test.go +++ b/pkg/vmcp/aggregator/default_aggregator_test.go @@ -35,7 +35,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(expectedCaps, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.NoError(t, err) @@ -59,7 +59,7 @@ func TestDefaultAggregator_QueryCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryCapabilities(context.Background(), backend) require.Error(t, err) @@ -90,7 +90,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -120,7 +120,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { return nil, errors.New("connection timeout") }).Times(2) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.NoError(t, err) @@ -140,7 +140,7 @@ func TestDefaultAggregator_QueryAllCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()). Return(nil, errors.New("connection failed")) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.QueryAllCapabilities(context.Background(), backends) require.Error(t, err) @@ -171,7 +171,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -204,7 +204,7 @@ func TestDefaultAggregator_ResolveConflicts(t *testing.T) { }, } - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) resolved, err := agg.ResolveConflicts(context.Background(), capabilities) require.NoError(t, err) @@ -263,7 +263,7 @@ func TestDefaultAggregator_MergeCapabilities(t *testing.T) { } registry := vmcp.NewImmutableRegistry(backends) - agg := NewDefaultAggregator(nil, nil, nil) + agg := NewDefaultAggregator(nil, nil, nil, nil) aggregated, err := agg.MergeCapabilities(context.Background(), resolved, registry) require.NoError(t, err) @@ -331,7 +331,7 @@ func TestDefaultAggregator_AggregateCapabilities(t *testing.T) { mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps1, nil) mockClient.EXPECT().ListCapabilities(gomock.Any(), gomock.Any()).Return(caps2, nil) - agg := NewDefaultAggregator(mockClient, nil, nil) + agg := NewDefaultAggregator(mockClient, nil, nil, nil) result, err := agg.AggregateCapabilities(context.Background(), backends) require.NoError(t, err) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index c72639f3b0..3993ca6caa 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -169,6 +169,8 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm target.WorkloadID, err) } + logger.Debugf("Applied authentication strategy %q to backend %s", authStrategy.Name(), target.WorkloadID) + // Add authentication layer with pre-resolved strategy baseTransport = &authRoundTripper{ base: baseTransport, diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index bf6f5c329c..ccc3a8effc 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -11,8 +11,6 @@ import ( "context" "errors" "fmt" - "net/url" - "strings" "time" "github.com/stacklok/toolhive/pkg/logger" @@ -31,10 +29,6 @@ type healthChecker struct { // If a health check succeeds but takes longer than this duration, the backend is marked degraded. // Zero means disabled (backends will never be marked degraded based on response time alone). degradedThreshold time.Duration - - // selfURL is the server's own URL. If a health check targets this URL, it's short-circuited. - // This prevents the server from trying to health check itself. - selfURL string } // NewHealthChecker creates a new health checker that uses BackendClient.ListCapabilities @@ -45,20 +39,17 @@ type healthChecker struct { // - client: BackendClient for communicating with backend MCP servers // - timeout: Maximum duration for health check operations (0 = no timeout) // - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled) -// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns a new HealthChecker implementation. func NewHealthChecker( client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration, - selfURL string, ) vmcp.HealthChecker { return &healthChecker{ client: client, timeout: timeout, degradedThreshold: degradedThreshold, - selfURL: selfURL, } } @@ -89,14 +80,6 @@ func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTar logger.Debugf("Performing health check for backend %s (%s)", target.WorkloadName, target.BaseURL) - // Short-circuit health check if targeting ourselves - // This prevents the server from trying to health check itself, which would work - // but is wasteful and can cause connection issues during startup - if h.selfURL != "" && h.isSelfCheck(target.BaseURL) { - logger.Debugf("Skipping health check for backend %s - this is the server itself", target.WorkloadName) - return vmcp.BackendHealthy, nil - } - // Track response time for degraded detection startTime := time.Now() @@ -162,62 +145,3 @@ func categorizeError(err error) vmcp.BackendHealthStatus { // Default to unhealthy for unknown errors return vmcp.BackendUnhealthy } - -// isSelfCheck checks if a backend URL matches the server's own URL. -// URLs are normalized before comparison to handle variations like: -// - http://127.0.0.1:PORT vs http://localhost:PORT -// - http://HOST:PORT vs http://HOST:PORT/ -func (h *healthChecker) isSelfCheck(backendURL string) bool { - if h.selfURL == "" || backendURL == "" { - return false - } - - // Normalize both URLs for comparison - backendNormalized, err := NormalizeURLForComparison(backendURL) - if err != nil { - return false - } - - selfNormalized, err := NormalizeURLForComparison(h.selfURL) - if err != nil { - return false - } - - return backendNormalized == selfNormalized -} - -// NormalizeURLForComparison normalizes a URL for comparison by: -// - Parsing and reconstructing the URL -// - Converting localhost/127.0.0.1 to a canonical form -// - Comparing only scheme://host:port (ignoring path, query, fragment) -// - Lowercasing scheme and host -// Exported for testing purposes -func NormalizeURLForComparison(rawURL string) (string, error) { - u, err := url.Parse(rawURL) - if err != nil { - return "", err - } - // Validate that we have a scheme and host (basic URL validation) - if u.Scheme == "" || u.Host == "" { - return "", fmt.Errorf("invalid URL: missing scheme or host") - } - - // Normalize host: convert localhost to 127.0.0.1 for consistency - host := strings.ToLower(u.Hostname()) - if host == "localhost" { - host = "127.0.0.1" - } - - // Reconstruct URL with normalized components (scheme://host:port only) - // We ignore path, query, and fragment for comparison - normalized := &url.URL{ - Scheme: strings.ToLower(u.Scheme), - } - if u.Port() != "" { - normalized.Host = host + ":" + u.Port() - } else { - normalized.Host = host - } - - return normalized.String(), nil -} diff --git a/pkg/vmcp/health/checker_selfcheck_test.go b/pkg/vmcp/health/checker_selfcheck_test.go deleted file mode 100644 index ff963d8d35..0000000000 --- a/pkg/vmcp/health/checker_selfcheck_test.go +++ /dev/null @@ -1,504 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -package health - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/stacklok/toolhive/pkg/vmcp" - "github.com/stacklok/toolhive/pkg/vmcp/mocks" -) - -// TestHealthChecker_CheckHealth_SelfCheck tests self-check detection -func TestHealthChecker_CheckHealth_SelfCheck(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - // Should not call ListCapabilities for self-check - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Times(0) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8080", // Same as selfURL - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_Localhost tests localhost normalization -func TestHealthChecker_CheckHealth_SelfCheck_Localhost(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Times(0) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://localhost:8080") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8080", // localhost should match 127.0.0.1 - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_Reverse tests reverse localhost normalization -func TestHealthChecker_CheckHealth_SelfCheck_Reverse(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Times(0) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://localhost:8080", // 127.0.0.1 should match localhost - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_DifferentPort tests different ports don't match -func TestHealthChecker_CheckHealth_SelfCheck_DifferentPort(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Return(&vmcp.CapabilityList{}, nil). - Times(1) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8081", // Different port - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_EmptyURL tests empty URLs -func TestHealthChecker_CheckHealth_SelfCheck_EmptyURL(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Return(&vmcp.CapabilityList{}, nil). - Times(1) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8080", - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_InvalidURL tests invalid URLs -func TestHealthChecker_CheckHealth_SelfCheck_InvalidURL(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Return(&vmcp.CapabilityList{}, nil). - Times(1) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "not-a-valid-url") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8080", - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_SelfCheck_WithPath tests URLs with paths are normalized -func TestHealthChecker_CheckHealth_SelfCheck_WithPath(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Times(0) - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://127.0.0.1:8080/mcp", // Path should be ignored - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status) -} - -// TestHealthChecker_CheckHealth_DegradedThreshold tests degraded threshold detection -func TestHealthChecker_CheckHealth_DegradedThreshold(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { - // Simulate slow response - time.Sleep(150 * time.Millisecond) - return &vmcp.CapabilityList{}, nil - }). - Times(1) - - // Set degraded threshold to 100ms - checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://localhost:8080", - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendDegraded, status, "Should mark as degraded when response time exceeds threshold") -} - -// TestHealthChecker_CheckHealth_DegradedThreshold_Disabled tests disabled degraded threshold -func TestHealthChecker_CheckHealth_DegradedThreshold_Disabled(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { - // Simulate slow response - time.Sleep(150 * time.Millisecond) - return &vmcp.CapabilityList{}, nil - }). - Times(1) - - // Set degraded threshold to 0 (disabled) - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://localhost:8080", - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when threshold is disabled") -} - -// TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse tests fast response doesn't trigger degraded -func TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockClient := mocks.NewMockBackendClient(ctrl) - mockClient.EXPECT(). - ListCapabilities(gomock.Any(), gomock.Any()). - Return(&vmcp.CapabilityList{}, nil). - Times(1) - - // Set degraded threshold to 100ms - checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") - target := &vmcp.BackendTarget{ - WorkloadID: "backend-1", - WorkloadName: "test-backend", - BaseURL: "http://localhost:8080", - } - - status, err := checker.CheckHealth(context.Background(), target) - assert.NoError(t, err) - assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when response is fast") -} - -// TestCategorizeError_SentinelErrors tests sentinel error categorization -func TestCategorizeError_SentinelErrors(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - expectedStatus vmcp.BackendHealthStatus - }{ - { - name: "ErrAuthenticationFailed", - err: vmcp.ErrAuthenticationFailed, - expectedStatus: vmcp.BackendUnauthenticated, - }, - { - name: "ErrAuthorizationFailed", - err: vmcp.ErrAuthorizationFailed, - expectedStatus: vmcp.BackendUnauthenticated, - }, - { - name: "ErrTimeout", - err: vmcp.ErrTimeout, - expectedStatus: vmcp.BackendUnhealthy, - }, - { - name: "ErrCancelled", - err: vmcp.ErrCancelled, - expectedStatus: vmcp.BackendUnhealthy, - }, - { - name: "ErrBackendUnavailable", - err: vmcp.ErrBackendUnavailable, - expectedStatus: vmcp.BackendUnhealthy, - }, - { - name: "wrapped ErrAuthenticationFailed", - err: errors.New("wrapped: " + vmcp.ErrAuthenticationFailed.Error()), - expectedStatus: vmcp.BackendUnauthenticated, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - status := categorizeError(tt.err) - assert.Equal(t, tt.expectedStatus, status) - }) - } -} - -// TestNormalizeURLForComparison tests URL normalization -func TestNormalizeURLForComparison(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - expected string - wantErr bool - }{ - { - name: "localhost normalized to 127.0.0.1", - input: "http://localhost:8080", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "127.0.0.1 stays as is", - input: "http://127.0.0.1:8080", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "path is ignored", - input: "http://127.0.0.1:8080/mcp", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "query is ignored", - input: "http://127.0.0.1:8080?param=value", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "fragment is ignored", - input: "http://127.0.0.1:8080#fragment", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "scheme is lowercased", - input: "HTTP://127.0.0.1:8080", - expected: "http://127.0.0.1:8080", - wantErr: false, - }, - { - name: "host is lowercased", - input: "http://EXAMPLE.COM:8080", - expected: "http://example.com:8080", - wantErr: false, - }, - { - name: "no port", - input: "http://127.0.0.1", - expected: "http://127.0.0.1", - wantErr: false, - }, - { - name: "invalid URL", - input: "not-a-url", - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - result, err := NormalizeURLForComparison(tt.input) - if tt.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, result) - } - }) - } -} - -// TestIsSelfCheck_EdgeCases tests edge cases for self-check detection -func TestIsSelfCheck_EdgeCases(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(func() { ctrl.Finish() }) - - mockClient := mocks.NewMockBackendClient(ctrl) - - tests := []struct { - name string - selfURL string - backendURL string - expected bool - }{ - { - name: "both empty", - selfURL: "", - backendURL: "", - expected: false, - }, - { - name: "selfURL empty", - selfURL: "", - backendURL: "http://127.0.0.1:8080", - expected: false, - }, - { - name: "backendURL empty", - selfURL: "http://127.0.0.1:8080", - backendURL: "", - expected: false, - }, - { - name: "localhost matches 127.0.0.1", - selfURL: "http://localhost:8080", - backendURL: "http://127.0.0.1:8080", - expected: true, - }, - { - name: "127.0.0.1 matches localhost", - selfURL: "http://127.0.0.1:8080", - backendURL: "http://localhost:8080", - expected: true, - }, - { - name: "different ports", - selfURL: "http://127.0.0.1:8080", - backendURL: "http://127.0.0.1:8081", - expected: false, - }, - { - name: "different hosts", - selfURL: "http://127.0.0.1:8080", - backendURL: "http://192.168.1.1:8080", - expected: false, - }, - { - name: "path ignored", - selfURL: "http://127.0.0.1:8080", - backendURL: "http://127.0.0.1:8080/mcp", - expected: true, - }, - { - name: "query ignored", - selfURL: "http://127.0.0.1:8080", - backendURL: "http://127.0.0.1:8080?param=value", - expected: true, - }, - { - name: "invalid selfURL", - selfURL: "not-a-url", - backendURL: "http://127.0.0.1:8080", - expected: false, - }, - { - name: "invalid backendURL", - selfURL: "http://127.0.0.1:8080", - backendURL: "not-a-url", - expected: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - checker := NewHealthChecker(mockClient, 5*time.Second, 0, tt.selfURL) - hc, ok := checker.(*healthChecker) - require.True(t, ok) - - result := hc.isSelfCheck(tt.backendURL) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index 63c3c986b6..39f7258d82 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -44,7 +44,7 @@ func TestNewHealthChecker(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - checker := NewHealthChecker(mockClient, tt.timeout, 0, "") + checker := NewHealthChecker(mockClient, tt.timeout, 0) require.NotNil(t, checker) // Type assert to access internals for verification @@ -68,7 +68,7 @@ func TestHealthChecker_CheckHealth_Success(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + checker := NewHealthChecker(mockClient, 5*time.Second, 0) target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -95,7 +95,7 @@ func TestHealthChecker_CheckHealth_ContextCancellation(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -123,7 +123,7 @@ func TestHealthChecker_CheckHealth_NoTimeout(t *testing.T) { Times(1) // Create checker with no timeout - checker := NewHealthChecker(mockClient, 0, 0, "") + checker := NewHealthChecker(mockClient, 0, 0) target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -213,7 +213,7 @@ func TestHealthChecker_CheckHealth_ErrorCategorization(t *testing.T) { Return(nil, tt.err). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + checker := NewHealthChecker(mockClient, 5*time.Second, 0) target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -430,7 +430,7 @@ func TestHealthChecker_CheckHealth_Timeout(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -467,7 +467,7 @@ func TestHealthChecker_CheckHealth_MultipleBackends(t *testing.T) { }). Times(4) - checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + checker := NewHealthChecker(mockClient, 5*time.Second, 0) // Test healthy backend status, err := checker.CheckHealth(context.Background(), &vmcp.BackendTarget{ diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index 62aea9b735..50f00d788d 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -108,14 +108,12 @@ func DefaultConfig() MonitorConfig { // - client: BackendClient for communicating with backend MCP servers // - backends: List of backends to monitor // - config: Configuration for health monitoring -// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns (monitor, error). Error is returned if configuration is invalid. func NewMonitor( client vmcp.BackendClient, backends []vmcp.Backend, config MonitorConfig, - selfURL string, ) (*Monitor, error) { // Validate configuration if config.CheckInterval <= 0 { @@ -125,8 +123,8 @@ func NewMonitor( return nil, fmt.Errorf("unhealthy threshold must be >= 1, got %d", config.UnhealthyThreshold) } - // Create health checker with degraded threshold and self URL - checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold, selfURL) + // Create health checker with degraded threshold + checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold) // Create status tracker statusTracker := newStatusTracker(config.UnhealthyThreshold) diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index 8d2de11bdd..bb177017e7 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -66,7 +66,7 @@ func TestNewMonitor_Validation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, tt.config, "") + monitor, err := NewMonitor(mockClient, backends, tt.config) if tt.expectError { assert.Error(t, err) assert.Nil(t, monitor) @@ -101,7 +101,7 @@ func TestMonitor_StartStop(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) // Start monitor @@ -178,7 +178,7 @@ func TestMonitor_StartErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) err = tt.setupFunc(monitor) @@ -208,7 +208,7 @@ func TestMonitor_StopWithoutStart(t *testing.T) { Timeout: 50 * time.Millisecond, } - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) // Try to stop without starting @@ -239,7 +239,7 @@ func TestMonitor_PeriodicHealthChecks(t *testing.T) { Return(nil, errors.New("backend unavailable")). MinTimes(2) - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) ctx := context.Background() @@ -289,7 +289,7 @@ func TestMonitor_GetHealthSummary(t *testing.T) { }). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) ctx := context.Background() @@ -333,7 +333,7 @@ func TestMonitor_GetBackendStatus(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) ctx := context.Background() @@ -382,7 +382,7 @@ func TestMonitor_GetBackendState(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) ctx := context.Background() @@ -433,7 +433,7 @@ func TestMonitor_GetAllBackendStates(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) ctx := context.Background() @@ -477,7 +477,7 @@ func TestMonitor_ContextCancellation(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config, "") + monitor, err := NewMonitor(mockClient, backends, config) require.NoError(t, err) // Start with cancellable context diff --git a/pkg/vmcp/server/integration_test.go b/pkg/vmcp/server/integration_test.go index 0086a8b4a9..8bb5e40231 100644 --- a/pkg/vmcp/server/integration_test.go +++ b/pkg/vmcp/server/integration_test.go @@ -119,6 +119,7 @@ func TestIntegration_AggregatorToRouterToServer(t *testing.T) { mockBackendClient, conflictResolver, nil, // no tool configs + nil, // no tracer provider in tests ) // Step 3: Run aggregation on mock backends @@ -311,7 +312,7 @@ func TestIntegration_HTTPRequestFlowWithRoutingTable(t *testing.T) { // Create discovery manager conflictResolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, conflictResolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, conflictResolver, nil, nil) discoveryMgr, err := discovery.NewManager(agg) require.NoError(t, err) @@ -501,7 +502,7 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) { Times(2) resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) result, err := agg.AggregateCapabilities(ctx, createBackendsWithConflicts()) require.NoError(t, err) @@ -539,7 +540,7 @@ func TestIntegration_ConflictResolutionStrategies(t *testing.T) { resolver, err := aggregator.NewPriorityConflictResolver([]string{"backend1", "backend2"}) require.NoError(t, err) - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) result, err := agg.AggregateCapabilities(ctx, createBackendsWithConflicts()) require.NoError(t, err) @@ -659,7 +660,7 @@ func TestIntegration_AuditLogging(t *testing.T) { Discover(gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, _ []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { resolver := aggregator.NewPrefixConflictResolver("{workload}_") - agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil) + agg := aggregator.NewDefaultAggregator(mockBackendClient, resolver, nil, nil) return agg.AggregateCapabilities(ctx, backends) }). AnyTimes() diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index a57eef5613..3ccecdf39c 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -340,9 +340,7 @@ func New( if cfg.HealthMonitorConfig != nil { // Get initial backends list from registry for health monitoring setup initialBackends := backendRegistry.List(ctx) - // Construct selfURL to prevent health checker from checking itself - selfURL := fmt.Sprintf("http://%s:%d%s", cfg.Host, cfg.Port, cfg.EndpointPath) - healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig, selfURL) + healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig) if err != nil { return nil, fmt.Errorf("failed to create health monitor: %w", err) } diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index e48d3fdea5..fc20dd36a5 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -37,7 +37,7 @@ import ( ) // WaitForVirtualMCPServerReady waits for a VirtualMCPServer to reach Ready status -// and ensures the associated pods are actually running and ready +// and ensures at least one associated pod is actually running and ready func WaitForVirtualMCPServerReady( ctx context.Context, c client.Client, @@ -58,7 +58,7 @@ func WaitForVirtualMCPServerReady( for _, condition := range vmcpServer.Status.Conditions { if condition.Type == "Ready" { if condition.Status == "True" { - // Also check that the pods are actually running and ready + // Also check that at least one pod is actually running and ready labels := map[string]string{ "app.kubernetes.io/name": "virtualmcpserver", "app.kubernetes.io/instance": name, @@ -75,7 +75,9 @@ func WaitForVirtualMCPServerReady( }, timeout, pollingInterval).Should(gomega.Succeed()) } -// checkPodsReady checks if all pods matching the given labels are ready +// checkPodsReady checks if at least one pod matching the given labels is ready. +// This is typically used when checking for a single expected pod (e.g., one replica deployment). +// Pods not in Running phase are skipped (e.g., Succeeded, Failed from previous deployments). func checkPodsReady(ctx context.Context, c client.Client, namespace string, labels map[string]string) error { podList := &corev1.PodList{} if err := c.List(ctx, podList, @@ -246,7 +248,7 @@ func GetVirtualMCPServerPods(ctx context.Context, c client.Client, vmcpServerNam return podList, err } -// WaitForPodsReady waits for all pods matching labels to be ready +// WaitForPodsReady waits for at least one pod matching labels to be ready func WaitForPodsReady( ctx context.Context, c client.Client, diff --git a/test/integration/vmcp/helpers/vmcp_server.go b/test/integration/vmcp/helpers/vmcp_server.go index 95aa5f5142..99caa6c086 100644 --- a/test/integration/vmcp/helpers/vmcp_server.go +++ b/test/integration/vmcp/helpers/vmcp_server.go @@ -161,7 +161,7 @@ func NewVMCPServer( } // Create aggregator - agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, nil) + agg := aggregator.NewDefaultAggregator(backendClient, conflictResolver, nil, nil) // Create discovery manager discoveryMgr, err := discovery.NewManager(agg) From cd7b756e1f3c732ca5c6a493719348e55b9c12dc Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 11:55:11 +0000 Subject: [PATCH 12/16] Add explanatory comment for MCP SDK Meta limitations Restores comment explaining why Meta field preservation is important for ReadResource, in anticipation of future SDK improvements. This addresses PR feedback to maintain context about the SDK's current limitations regarding Meta field handling. --- pkg/vmcp/client/client.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 3993ca6caa..a30b717ce1 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -700,6 +700,8 @@ func (h *httpBackendClient) ReadResource( // Extract _meta field from backend response meta := conversion.FromMCPMeta(result.Meta) + // Note: Due to MCP SDK limitations, the SDK's ReadResourceResult may not include Meta. + // This preserves it for future SDK improvements. return &vmcp.ResourceReadResult{ Contents: data, MimeType: mimeType, From af70d9415565dc78924c5b7985c76d25111251a2 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 11:58:26 +0000 Subject: [PATCH 13/16] Update test helper comments to clarify pod readiness contract - Clarify that checkPodsReady waits for at least one pod (not all pods) - Add context that helpers are used for single replica deployments - Update comments on WaitForPodsReady and WaitForVirtualMCPServerReady Addresses code review feedback from PR review. --- test/e2e/thv-operator/virtualmcp/helpers.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index fc20dd36a5..56b9bbf52f 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -37,7 +37,8 @@ import ( ) // WaitForVirtualMCPServerReady waits for a VirtualMCPServer to reach Ready status -// and ensures at least one associated pod is actually running and ready +// and ensures at least one associated pod is actually running and ready. +// This is used when waiting for a single expected pod (e.g., one replica deployment). func WaitForVirtualMCPServerReady( ctx context.Context, c client.Client, @@ -75,8 +76,8 @@ func WaitForVirtualMCPServerReady( }, timeout, pollingInterval).Should(gomega.Succeed()) } -// checkPodsReady checks if at least one pod matching the given labels is ready. -// This is typically used when checking for a single expected pod (e.g., one replica deployment). +// checkPodsReady waits for at least one pod matching the given labels to be ready. +// This is used when checking for a single expected pod (e.g., one replica deployment). // Pods not in Running phase are skipped (e.g., Succeeded, Failed from previous deployments). func checkPodsReady(ctx context.Context, c client.Client, namespace string, labels map[string]string) error { podList := &corev1.PodList{} @@ -248,7 +249,8 @@ func GetVirtualMCPServerPods(ctx context.Context, c client.Client, vmcpServerNam return podList, err } -// WaitForPodsReady waits for at least one pod matching labels to be ready +// WaitForPodsReady waits for at least one pod matching labels to be ready. +// This is used when waiting for a single expected pod to be ready (e.g., one replica deployment). func WaitForPodsReady( ctx context.Context, c client.Client, From bc966369800833c25a2c35256a7493b9ff181482 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 12:05:12 +0000 Subject: [PATCH 14/16] Complete error capture pattern in MergeCapabilities defer - Add named return value (retErr error) to MergeCapabilities - Add error capture in defer statement with span.RecordError and span.SetStatus - Ensures consistent error handling pattern across all aggregator methods This completes the implementation of the error capture pattern suggested in code review for all methods with tracing spans. --- pkg/vmcp/aggregator/default_aggregator.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 20985f9a1f..ca51d207d8 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -274,7 +274,7 @@ func (a *defaultAggregator) MergeCapabilities( ctx context.Context, resolved *ResolvedCapabilities, registry vmcp.BackendRegistry, -) (*AggregatedCapabilities, error) { +) (_ *AggregatedCapabilities, retErr error) { ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities", trace.WithAttributes( attribute.Int("resolved.tools", len(resolved.Tools)), @@ -282,7 +282,13 @@ func (a *defaultAggregator) MergeCapabilities( attribute.Int("resolved.prompts", len(resolved.Prompts)), ), ) - defer span.End() + defer func() { + if retErr != nil { + span.RecordError(retErr) + span.SetStatus(codes.Error, retErr.Error()) + } + span.End() + }() logger.Debugf("Merging capabilities into final view") From c514590c0ad6247335a92496dc9d2b26c7fde95c Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 16:57:21 +0000 Subject: [PATCH 15/16] Remove singleflight race condition fix Moving the singleflight deduplication logic to a separate PR as it addresses a different race condition from the one fixed in #3450. The fix prevents duplicate capability aggregation when multiple concurrent requests arrive simultaneously at startup. --- .gitignore | 6 ----- pkg/vmcp/discovery/manager.go | 44 +++++++---------------------------- 2 files changed, 8 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 34dcc23d79..f0840c001e 100644 --- a/.gitignore +++ b/.gitignore @@ -44,9 +44,3 @@ coverage* crd-helm-wrapper cmd/vmcp/__debug_bin* - -# Demo files -examples/operator/virtual-mcps/vmcp_optimizer.yaml -scripts/k8s_vmcp_optimizer_demo.sh -examples/ingress/mcp-servers-ingress.yaml -/vmcp diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 6dfa659512..9bdfdc1d39 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -1,6 +1,3 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - // Package discovery provides lazy per-user capability discovery for vMCP servers. // // This package implements per-request capability discovery with user-specific @@ -18,8 +15,6 @@ import ( "sync" "time" - "golang.org/x/sync/singleflight" - "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -70,9 +65,6 @@ type DefaultManager struct { stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup - // singleFlight ensures only one aggregation happens per cache key at a time - // This prevents concurrent requests from all triggering aggregation - singleFlight singleflight.Group } // NewManager creates a new discovery manager with the given aggregator. @@ -136,9 +128,6 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. -// -// This method uses singleflight to ensure that concurrent requests for the same -// cache key only trigger one aggregation, preventing duplicate work. func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { // Validate user identity is present (set by auth middleware) // This ensures discovery happens with proper user authentication context @@ -150,7 +139,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) // Generate cache key from user identity and backend set cacheKey := m.generateCacheKey(identity.Subject, backends) - // Check cache first (with read lock) + // Check cache first if caps := m.getCachedCapabilities(cacheKey); caps != nil { logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) return caps, nil @@ -158,33 +147,16 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) - // Use singleflight to ensure only one aggregation happens per cache key - // Even if multiple requests come in concurrently, they'll all wait for the same result - result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) { - // Double-check cache after acquiring singleflight lock - // Another goroutine might have populated it while we were waiting - if caps := m.getCachedCapabilities(cacheKey); caps != nil { - logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject) - return caps, nil - } - - // Perform aggregation - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) - if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) - } - - // Cache the result (skips caching if at capacity and key doesn't exist) - m.cacheCapabilities(cacheKey, caps) - - return caps, nil - }) - + // Cache miss - perform aggregation + caps, err := m.aggregator.AggregateCapabilities(ctx, backends) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) } - return result.(*aggregator.AggregatedCapabilities), nil + // Cache the result (skips caching if at capacity and key doesn't exist) + m.cacheCapabilities(cacheKey, caps) + + return caps, nil } // Stop gracefully stops the manager and cleans up resources. From ff267fe8e3bef641fed7d4969c667598999cff0a Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 27 Jan 2026 17:02:43 +0000 Subject: [PATCH 16/16] Add SPDX license headers to manager.go --- pkg/vmcp/discovery/manager.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 9bdfdc1d39..0845118ee1 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -1,3 +1,6 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + // Package discovery provides lazy per-user capability discovery for vMCP servers. // // This package implements per-request capability discovery with user-specific