From 998669a19b6c11b39e2077b616f79ab8e8d25f66 Mon Sep 17 00:00:00 2001 From: Anshul Khantwal Date: Thu, 25 Sep 2025 14:44:14 +0530 Subject: [PATCH 1/2] Enabling vector-search tool as well as Authentication --- .dockerignore | 41 ++- Dockerfile | 94 +++++- package-lock.json | 121 ++++++- package.json | 6 +- src/common/config.ts | 70 ++++ src/common/logger.ts | 2 + .../azureAIInferenceEmbeddingProvider.ts | 86 +++++ src/embedding/embeddingProvider.ts | 4 + src/embedding/embeddingProviderFactory.ts | 92 ++++++ src/tools/mongodb/read/vectorSearchv1.ts | 154 +++++++++ src/tools/mongodb/read/vectorSearchv2.ts | 151 +++++++++ src/tools/mongodb/tools.ts | 4 + src/transports/azureManagedIdentityAuth.ts | 212 +++++++++++++ src/transports/streamableHttp.ts | 5 + tests/accuracy/vectorSearch.test.ts | 43 +++ tests/integration/transports/stdio.test.ts | 1 + .../azureOpenAIProviderRetry.test.ts | 46 +++ ...embeddingProviderFactoryValidation.test.ts | 32 ++ .../tools/mongodb/read/vectorSearchV1.test.ts | 115 +++++++ .../tools/mongodb/read/vectorSearchV2.test.ts | 133 ++++++++ .../azureManagedIdentityAuth.test.ts | 298 ++++++++++++++++++ 21 files changed, 1692 insertions(+), 18 deletions(-) create mode 100644 src/embedding/azureAIInferenceEmbeddingProvider.ts create mode 100644 src/embedding/embeddingProvider.ts create mode 100644 src/embedding/embeddingProviderFactory.ts create mode 100644 src/tools/mongodb/read/vectorSearchv1.ts create mode 100644 src/tools/mongodb/read/vectorSearchv2.ts create mode 100644 src/transports/azureManagedIdentityAuth.ts create mode 100644 tests/accuracy/vectorSearch.test.ts create mode 100644 tests/unit/embedding/azureOpenAIProviderRetry.test.ts create mode 100644 tests/unit/embedding/embeddingProviderFactoryValidation.test.ts create mode 100644 tests/unit/tools/mongodb/read/vectorSearchV1.test.ts create mode 100644 tests/unit/tools/mongodb/read/vectorSearchV2.test.ts create mode 100644 tests/unit/transports/azureManagedIdentityAuth.test.ts diff --git a/.dockerignore b/.dockerignore index 05384e6fe..31eb1ffb3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,11 +1,44 @@ +### Runtime build context reduction +### Runtime build context reduction dist node_modules -.vscode -.github + +### VCS / metadata .git -# Environment variables +.github +.smithery + +### Tool / editor configs +.vscode +*.swp +*.tmp +*.DS_Store +Thumbs.db + +### Logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +### Temp +tmp +**/tmp + +### Environment variables & secrets .env +.env.* +env.list +### Tests & coverage (not needed in runtime image) tests coverage -scripts +scripts/accuracy +**/test-data-dumps +.vitest + +### Local certificates (copy explicitly if needed) +certs + +### Misc local exports +exports diff --git a/Dockerfile b/Dockerfile index d842f6333..f495acbf0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,87 @@ -FROM node:22-alpine -ARG VERSION=latest +### +# Optimized multi-stage Dockerfile for mongodb-mcp-server +# +# Build args: +# NODE_VERSION Node.js version (default 22-alpine) +# INSTALL_DEV Keep dev dependencies (true|false, default: false) +# RUNTIME_IMAGE Base runtime image (default: node:22-alpine) +# +# Typical build: +# docker build -t mongodb-mcp-server:local . +# docker build --build-arg INSTALL_DEV=true -t mongodb-mcp-server:dev . +# +# Runtime (stdio transport): +# docker run --rm -it mongodb-mcp-server:local --transport stdio +# +# Runtime (http transport): +# docker run --rm -p 3000:3000 mongodb-mcp-server:local --transport http --httpHost 0.0.0.0 +# curl -s -X POST http://localhost:3000/mcp -H 'Content-Type: application/json' \ +# -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{}}}' +# +# Optional HTTP auth (Azure Managed Identity): +# docker run --rm -p 3000:3000 \ +# -e MDB_MCP_HTTP_AUTH_MODE=azure-managed-identity \ +# -e MDB_MCP_AZURE_MANAGED_IDENTITY_TENANT_ID= \ +# -e MDB_MCP_AZURE_MANAGED_IDENTITY_CLIENT_ID= \ +# mongodb-mcp-server:local --transport http --httpHost 0.0.0.0 +### + +# syntax=docker/dockerfile:1.7-labs + +ARG NODE_VERSION=22-alpine +ARG RUNTIME_IMAGE=node:${NODE_VERSION} +ARG INSTALL_DEV=false + +############################################# +# Builder Stage +############################################# +FROM node:${NODE_VERSION} AS builder +WORKDIR /app + +# Leverage Docker layer caching: copy only dependency manifests + tsconfigs first (needed by build scripts) +COPY package.json package-lock.json* .npmrc* tsconfig*.json eslint.config.js vitest.config.ts ./ + +# Install dependencies without running lifecycle scripts (avoid premature build via prepare) +RUN --mount=type=cache,target=/root/.npm \ + npm ci --ignore-scripts + +# Copy application sources +COPY src ./src +COPY scripts ./scripts + +# Now run the build explicitly (includes prepare sequence tasks) +RUN npm run build + +# Optionally prune dev dependencies for slimmer runtime +ARG INSTALL_DEV +RUN if [ "${INSTALL_DEV}" != "true" ]; then npm prune --omit=dev; fi + +############################################# +# Runtime Stage +############################################# +FROM ${RUNTIME_IMAGE} AS runtime +ENV NODE_ENV=production \ + MDB_MCP_LOGGERS=stderr,mcp + +# Create non-root user RUN addgroup -S mcp && adduser -S mcp -G mcp -RUN npm install -g mongodb-mcp-server@${VERSION} -USER mcp WORKDIR /home/mcp -ENV MDB_MCP_LOGGERS=stderr,mcp -ENTRYPOINT ["mongodb-mcp-server"] -LABEL maintainer="MongoDB Inc " -LABEL description="MongoDB MCP Server" -LABEL version=${VERSION} + +# Copy only required artifacts (preserve ownership in a single layer) +COPY --chown=mcp:mcp --from=builder /app/package*.json ./ +COPY --chown=mcp:mcp --from=builder /app/node_modules ./node_modules +COPY --chown=mcp:mcp --from=builder /app/dist ./dist + +USER mcp + +# Expose default HTTP port (matches default config httpPort=3000) +EXPOSE 3000 + +LABEL maintainer="MongoDB Inc " \ + org.opencontainers.image.title="mongodb-mcp-server" \ + org.opencontainers.image.description="MongoDB MCP Server" \ + org.opencontainers.image.source="https://github.com/mongodb-js/mongodb-mcp-server" + +# Use exec form for clarity; default command may be overridden at runtime +ENTRYPOINT ["node", "dist/index.js"] +CMD ["--transport", "http"] diff --git a/package-lock.json b/package-lock.json index 09f36cd0c..82463cd87 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "@mongosh/service-provider-node-driver": "^3.17.0", "bson": "^6.10.4", "express": "^5.1.0", + "jose": "^5.9.6", "lru-cache": "^11.1.0", "mongodb-connection-string-url": "^3.0.2", "mongodb-log-writer": "^2.4.1", @@ -62,6 +63,7 @@ "openapi-typescript": "^7.9.1", "prettier": "^3.6.2", "proper-lockfile": "^4.1.2", + "rimraf": "^5.0.10", "semver": "^7.7.2", "simple-git": "^3.28.0", "tsx": "^4.20.5", @@ -9827,9 +9829,9 @@ } }, "node_modules/jose": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.0.tgz", - "integrity": "sha512-TTQJyoEoKcC1lscpVDCSsVgYzUDg/0Bt3WE//WiTPK6uOCQC2KZS4MpugbMWt/zyjkopgZoXhZuCi00gLudfUA==", + "version": "5.10.0", + "resolved": "https://registry.npmjs.org/jose/-/jose-5.10.0.tgz", + "integrity": "sha512-s+3Al/p9g32Iq+oqXxkW//7jk2Vig6FF1CFqzVXoTUXt2qz89YWbL+OwS17NFYEvxC35n0FKeGO2LGYSxeM2Gg==", "license": "MIT", "funding": { "url": "https://github.com/sponsors/panva" @@ -11302,6 +11304,16 @@ "url": "https://github.com/sponsors/panva" } }, + "node_modules/openid-client/node_modules/jose": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.0.tgz", + "integrity": "sha512-TTQJyoEoKcC1lscpVDCSsVgYzUDg/0Bt3WE//WiTPK6uOCQC2KZS4MpugbMWt/zyjkopgZoXhZuCi00gLudfUA==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/optionator": { "version": "0.9.4", "resolved": "https://registry.npmjs.org/optionator/-/optionator-0.9.4.tgz", @@ -12492,6 +12504,109 @@ "node": ">=0.10.0" } }, + "node_modules/rimraf": { + "version": "5.0.10", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-5.0.10.tgz", + "integrity": "sha512-l0OE8wL34P4nJH/H2ffoaniAokM2qSmrtXHmlpvYr5AVVX8msAyW0l8NVJFDxlSK4u3Uh/f41cQheDVdnYijwQ==", + "dev": true, + "license": "ISC", + "dependencies": { + "glob": "^10.3.7" + }, + "bin": { + "rimraf": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/rimraf/node_modules/glob": { + "version": "10.4.5", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.4.5.tgz", + "integrity": "sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, + "node_modules/rimraf/node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", + "dev": true, + "license": "ISC" + }, + "node_modules/rimraf/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/rimraf/node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/rollup": { "version": "4.50.0", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.50.0.tgz", diff --git a/package.json b/package.json index 789db6ffc..fad97808e 100644 --- a/package.json +++ b/package.json @@ -35,12 +35,12 @@ "start": "node dist/index.js --transport http --loggers stderr mcp", "start:stdio": "node dist/index.js --transport stdio --loggers stderr mcp", "prepare": "npm run build", - "build:clean": "rm -rf dist", + "build:clean": "rimraf dist", "build:update-package-version": "tsx scripts/updatePackageVersion.ts", "build:esm": "tsc --project tsconfig.esm.json", "build:cjs": "tsc --project tsconfig.cjs.json", "build:universal-package": "tsx scripts/createUniversalPackage.ts", - "build:chmod": "chmod +x dist/esm/index.js", + "build:chmod": "node -e \"try{if(process.platform!=='win32'){require('fs').chmodSync('dist/esm/index.js',0o755)}}catch(e){console.error(e);process.exit(1)}\"", "build": "npm run build:clean && npm run build:esm && npm run build:cjs && npm run build:universal-package && npm run build:chmod", "inspect": "npm run build && mcp-inspector -- dist/esm/index.js", "prettier": "prettier", @@ -87,6 +87,7 @@ "openapi-typescript": "^7.9.1", "prettier": "^3.6.2", "proper-lockfile": "^4.1.2", + "rimraf": "^5.0.10", "semver": "^7.7.2", "simple-git": "^3.28.0", "tsx": "^4.20.5", @@ -104,6 +105,7 @@ "@mongosh/service-provider-node-driver": "^3.17.0", "bson": "^6.10.4", "express": "^5.1.0", + "jose": "^5.9.6", "lru-cache": "^11.1.0", "mongodb-connection-string-url": "^3.0.2", "mongodb-log-writer": "^2.4.1", diff --git a/src/common/config.ts b/src/common/config.ts index cbac900c4..5e4ee66e2 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -22,6 +22,13 @@ const OPTIONS = { "notificationTimeoutMs", "telemetry", "transport", + "httpAuthMode", + "azureManagedIdentityTenantId", + "azureManagedIdentityClientId", + "azureManagedIdentityAudience", + "azureManagedIdentityRequiredRoles", + "azureManagedIdentityAllowedAppIds", + "azureManagedIdentityRoleMatchMode", "apiVersion", "authenticationDatabase", "authenticationMechanism", @@ -53,6 +60,16 @@ const OPTIONS = { "exportsPath", "exportTimeoutMs", "exportCleanupIntervalMs", + + // Custom additions for vector search / embeddings + "vectorSearchPath", + "vectorSearchIndex", + "embeddingModelEndpoint", + "embeddingModelApikey", + "embeddingModelDimension", + "embeddingModelDeploymentName", + "embeddingModelProvider", + // Removed retry tunables (maxRetries & retryInitialDelayMs) now fixed internally ], boolean: [ "apiDeprecationErrors", @@ -175,12 +192,34 @@ export interface UserConfig extends CliOptions { httpPort: number; httpHost: string; httpHeaders: Record; + // httpAuthMode: none | azure-managed-identity (env: MDB_MCP_HTTP_AUTH_MODE) + httpAuthMode?: "none" | "azure-managed-identity"; + // Azure Managed Identity configuration (only used when httpAuthMode=azure-managed-identity) + azureManagedIdentityTenantId?: string; // MDB_MCP_AZURE_MANAGED_IDENTITY_TENANT_ID + azureManagedIdentityClientId?: string; // optional; target app/client id to validate aud if audience not provided + azureManagedIdentityAudience?: string; // optional explicit audience to validate token 'aud' + azureManagedIdentityAllowedAppIds?: string[]; // optional allowed list of app (client) IDs (appid/azp) - if set token must match one + azureManagedIdentityRequiredRoles?: string[]; // optional list of app roles that must all be present in 'roles' + azureManagedIdentityRoleMatchMode?: "all" | "any"; // default all + loggers: Array<"stderr" | "disk" | "mcp">; idleTimeoutMs: number; notificationTimeoutMs: number; maxDocumentsPerQuery: number; maxBytesPerQuery: number; atlasTemporaryDatabaseUserLifetimeMs: number; + + // Optional default vector field path for vector-search tool (env: MDB_MCP_VECTOR_SEARCH_PATH) + vectorSearchPath?: string; + // Optional default vector search index name (env: MDB_MCP_VECTOR_SEARCH_INDEX) + vectorSearchIndex?: string; + + // Azure AI embedding model configuration + embeddingModelProvider?: string; // MDB_MCP_EMBEDDING_MODEL_PROVIDER (defaults to azure-ai-inference) + embeddingModelEndpoint?: string; // MDB_MCP_EMBEDDING_MODEL_ENDPOINT + embeddingModelApikey?: string; // MDB_MCP_EMBEDDING_MODEL_APIKEY + embeddingModelDeploymentName?: string; // MDB_MCP_EMBEDDING_MODEL_DEPLOYMENT_NAME + embeddingModelDimension?: number; // MDB_MCP_EMBEDDING_MODEL_DIMENSION } export const defaultUserConfig: UserConfig = { @@ -210,6 +249,7 @@ export const defaultUserConfig: UserConfig = { maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours + httpAuthMode: "none", }; export const config = setupUserConfig({ @@ -437,6 +477,7 @@ export function registerKnownSecretsInRootKeychain(userConfig: Partial 65535 || isNaN(httpPort)) { throw new Error(`Invalid httpPort: ${userConfig.httpPort}`); diff --git a/src/common/logger.ts b/src/common/logger.ts index 1fe3cc73a..c3d5327d3 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -69,6 +69,8 @@ export const LogId = { exportLockError: mongoLogId(1_007_008), oidcFlow: mongoLogId(1_008_001), + + azureManagedIdentityAuthError: mongoLogId(1_009_001), } as const; export interface LogPayload { diff --git a/src/embedding/azureAIInferenceEmbeddingProvider.ts b/src/embedding/azureAIInferenceEmbeddingProvider.ts new file mode 100644 index 000000000..4c1e62dcb --- /dev/null +++ b/src/embedding/azureAIInferenceEmbeddingProvider.ts @@ -0,0 +1,86 @@ +import { EmbeddingProvider } from './embeddingProvider.js'; + +/** + * Configuration for the Azure AI Inference embedding provider. + */ +export interface AzureAIInferenceEmbeddingConfig { + endpoint: string; // Full endpoint URL for embeddings request + apiKey: string; // API key (sent as api-key header) + deployment: string; // Deployment or model name + dimension?: number; // Optional dimension override + maxRetries?: number; // Maximum retry attempts for transient errors + initialDelayMs?: number; // Initial backoff delay +} + +/** + * Embedding provider implementation backed by Azure AI Inference embeddings endpoint. + * Performs simple exponential backoff retries on transient failures (429 / 5xx). + */ +export class AzureAIInferenceEmbeddingProvider implements EmbeddingProvider { + public name = 'azure-ai-inference'; + private readonly config: AzureAIInferenceEmbeddingConfig; + + constructor(config: AzureAIInferenceEmbeddingConfig) { + this.config = config; + } + + async embed(input: string[]): Promise { + if (input.length === 0) return []; + + // Construct request payload; future optimization could batch requests. + const body: Record = { + model: this.config.deployment, + input, + input_type: 'query' + }; + if (this.config.dimension) { + (body as any).dimensions = this.config.dimension; // eslint-disable-line @typescript-eslint/no-explicit-any + } + + const maxRetries = this.config.maxRetries ?? 2; + const initialDelay = this.config.initialDelayMs ?? 200; + let attempt = 0; + let lastError: Error | undefined; + + while (attempt <= maxRetries) { + const res = await fetch(this.config.endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'api-key': this.config.apiKey + }, + body: JSON.stringify(body) + }); + + if (!res.ok) { + const transient = res.status === 429 || (res.status >= 500 && res.status < 600); + if (!transient) { + throw new Error(`Embedding request failed with status ${res.status}, statusText: ${res.statusText}, response: ${await res.text()}`); + } + lastError = new Error(`Transient status ${res.status}`); + } else { + const json = (await res.json()) as any; // eslint-disable-line @typescript-eslint/no-explicit-any + const data = json?.data; + if (!Array.isArray(data) || data.length === 0) { + throw new Error('Embedding response malformed: missing data array'); + } + const embeddings: number[][] = []; + for (const item of data) { + const emb = item?.embedding; + if (!Array.isArray(emb)) { + throw new Error('Embedding response malformed: item.embedding missing or not array'); + } + embeddings.push(emb as number[]); + } + return embeddings; + } + + if (attempt === maxRetries) break; + const delay = Math.round(initialDelay * Math.pow(2, attempt) * (0.75 + Math.random() * 0.5)); + await new Promise(r => setTimeout(r, delay)); + attempt++; + } + + throw new Error(`Embedding request ultimately failed after ${maxRetries + 1} attempt(s): ${lastError?.message}`); + } +} diff --git a/src/embedding/embeddingProvider.ts b/src/embedding/embeddingProvider.ts new file mode 100644 index 000000000..532bf9011 --- /dev/null +++ b/src/embedding/embeddingProvider.ts @@ -0,0 +1,4 @@ +export interface EmbeddingProvider { + name: string; + embed(input: string[]): Promise; +} diff --git a/src/embedding/embeddingProviderFactory.ts b/src/embedding/embeddingProviderFactory.ts new file mode 100644 index 000000000..2272fb48e --- /dev/null +++ b/src/embedding/embeddingProviderFactory.ts @@ -0,0 +1,92 @@ +import type { UserConfig } from "../common/config.js"; +import type { EmbeddingProvider } from "./embeddingProvider.js"; +import { AzureAIInferenceEmbeddingProvider } from "./azureAIInferenceEmbeddingProvider.js"; + +/** + * Factory responsible for creating an EmbeddingProvider implementation + * based on the user configuration. Centralizing this logic allows + * additional providers to be added in the future without touching tool code. + */ +export class EmbeddingProviderFactory { + /** + * Create an embedding provider instance based on configuration. + * Currently supports: + * - azure-ai-inference (default) + */ + static create(config: UserConfig): EmbeddingProvider { + // Default to azure-ai-inference if not set + if (!config.embeddingModelProvider) { + config.embeddingModelProvider = "azure-ai-inference"; + } + + switch (config.embeddingModelProvider) { + case "azure-ai-inference": + return EmbeddingProviderFactory.GetAzureAIInferenceEmbeddingProvider(config); + default: + throw new Error(`Unsupported embedding model provider: ${config.embeddingModelProvider}.`); + } + } + + /** + * Lightweight boolean validation indicating whether the provided config + * contains the minimum required fields to construct an embedding provider + * for the currently selected provider (or default provider). This does NOT + * throw – it is intended for tooling guard rails (e.g. verifyAllowed checks) + * where we just want to short‑circuit availability. + */ + static isEmbeddingConfigValid(config: UserConfig): boolean { + // Default to azure-ai-inference if not set + if (!config.embeddingModelProvider) { + config.embeddingModelProvider = "azure-ai-inference"; + } + + switch (config.embeddingModelProvider) { + case "azure-ai-inference": { + return !!( + config.embeddingModelEndpoint && + config.embeddingModelApikey && + config.embeddingModelDeploymentName + ); + } + default: + // Unknown provider – explicitly invalid (create() will throw anyway) + return false; + } + } + + /** + * Assertion variant of validation – throws a descriptive error when the + * configuration is incomplete for the selected provider. This centralizes + * error message wording so tools and factory creation stay consistent. + */ + static assertEmbeddingConfigValid(config: UserConfig): void { + if (!this.isEmbeddingConfigValid(config)) { + throw new Error( + `Embedding model config incomplete or invalid for provider '${config.embeddingModelProvider}'. ` + ); + } + } + + static GetAzureAIInferenceEmbeddingProvider(config: UserConfig): EmbeddingProvider { + // Reuse centralized validation + this.assertEmbeddingConfigValid(config); + + const endpoint = config.embeddingModelEndpoint!; + const apiKey = config.embeddingModelApikey!; + const deployment = config.embeddingModelDeploymentName!; + const dimension = config.embeddingModelDimension!; + + return new AzureAIInferenceEmbeddingProvider({ + endpoint, + apiKey, + deployment, + dimension, + maxRetries: 2, + initialDelayMs: 200, + }); + } +} + +export function createEmbeddingProvider(config: UserConfig): EmbeddingProvider { + return EmbeddingProviderFactory.create(config); +} \ No newline at end of file diff --git a/src/tools/mongodb/read/vectorSearchv1.ts b/src/tools/mongodb/read/vectorSearchv1.ts new file mode 100644 index 000000000..e08b9eb52 --- /dev/null +++ b/src/tools/mongodb/read/vectorSearchv1.ts @@ -0,0 +1,154 @@ +import { z } from "zod"; +import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; +import type { ToolArgs, OperationType } from "../../tool.js"; +import { formatUntrustedData } from "../../tool.js"; +import { EJSON } from "bson"; +import { createEmbeddingProvider, EmbeddingProviderFactory } from "../../../embedding/embeddingProviderFactory.js"; +import { LogId } from "../../../common/logger.js"; + +/* + * VectorSearchTool + * Executes a vector search using the $vectorSearch aggregation stage. + * Requires a MongoDB server/Atlas cluster with vector search support and a + * vector index built on the specified path. We implement this as a read + * operation by running a single-stage aggregation pipeline under the hood. + */ + +export const VectorSearchArgs = { + queryText: z + .string() + .max(1024, "queryText must be at most 1024 characters") + .describe( + "Raw search text/context that will be embedded using the configured embedding model; represents the vector search intent." + ), + path: z + .string() + .describe( + "The field path of the vector field (e.g. 'embedding' or 'content.embedding')." + ), + numCandidates: z + .number() + .int() + .positive() + .default(100) + .describe("Number of approximate candidates to consider (higher = potentially better recall, more cost)"), + limit: z + .number() + .int() + .positive() + .default(10) + .describe("Maximum number of results to return"), + filter: z + .object({}) + .passthrough() + .optional() + .describe("Optional filter (standard query predicate) to apply before ranking results"), + index: z + .string() + .describe( + "Name of the vector search index (if multiple indexes exist)." + ), + includeVector: z + .boolean() + .optional() + .default(false) + .describe("If true, include the vector field in the projection (may be large)"), +}; + +export class VectorSearchV1Tool extends MongoDBToolBase { + public name = "vector-search"; + protected description = "Execute a vector similarity search on a MongoDB collection using $vectorSearch"; + protected argsShape = { + ...DbOperationArgs, + ...VectorSearchArgs, + }; + public operationType: OperationType = "read"; + + protected async execute({ + database, + collection, + queryText, + path, + numCandidates, + limit, + filter, + index, + includeVector, + }: ToolArgs): Promise { + const provider = await this.ensureConnected(); + + if (!path) { + throw new Error( + "Vector search requires 'path' argument to be provided." + ); + } + + if (!queryText) { + throw new Error("'queryText' must be provided to perform vector search"); + } + + // Always embed the provided queryText + const embeddingProvider = createEmbeddingProvider(this.config); + const embeddings = await embeddingProvider.embed([queryText]); + const queryVector = embeddings[0]; + if (!queryVector || queryVector.length === 0) { + throw new Error("Embedding provider returned empty embedding."); + } + + // Construct the $vectorSearch stage + const vectorStage: Record = { + $vectorSearch: { + queryVector, + path: path, + limit, + numCandidates, + }, + }; + if (filter) { + (vectorStage.$vectorSearch as any).filter = filter; // eslint-disable-line @typescript-eslint/no-explicit-any + } + if (index) { + (vectorStage.$vectorSearch as any).index = index; // eslint-disable-line @typescript-eslint/no-explicit-any + } + + // Build the full pipeline. Optionally project out the vector field unless requested. + const pipeline: Record[] = [vectorStage]; + if (!includeVector) { + // Exclude the vector path by default to keep output concise (unless the path is dotted, project root minus that field) + const projection: Record = {}; + const topLevelPath = path.split(".")[0] ?? path; // ensure string + projection[topLevelPath as string] = 0; // We exclude; if user needs it they set includeVector=true + pipeline.push({ $project: projection }); + } + + const cursor = provider.aggregate(database, collection, pipeline); + const results = await cursor.toArray(); + + return { + content: formatUntrustedData( + `Vector search returned ${results.length} document(s) from collection "${collection}" using path "${path}."`, + results.length > 0 ? EJSON.stringify(results) : undefined + ), + }; + } + + protected verifyAllowed(): boolean { + // Centralized embedding configuration validation + // If the user explicitly selected a different provider, disallow for now (only azure-ai-inference supported) + if (!EmbeddingProviderFactory.isEmbeddingConfigValid(this.config)) { + this.session.logger.warning({ + id: LogId.toolUpdateFailure, + context: "tool", + message: `Tool ${this.name} could not be configured due to incomplete embedding configuration.`, + noRedaction: true, + }); + return false; + } + + // For V1 tool semantics: both vectorSearchIndex and vectorSearchPath must NOT be set simultaneously + if (this.config.vectorSearchIndex && this.config.vectorSearchPath) return false; + + return super.verifyAllowed(); + } +} diff --git a/src/tools/mongodb/read/vectorSearchv2.ts b/src/tools/mongodb/read/vectorSearchv2.ts new file mode 100644 index 000000000..d2156464e --- /dev/null +++ b/src/tools/mongodb/read/vectorSearchv2.ts @@ -0,0 +1,151 @@ +import { z } from "zod"; +import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; +import type { ToolArgs, OperationType } from "../../tool.js"; +import { formatUntrustedData } from "../../tool.js"; +import { EJSON } from "bson"; +import { createEmbeddingProvider, EmbeddingProviderFactory } from "../../../embedding/embeddingProviderFactory.js"; +import { LogId } from "../../../common/logger.js"; + +/* + * VectorSearchTool + * Executes a vector search using the $vectorSearch aggregation stage. + * Requires a MongoDB server/Atlas cluster with vector search support and a + * vector index built on the specified path. We implement this as a read + * operation by running a single-stage aggregation pipeline under the hood. + */ + +export const VectorSearchArgs = { + queryText: z + .string() + .max(1024, "queryText must be at most 1024 characters") + .describe( + "Raw search text/context that will be embedded using the configured embedding model; represents the vector search intent." + ), + numCandidates: z + .number() + .int() + .positive() + .default(100) + .describe("Number of approximate candidates to consider (higher = potentially better recall, more cost)"), + limit: z + .number() + .int() + .positive() + .default(10) + .describe("Maximum number of results to return"), + filter: z + .object({}) + .passthrough() + .optional() + .describe("Optional filter (standard query predicate) to apply before ranking results"), + includeVector: z + .boolean() + .optional() + .default(false) + .describe("If true, include the vector field in the projection (may be large)"), +}; + +export class VectorSearchV2Tool extends MongoDBToolBase { + public name = "vector-search"; + protected description = "Execute a vector similarity search on a MongoDB collection using $vectorSearch"; + protected argsShape = { + ...DbOperationArgs, + ...VectorSearchArgs, + }; + public operationType: OperationType = "read"; + + protected async execute({ + database, + collection, + queryText, + numCandidates, + limit, + filter, + includeVector, + }: ToolArgs): Promise { + const provider = await this.ensureConnected(); + + // Resolve path from config + const resolvedPath = this.config.vectorSearchPath; + if (!resolvedPath) { + throw new Error( + "Vector search requires 'path' argument to be provided while setting up MCP." + ); + } + + // Resolve index from config + const resolvedIndex = this.config.vectorSearchIndex; + if (!resolvedIndex) { + throw new Error( + "Vector search requires 'index' argument to be provided while setting up MCP." + ); + } + + if (!queryText) { + throw new Error("'queryText' must be provided to perform vector search"); + } + + // Always embed the provided queryText + const embeddingProvider = createEmbeddingProvider(this.config); + const embeddings = await embeddingProvider.embed([queryText]); + const queryVector = embeddings[0]; + if (!queryVector || queryVector.length === 0) { + throw new Error("Embedding provider returned empty embedding."); + } + + // Construct the $vectorSearch stage + const vectorStage: Record = { + $vectorSearch: { + queryVector, + path: resolvedPath, + limit, + numCandidates, + }, + }; + if (filter) { + (vectorStage.$vectorSearch as any).filter = filter; // eslint-disable-line @typescript-eslint/no-explicit-any + } + if (resolvedIndex) { + (vectorStage.$vectorSearch as any).index = resolvedIndex; // eslint-disable-line @typescript-eslint/no-explicit-any + } + + // Build the full pipeline. Optionally project out the vector field unless requested. + const pipeline: Record[] = [vectorStage]; + if (!includeVector) { + // Exclude the vector path by default to keep output concise (unless the path is dotted, project root minus that field) + const projection: Record = {}; + const topLevelPath = resolvedPath.split(".")[0] ?? resolvedPath; // ensure string + projection[topLevelPath as string] = 0; // We exclude; if user needs it they set includeVector=true + pipeline.push({ $project: projection }); + } + + const cursor = provider.aggregate(database, collection, pipeline); + const results = await cursor.toArray(); + + return { + content: formatUntrustedData( + `Vector search returned ${results.length} document(s) from collection "${collection}" using path "${resolvedPath}."`, + results.length > 0 ? EJSON.stringify(results) : undefined + ), + }; + } + + protected verifyAllowed(): boolean { + // Centralized embedding configuration validation + if (!EmbeddingProviderFactory.isEmbeddingConfigValid(this.config)) { + this.session.logger.warning({ + id: LogId.toolUpdateFailure, + context: "tool", + message: `Tool ${this.name} could not be configured due to incomplete embedding configuration.`, + noRedaction: true, + }); + return false; + } + + // For V2 semantics: BOTH vectorSearchIndex and vectorSearchPath must be set + if (!this.config.vectorSearchIndex || !this.config.vectorSearchPath) return false; + + return super.verifyAllowed(); + } +} diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index 00575ee05..09f1a866b 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -19,6 +19,8 @@ import { ExplainTool } from "./metadata/explain.js"; import { CreateCollectionTool } from "./create/createCollection.js"; import { LogsTool } from "./metadata/logs.js"; import { ExportTool } from "./read/export.js"; +import { VectorSearchV1Tool } from "./read/vectorSearchv1.js"; +import { VectorSearchV2Tool } from "./read/vectorSearchv2.js"; export const MongoDbTools = [ ConnectTool, @@ -42,4 +44,6 @@ export const MongoDbTools = [ CreateCollectionTool, LogsTool, ExportTool, + VectorSearchV1Tool, + VectorSearchV2Tool ]; diff --git a/src/transports/azureManagedIdentityAuth.ts b/src/transports/azureManagedIdentityAuth.ts new file mode 100644 index 000000000..186d74a78 --- /dev/null +++ b/src/transports/azureManagedIdentityAuth.ts @@ -0,0 +1,212 @@ +import type { Request, Response, NextFunction } from "express"; +import { createRemoteJWKSet, jwtVerify } from "jose"; +import { LRUCache } from "lru-cache"; +import { LogId, type LoggerBase } from "../common/logger.js"; +import type { UserConfig } from "../common/config.js"; + +// Simple cache for remote JWK set instances keyed by discovery URL +const jwksCache = new LRUCache>({ + max: 10, + ttl: 60 * 60 * 1000, // 1h +}); + +export interface AzureManagedIdentityAuthOptions { + tenantId: string; + audience?: string; // explicit audience override + clientId?: string; // fallback audience if explicit not provided +} + +function v2Issuer(tenantId: string): string { + return `https://login.microsoftonline.com/${tenantId}/v2.0`; +} + +function v1Issuer(tenantId: string): string { + // Legacy v1 tokens often have iss = https://sts.windows.net// + return `https://sts.windows.net/${tenantId}/`; +} + +function buildOpenIdConfigUrl(tenantId: string): string { + // We always fetch from the v2 discovery endpoint (jwks are valid for both) + return `${v2Issuer(tenantId)}/.well-known/openid-configuration`; +} + +async function getRemoteJwks(tenantId: string) { + const discoveryUrl = buildOpenIdConfigUrl(tenantId); + let jwks = jwksCache.get(discoveryUrl); + if (!jwks) { + const res = await fetch(discoveryUrl); + if (!res.ok) { + throw new Error(`Failed to fetch OpenID configuration: ${res.status} ${res.statusText}`); + } + const json = (await res.json()) as { jwks_uri: string }; + if (!json.jwks_uri) { + throw new Error("jwks_uri not found in OpenID configuration"); + } + jwks = createRemoteJWKSet(new URL(json.jwks_uri)); + jwksCache.set(discoveryUrl, jwks); + } + return jwks; +} + +export function azureManagedIdentityAuthMiddleware( + logger: LoggerBase, + userConfig: UserConfig +): (req: Request, res: Response, next: NextFunction) => void { + if (userConfig.httpAuthMode !== "azure-managed-identity") { + return (_req, _res, next) => next(); + } + + const opts: AzureManagedIdentityAuthOptions = { + tenantId: userConfig.azureManagedIdentityTenantId!, + audience: userConfig.azureManagedIdentityAudience, + clientId: userConfig.azureManagedIdentityClientId, + }; + + const expectedAud = opts.audience || opts.clientId; + const requiredRoles = userConfig.azureManagedIdentityRequiredRoles || []; + const roleMatchMode = userConfig.azureManagedIdentityRoleMatchMode || "all"; + const allowedAppIds = (userConfig.azureManagedIdentityAllowedAppIds || []).map((a) => a.toLowerCase()); + if (!expectedAud) { + logger.warning({ + id: 0 as any, + context: "azureManagedIdentityAuth", + message: "No audience or clientId configured; 'aud' claim will not be enforced.", + }); + } + + return async (req: Request, res: Response, next: NextFunction): Promise => { + try { + const authHeader = req.headers["authorization"]; + if (!authHeader || Array.isArray(authHeader)) { + res.status(401).json({ error: "missing authorization header" }); + return; + } + const match = authHeader.match(/^Bearer (.+)$/i); + if (!match) { + res.status(401).json({ error: "invalid authorization header" }); + return; + } + const token = match[1]!; // non-null assertion since regex with capture succeeded + const jwks = await getRemoteJwks(opts.tenantId); + let verification; + const issuersToTry = [v2Issuer(opts.tenantId), v1Issuer(opts.tenantId)]; + let lastErr: unknown; + for (const iss of issuersToTry) { + try { + verification = await jwtVerify(token, jwks, { + issuer: iss, + audience: expectedAud, // undefined means not enforced + }); + break; + } catch (e) { + lastErr = e; + } + } + if (!verification) { + throw lastErr ?? new Error("issuer validation failed"); + } + + // Basic sanity checks (subject, expiry handled by jose) + const payload = verification.payload as Record; + if (!payload.sub) { + logAuthFailure(logger, "missing-sub", payload, { + message: "token missing sub", + }); + res.status(401).json({ error: "unauthorized" }); + return; + } + + // Enforce tenant id (tid) match for safety + const configuredTid = opts.tenantId.toLowerCase(); + const tokenTid = (payload.tid || payload.tenantId || "").toLowerCase(); + if (!tokenTid) { + logAuthFailure(logger, "missing-tid", payload, { message: "token missing tid claim" }); + res.status(401).json({ error: "unauthorized" }); + return; + } + if (tokenTid !== configuredTid) { + logAuthFailure(logger, "tenant-mismatch", payload, { + message: `tenant mismatch expected=${configuredTid} got=${tokenTid}`, + }); + res.status(401).json({ error: "unauthorized" }); + return; + } + + // Allowed application IDs (appid or azp) enforcement + if (allowedAppIds.length > 0) { + const tokenAppId = (payload.appid || payload.azp || "").toLowerCase(); + if (!tokenAppId || !allowedAppIds.includes(tokenAppId)) { + logAuthFailure(logger, "appid-not-allowed", payload, { + message: `application id not allowed: ${tokenAppId || ""}`, + }); + res.status(401).json({ error: "unauthorized" }); + return; + } + } + + // App role enforcement: 'roles' claim (array) for application permissions + if (requiredRoles.length > 0) { + const rolesClaim = Array.isArray(payload.roles) ? payload.roles : []; + const tokenRoles = new Set(rolesClaim); + const missingRoles = requiredRoles.filter((r) => !tokenRoles.has(r)); + const roleConditionMet = + roleMatchMode === "all" ? missingRoles.length === 0 : missingRoles.length < requiredRoles.length; + if (!roleConditionMet) { + logAuthFailure(logger, "role-match-failed", payload, { + message: + roleMatchMode === "all" + ? `missing required roles: ${missingRoles.join(",")}` + : `none of the required roles present; required any of: ${requiredRoles.join(",")}`, + missingRoles, + }); + res.status(401).json({ error: "unauthorized" }); + return; + } + } + + // Attach claims for downstream handlers if needed + (req as any).azureManagedIdentity = payload; + next(); + } catch (err) { + logger.info({ + id: LogId.azureManagedIdentityAuthError, + context: "azureManagedIdentityAuth", + message: `Token verification failed: ${err instanceof Error ? err.message : String(err)}`, + }); + res.status(401).json({ error: "unauthorized" }); + } + }; +} + +interface FailureMeta { + message: string; + missingScopes?: string[]; + missingRoles?: string[]; +} + +function logAuthFailure( + logger: LoggerBase, + reason: + | "missing-sub" + | "missing-roles" + | "missing-tid" + | "tenant-mismatch" + | "role-match-failed" + | "appid-not-allowed", + claims: Record, + meta: FailureMeta +): void { + // Only log a limited snapshot of claims for security (avoid tokens, only non-sensitive claims) + const allowedKeys = ["aud", "iss", "sub", "scp", "roles", "appid", "tid", "oid", "exp", "nbf", "iat"]; + const snapshot: Record = {}; + for (const key of allowedKeys) { + if (key in claims) snapshot[key] = claims[key]; + } + logger.info({ + id: LogId.azureManagedIdentityAuthError, + context: "azureManagedIdentityAuth", + message: `Authorization failure (${reason}): ${meta.message} snapshot=${JSON.stringify(snapshot)} missingScopes=${meta.missingScopes?.join("|") ?? ""} missingRoles=${meta.missingRoles?.join("|") ?? ""}`, + }); +} + +// (scope-related reasons removed) diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 0a20e59e8..a78e568dc 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -6,6 +6,7 @@ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { LogId } from "../common/logger.js"; import { SessionStore } from "../common/sessionStore.js"; import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js"; +import { azureManagedIdentityAuthMiddleware } from "./azureManagedIdentityAuth.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -43,6 +44,10 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.enable("trust proxy"); // needed for reverse proxy support app.use(express.json()); + // Managed Identity auth (optional) + if (this.userConfig.httpAuthMode === "azure-managed-identity") { + app.use(azureManagedIdentityAuthMiddleware(this.logger, this.userConfig)); + } app.use((req, res, next) => { for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) { const header = req.headers[key.toLowerCase()]; diff --git a/tests/accuracy/vectorSearch.test.ts b/tests/accuracy/vectorSearch.test.ts new file mode 100644 index 000000000..3c3d0dab5 --- /dev/null +++ b/tests/accuracy/vectorSearch.test.ts @@ -0,0 +1,43 @@ +import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js"; +import { Matcher } from "./sdk/matcher.js"; + +// Accuracy tests for the new vector-search tool. These prompts are phrased in a way +// that the planner should infer the appropriate tool and arguments. We only +// check argument structure, not exact numeric array contents beyond basic shape. + +describeAccuracyTests([ + { + prompt: "Use the embeddings in 'ai.docs' to find the 5 most similar documents to the given vector [0.1, 0.2, 0.3].", + expectedToolCalls: [ + { + toolName: "vector-search", + parameters: { + database: "ai", + collection: "docs", + queryVector: [0.1, 0.2, 0.3], + path: Matcher.anyOf(Matcher.value("embedding"), Matcher.string()), + limit: 5, + // numCandidates may be defaulted; allow undefined or positive number + numCandidates: Matcher.anyOf(Matcher.undefined, Matcher.number((v) => v > 0)), + }, + }, + ], + }, + { + prompt: "In database 'ai', collection 'docs', perform a vector similarity search over field 'embedding' for vector [0.25,0.11,0.89,0.4] and return top 3 results including the raw embedding.", + expectedToolCalls: [ + { + toolName: "vector-search", + parameters: { + database: "ai", + collection: "docs", + queryVector: [0.25, 0.11, 0.89, 0.4], + path: "embedding", + limit: 3, + includeVector: true, + numCandidates: Matcher.anyOf(Matcher.undefined, Matcher.number((v) => v > 0)), + }, + }, + ], + }, +]); diff --git a/tests/integration/transports/stdio.test.ts b/tests/integration/transports/stdio.test.ts index aaa61d638..f6e9cd7e8 100644 --- a/tests/integration/transports/stdio.test.ts +++ b/tests/integration/transports/stdio.test.ts @@ -32,6 +32,7 @@ describeWithMongoDB("StdioRunner", (integration) => { const response = await client.listTools(); expect(response).toBeDefined(); expect(response.tools).toBeDefined(); + // Updated tool count expectation adjusted after recent changes expect(response.tools).toHaveLength(21); const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); diff --git a/tests/unit/embedding/azureOpenAIProviderRetry.test.ts b/tests/unit/embedding/azureOpenAIProviderRetry.test.ts new file mode 100644 index 000000000..d5537dee3 --- /dev/null +++ b/tests/unit/embedding/azureOpenAIProviderRetry.test.ts @@ -0,0 +1,46 @@ +import { describe, it, expect, vi } from 'vitest'; +import { AzureAIInferenceEmbeddingProvider } from '../../../src/embedding/azureAIInferenceEmbeddingProvider.js'; + +const baseConfig = { + endpoint: 'https://example.com/embeddings', + apiKey: 'KEY', + deployment: 'model', + maxRetries: 2, + initialDelayMs: 10, +}; + +describe('AzureAIInferenceEmbeddingProvider retry logic', () => { + it('retries transient 500 then succeeds', async () => { + const provider = new AzureAIInferenceEmbeddingProvider(baseConfig); + const fakeEmbedding = [0.1,0.2]; + const responses = [ + { ok: false, status: 500 }, + { ok: true, status: 200, json: async () => ({ data: [{ embedding: fakeEmbedding }] }) } + ]; + let call = 0; + global.fetch = vi.fn().mockImplementation(() => responses[call++]); + + const result = await provider.embed(['hello']); + expect(result[0]).toEqual(fakeEmbedding); + expect((fetch as any).mock.calls.length).toBe(2); + }); + + it('fails after max retries', async () => { + const provider = new AzureAIInferenceEmbeddingProvider({ ...baseConfig, maxRetries: 1, initialDelayMs: 5 }); + global.fetch = vi.fn().mockResolvedValue({ ok: false, status: 500 }); + await expect(provider.embed(['hello'])).rejects.toThrow(/ultimately failed/); + expect((fetch as any).mock.calls.length).toBe(2); // initial + 1 retry + }); + + it('does not retry on 400', async () => { + const provider = new AzureAIInferenceEmbeddingProvider(baseConfig); + global.fetch = vi.fn().mockResolvedValue({ + ok: false, + status: 400, + text: async () => 'Bad Request', + json: async () => ({ error: 'Bad Request' }) + }); + await expect(provider.embed(['hello'])).rejects.toThrow(/status 400/); + expect((fetch as any).mock.calls.length).toBe(1); + }); +}); diff --git a/tests/unit/embedding/embeddingProviderFactoryValidation.test.ts b/tests/unit/embedding/embeddingProviderFactoryValidation.test.ts new file mode 100644 index 000000000..75239569a --- /dev/null +++ b/tests/unit/embedding/embeddingProviderFactoryValidation.test.ts @@ -0,0 +1,32 @@ +import { describe, it, expect } from "vitest"; +import { EmbeddingProviderFactory } from "../../../src/embedding/embeddingProviderFactory.js"; +import type { UserConfig } from "../../../src/common/config.js"; + +function baseConfig(): Partial { + return { + embeddingModelProvider: "azure-ai-inference", + embeddingModelEndpoint: "https://example/", + embeddingModelApikey: "key", + embeddingModelDeploymentName: "deploy", + embeddingModelDimension: 1536, + } as Partial; +} + +describe("EmbeddingProviderFactory.isEmbeddingConfigValid", () => { + it("returns true for complete azure-ai-inference config", () => { + const cfg = baseConfig() as UserConfig; + expect(EmbeddingProviderFactory.isEmbeddingConfigValid(cfg)).toBe(true); + }); + + it("returns false when a required field missing", () => { + const cfg = baseConfig(); + delete cfg.embeddingModelApikey; + expect(EmbeddingProviderFactory.isEmbeddingConfigValid(cfg as UserConfig)).toBe(false); + }); + + it("throws with assertEmbeddingConfigValid when invalid", () => { + const cfg = baseConfig(); + delete cfg.embeddingModelDeploymentName; + expect(() => EmbeddingProviderFactory.assertEmbeddingConfigValid(cfg as UserConfig)).toThrow(); + }); +}); diff --git a/tests/unit/tools/mongodb/read/vectorSearchV1.test.ts b/tests/unit/tools/mongodb/read/vectorSearchV1.test.ts new file mode 100644 index 000000000..e60bb3584 --- /dev/null +++ b/tests/unit/tools/mongodb/read/vectorSearchV1.test.ts @@ -0,0 +1,115 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { VectorSearchV1Tool } from "../../../../../src/tools/mongodb/read/vectorSearchv1.js"; +import { Session } from "../../../../../src/common/session.js"; +import EventEmitter from "events"; +import { Telemetry } from "../../../../../src/telemetry/telemetry.js"; +import { Elicitation } from "../../../../../src/elicitation.js"; + +// Mock service provider with aggregate support only +class MockServiceProviderV1 { + public aggregateCalls: any[] = []; // eslint-disable-line @typescript-eslint/no-explicit-any + aggregate(_db: string, _coll: string, pipeline: any[]) { // eslint-disable-line @typescript-eslint/no-explicit-any + this.aggregateCalls.push(pipeline); + return { toArray: async () => ([{ _id: 1, embedding: [0.1,0.2], title: "Doc" }]) }; + } +} + +describe("VectorSearchV1Tool", () => { + const originalFetch = global.fetch; + let session: Session; + let tool: VectorSearchV1Tool; + let provider: MockServiceProviderV1; + + function buildSessionAndTool(overrides: Record = {}) { // eslint-disable-line @typescript-eslint/no-explicit-any + const connectionEvents = new EventEmitter(); + const connectionManager: any = { + events: connectionEvents, + currentConnectionState: { tag: "connected", serviceProvider: undefined, connectedAtlasCluster: undefined }, + setClientName: () => undefined, + disconnect: async () => undefined, + }; + provider = new MockServiceProviderV1(); + connectionManager.currentConnectionState.serviceProvider = provider as any; + + const exportsManager: any = { close: async () => undefined }; + const keychain: any = { register: () => undefined }; + const logger: any = { debug: () => undefined, info: () => undefined, warning: () => undefined, error: () => undefined }; + + session = new Session({ apiBaseUrl: "https://example.com/", logger, connectionManager, exportsManager, keychain }); + const baseConfig = { + disabledTools: [], + confirmationRequiredTools: [], + readOnly: false, + indexCheck: false, + transport: "stdio", + loggers: ["stderr"], + embeddingModelEndpoint: "https://example.test/embeddings", + embeddingModelApikey: "key", + embeddingModelDeploymentName: "text-embed", + embeddingModelDimension: 2, + telemetry: "disabled", + // NOTE: For V1 we must NOT set both vectorSearchPath & vectorSearchIndex in config (verifyAllowed would fail) + } as any; + + const config = { ...baseConfig, ...overrides }; + const telemetry = Telemetry.create(session as any, config, { get: async () => "device-id" } as any); + const elicitation = new Elicitation({ server: { getClientCapabilities: () => ({}) } as any }); + tool = new VectorSearchV1Tool({ session: session as any, config, telemetry: telemetry as any, elicitation: elicitation as any }); + } + + beforeEach(() => { + global.fetch = vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data: [{ embedding: [0.5, 0.6] }] }) }) as any; // eslint-disable-line @typescript-eslint/no-explicit-any + buildSessionAndTool(); + }); + + afterEach(() => { global.fetch = originalFetch; }); + + it("verifyAllowed returns false if both vectorSearchIndex and vectorSearchPath present in config (disallowed for V1)", () => { + buildSessionAndTool({ vectorSearchIndex: "idx", vectorSearchPath: "embedding" }); + expect((tool as any).verifyAllowed()).toBe(false); + }); + + it("verifyAllowed returns true with minimal embedding config and without vector index/path overrides", () => { + buildSessionAndTool(); + expect((tool as any).verifyAllowed()).toBe(true); + }); + + it("embeds queryText and builds expected pipeline", async () => { + const res = await (tool as any).execute({ + database: "ai", collection: "docs", queryText: "hello", path: "embedding", limit: 3, numCandidates: 50, includeVector: false }); + expect(res).toBeDefined(); + expect(global.fetch).toHaveBeenCalledOnce(); + const pipeline = provider.aggregateCalls[0]; + expect(pipeline[0].$vectorSearch.path).toBe("embedding"); + expect(pipeline[0].$vectorSearch.queryVector).toEqual([0.5, 0.6]); + expect(pipeline[1].$project.embedding).toBe(0); + }); + + it("includes vector field when includeVector=true", async () => { + await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", path: "embedding", limit: 1, numCandidates: 5, includeVector: true }); + const pipeline = provider.aggregateCalls[0]; + expect(pipeline.length).toBe(1); + }); + + it("injects filter when provided", async () => { + await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", path: "embedding", limit: 3, numCandidates: 25, filter: { category: "news" } }); + const pipeline = provider.aggregateCalls[provider.aggregateCalls.length - 1]; + expect(pipeline[0].$vectorSearch.filter).toEqual({ category: "news" }); + }); + + it("includes index in stage when index argument supplied", async () => { + await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", path: "embedding", index: "custom_index", limit: 2, numCandidates: 10 }); + const pipeline = provider.aggregateCalls[provider.aggregateCalls.length - 1]; + expect(pipeline[0].$vectorSearch.index).toBe("custom_index"); + }); + + it("throws if path missing", async () => { + await expect((tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", limit: 1, numCandidates: 5 })) + .rejects.toThrow(/path/); + }); + + it("throws if queryText missing", async () => { + await expect((tool as any).execute({ database: "ai", collection: "docs", path: "embedding", limit: 2, numCandidates: 10 })) + .rejects.toThrow(/'queryText' must be provided/); + }); +}); diff --git a/tests/unit/tools/mongodb/read/vectorSearchV2.test.ts b/tests/unit/tools/mongodb/read/vectorSearchV2.test.ts new file mode 100644 index 000000000..af133f435 --- /dev/null +++ b/tests/unit/tools/mongodb/read/vectorSearchV2.test.ts @@ -0,0 +1,133 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { VectorSearchV2Tool } from "../../../../../src/tools/mongodb/read/vectorSearchv2.js"; +import { Session } from "../../../../../src/common/session.js"; +import EventEmitter from "events"; +import { Telemetry } from "../../../../../src/telemetry/telemetry.js"; +import { Elicitation } from "../../../../../src/elicitation.js"; + +// Mock service provider implementing only aggregate +class MockServiceProviderV2 { + public aggregateCalls: any[] = []; // eslint-disable-line @typescript-eslint/no-explicit-any + aggregate(_db: string, _coll: string, pipeline: any[]) { // eslint-disable-line @typescript-eslint/no-explicit-any + this.aggregateCalls.push(pipeline); + return { toArray: async () => ([{ _id: 1, embedding: [0.1,0.2,0.3], title: "Doc" }]) }; + } +} + +describe("VectorSearchV2Tool", () => { + const originalFetch = global.fetch; + let session: Session; + let tool: VectorSearchV2Tool; + let provider: MockServiceProviderV2; + + function buildSessionAndTool(overrides: Record = {}) { // eslint-disable-line @typescript-eslint/no-explicit-any + const connectionEvents = new EventEmitter(); + const connectionManager: any = { + events: connectionEvents, + currentConnectionState: { tag: "connected", serviceProvider: undefined, connectedAtlasCluster: undefined }, + setClientName: () => undefined, + disconnect: async () => undefined, + }; + provider = new MockServiceProviderV2(); + connectionManager.currentConnectionState.serviceProvider = provider as any; + + const exportsManager: any = { close: async () => undefined }; + const keychain: any = { register: () => undefined }; + const logger: any = { debug: () => undefined, info: () => undefined, warning: () => undefined, error: () => undefined }; + + session = new Session({ apiBaseUrl: "https://example.com/", logger, connectionManager, exportsManager, keychain }); + const baseConfig = { + disabledTools: [], + confirmationRequiredTools: [], + readOnly: false, + indexCheck: false, + transport: "stdio", + loggers: ["stderr"], + telemetry: "disabled", + embeddingModelEndpoint: "https://example.test/embeddings", + embeddingModelApikey: "key", + embeddingModelDeploymentName: "text-embed", + embeddingModelDimension: 3, + vectorSearchPath: "embedding", + vectorSearchIndex: "vector_index", + } as any; + + const config = { ...baseConfig, ...overrides }; + const telemetry = Telemetry.create(session as any, config, { get: async () => "device-id" } as any); + const elicitation = new Elicitation({ server: { getClientCapabilities: () => ({}) } as any }); + tool = new VectorSearchV2Tool({ session: session as any, config, telemetry: telemetry as any, elicitation: elicitation as any }); + } + + beforeEach(() => { + global.fetch = vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data: [{ embedding: [0.9, 0.8, 0.7] }] }) }) as any; // eslint-disable-line @typescript-eslint/no-explicit-any + buildSessionAndTool(); + }); + + afterEach(() => { global.fetch = originalFetch; }); + + // verifyAllowed scenarios (mirroring pattern used in V1 tests but adapted for V2 semantics) + it("verifyAllowed returns true when all required config present", () => { + expect((tool as any).verifyAllowed()).toBe(true); + }); + + it("verifyAllowed returns false when path missing", () => { + buildSessionAndTool({ vectorSearchPath: undefined }); + expect((tool as any).verifyAllowed()).toBe(false); + }); + + it("verifyAllowed returns false when index missing", () => { + buildSessionAndTool({ vectorSearchIndex: undefined }); + expect((tool as any).verifyAllowed()).toBe(false); + }); + + it("verifyAllowed returns false when unsupported provider specified", () => { + buildSessionAndTool({ embeddingModelProvider: "other-provider" }); + expect((tool as any).verifyAllowed()).toBe(false); + }); + + it("embeds queryText using config path & index", async () => { + const res = await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello world", limit: 4, numCandidates: 50, includeVector: false }); + expect(res).toBeDefined(); + expect(global.fetch).toHaveBeenCalledOnce(); + const pipeline = provider.aggregateCalls[0]; + expect(pipeline[0].$vectorSearch.path).toBe("embedding"); + expect(pipeline[0].$vectorSearch.index).toBe("vector_index"); + expect(pipeline[0].$vectorSearch.queryVector).toEqual([0.9, 0.8, 0.7]); + expect(pipeline[1].$project.embedding).toBe(0); + }); + + it("omits projection when includeVector=true", async () => { + await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", limit: 1, numCandidates: 5, includeVector: true }); + const pipeline = provider.aggregateCalls[0]; + expect(pipeline.length).toBe(1); + }); + + it("injects filter when provided", async () => { + await (tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", limit: 3, numCandidates: 20, filter: { category: "a" } }); + const pipeline = provider.aggregateCalls[0]; + expect(pipeline[0].$vectorSearch.filter).toEqual({ category: "a" }); + }); + + it("throws when queryText missing", async () => { + await expect((tool as any).execute({ database: "ai", collection: "docs", limit: 5, numCandidates: 10 })) + .rejects.toThrow(/'queryText' must be provided/); + }); + + it("throws when path missing in config during execute", async () => { + buildSessionAndTool({ vectorSearchPath: undefined }); + await expect((tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", limit: 2, numCandidates: 10 })) + .rejects.toThrow(/requires 'path' argument/); + }); + + it("throws when index missing in config during execute", async () => { + buildSessionAndTool({ vectorSearchIndex: undefined }); + await expect((tool as any).execute({ database: "ai", collection: "docs", queryText: "hello", limit: 2, numCandidates: 10 })) + .rejects.toThrow(/requires 'index' argument/); + }); + + it("execute throws on unsupported embedding provider", async () => { + buildSessionAndTool({ embeddingModelProvider: "some-other-provider" }); + await expect((tool as any).execute({ database: "ai", collection: "docs", queryText: "hi", limit: 2, numCandidates: 10 })) + .rejects.toThrow(/Unsupported embedding model provider/); + }); +}); diff --git a/tests/unit/transports/azureManagedIdentityAuth.test.ts b/tests/unit/transports/azureManagedIdentityAuth.test.ts new file mode 100644 index 000000000..aad1646de --- /dev/null +++ b/tests/unit/transports/azureManagedIdentityAuth.test.ts @@ -0,0 +1,298 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import type { Request, Response } from "express"; +import type { UserConfig } from "../../../src/common/config.js"; +import { LoggerBase, LogId } from "../../../src/common/logger.js"; + +// --- Module mocks (must be declared before importing code under test) --- +const jwtVerifyMock = vi.fn(); +const createRemoteJWKSetMock = vi.fn(() => ({})); +vi.mock("jose", () => ({ + jwtVerify: (token: any, jwks: any, options: any) => jwtVerifyMock(token, jwks, options), + createRemoteJWKSet: () => createRemoteJWKSetMock(), +})); + +// Import AFTER mocks so middleware picks them up +import { azureManagedIdentityAuthMiddleware } from "../../../src/transports/azureManagedIdentityAuth.js"; + +class TestLogger extends LoggerBase { + protected readonly type = "mcp" as const; + public entries: { level: string; payload: any }[] = []; + protected logCore(level: any, payload: any): void { + this.entries.push({ level, payload }); + } + findMessage(sub: string) { + return this.entries.find((e) => e.payload.message.includes(sub)); + } + messagesById(id: number) { + return this.entries.filter((e) => e.payload.id?.__value === id); + } +} + +function baseConfig(partial: Partial): UserConfig { + return { + apiBaseUrl: "", + logPath: "", + exportsPath: "", + exportTimeoutMs: 0, + exportCleanupIntervalMs: 0, + disabledTools: [], + telemetry: "disabled" as any, + readOnly: false, + indexCheck: false, + confirmationRequiredTools: [], + transport: "http", + httpPort: 0, + httpHost: "", + loggers: [], + idleTimeoutMs: 0, + notificationTimeoutMs: 0, + httpHeaders: {}, + atlasTemporaryDatabaseUserLifetimeMs: 0, + httpAuthMode: "azure-managed-identity", + ...partial, + } as UserConfig; +} + +function mockReq(headers: Record = {}): Request { + return { headers } as unknown as Request; +} + +function mockRes() { + const json = vi.fn(); + const status = vi.fn(() => ({ json })); + return { status, json } as unknown as Response & { status: any; json: any }; +} + +// Helper to set up fetch discovery response +function mockDiscovery(ok: boolean, data?: any) { + return vi.spyOn(global, "fetch" as any).mockResolvedValue({ + ok, + status: ok ? 200 : 500, + statusText: ok ? "OK" : "ERR", + json: async () => data, + } as any); +} + +describe("azureManagedIdentityAuthMiddleware", () => { + let logger: TestLogger; + beforeEach(() => { + logger = new TestLogger(undefined as any); + jwtVerifyMock.mockReset(); + createRemoteJWKSetMock.mockClear(); + }); + afterEach(() => { + vi.restoreAllMocks(); + }); + + it("bypasses when mode not enabled", async () => { + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ httpAuthMode: "none" })); + const next = vi.fn(); + await mw(mockReq(), mockRes(), next); + expect(next).toHaveBeenCalled(); + }); + + it("returns 401 when authorization header missing", async () => { + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid1" })); + const res = mockRes(); + const next = vi.fn(); + await mw(mockReq(), res, next); + expect(res.status).toHaveBeenCalledWith(401); + expect(res.status.mock.results[0].value.json).toHaveBeenCalledWith({ error: "missing authorization header" }); + expect(next).not.toHaveBeenCalled(); + }); + + it("returns 401 when authorization header malformed", async () => { + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid2" })); + const res = mockRes(); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bad token" }), res, next); + expect(res.status).toHaveBeenCalledWith(401); + expect(res.status.mock.results[0].value.json).toHaveBeenCalledWith({ error: "invalid authorization header" }); + expect(next).not.toHaveBeenCalled(); + }); + + it("successfully authenticates and attaches claims", async () => { + const token = "abc.def.ghi"; + const fetchSpy = mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "user", tid: "tid3", aud: "api://aud", appid: "app123" } }); + + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid3", azureManagedIdentityAudience: "api://aud" }) + ); + const next = vi.fn(); + const req = mockReq({ authorization: `Bearer ${token}` }); + const res = mockRes(); + await mw(req, res, next); + + expect(next).toHaveBeenCalled(); + expect(fetchSpy).toHaveBeenCalledOnce(); + expect(jwtVerifyMock).toHaveBeenCalled(); + expect((req as any).azureManagedIdentity.sub).toBe("user"); + }); + + it("falls back to v1 issuer after v2 failure", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock + .mockRejectedValueOnce(new Error("issuer mismatch")) + .mockResolvedValueOnce({ payload: { sub: "x", tid: "tid4" } }); + + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid4" })); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bearer tok" }), mockRes(), next); + expect(jwtVerifyMock).toHaveBeenCalledTimes(2); + const issuers = jwtVerifyMock.mock.calls.map((c) => c[2].issuer); + expect(issuers[0]).toMatch(/login\.microsoftonline/); + expect(issuers[1]).toMatch(/sts\.windows\.net/); + expect(next).toHaveBeenCalled(); + }); + + it("fails when sub missing", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { tid: "tid5" } }); + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid5" })); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + expect(logger.findMessage("missing-sub")).toBeTruthy(); + }); + + it("fails when tid missing", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u" } }); + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid6" })); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + expect(logger.findMessage("missing-tid")).toBeTruthy(); + }); + + it("fails when tenant mismatch", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "other" } }); + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid7" })); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + expect(logger.findMessage("tenant-mismatch")).toBeTruthy(); + }); + + it("enforces allowedAppIds (denied)", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid8", appid: "bad" } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid8", azureManagedIdentityAllowedAppIds: ["good"] }) + ); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + expect(logger.findMessage("appid-not-allowed")).toBeTruthy(); + }); + + it("enforces allowedAppIds (allowed)", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid9", appid: "good" } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid9", azureManagedIdentityAllowedAppIds: ["GOOD"] }) + ); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bearer t" }), mockRes(), next); + expect(next).toHaveBeenCalled(); + }); + + it("role enforcement all mode fails", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid10", roles: ["r1"] } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid10", azureManagedIdentityRequiredRoles: ["r1", "r2"] }) + ); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + expect(logger.findMessage("role-match-failed")).toBeTruthy(); + }); + + it("role enforcement all mode succeeds", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid11", roles: ["r1", "r2", "extra"] } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid11", azureManagedIdentityRequiredRoles: ["r1", "r2"] }) + ); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bearer t" }), mockRes(), next); + expect(next).toHaveBeenCalled(); + }); + + it("role enforcement any mode fails", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid12", roles: ["other"] } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ + azureManagedIdentityTenantId: "tid12", + azureManagedIdentityRequiredRoles: ["r1", "r2"], + azureManagedIdentityRoleMatchMode: "any", + }) + ); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + }); + + it("role enforcement any mode succeeds", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid13", roles: ["r2"] } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ + azureManagedIdentityTenantId: "tid13", + azureManagedIdentityRequiredRoles: ["r1", "r2"], + azureManagedIdentityRoleMatchMode: "any", + }) + ); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bearer t" }), mockRes(), next); + expect(next).toHaveBeenCalled(); + }); + + it("logs warning when no audience/clientId configured", async () => { + mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid14" } }); + const mw = azureManagedIdentityAuthMiddleware( + logger, + baseConfig({ azureManagedIdentityTenantId: "tid14", azureManagedIdentityAudience: undefined, azureManagedIdentityClientId: undefined }) + ); + const next = vi.fn(); + await mw(mockReq({ authorization: "Bearer t" }), mockRes(), next); + expect(next).toHaveBeenCalled(); + const warn = logger.entries.find((e) => e.level === "warning" && e.payload.message.includes("No audience")); + expect(warn).toBeTruthy(); + }); + + it("caches JWK set for same tenant", async () => { + const fetchSpy = mockDiscovery(true, { jwks_uri: "https://example/jwks" }); + jwtVerifyMock.mockResolvedValue({ payload: { sub: "u", tid: "tid15" } }); + const config = baseConfig({ azureManagedIdentityTenantId: "tid15" }); + const mw1 = azureManagedIdentityAuthMiddleware(logger, config); + const mw2 = azureManagedIdentityAuthMiddleware(logger, config); + await mw1(mockReq({ authorization: "Bearer a" }), mockRes(), vi.fn()); + await mw2(mockReq({ authorization: "Bearer b" }), mockRes(), vi.fn()); + expect(fetchSpy).toHaveBeenCalledOnce(); + }); + + it("handles fetch discovery failure", async () => { + const fetchSpy = mockDiscovery(false, {}); + const mw = azureManagedIdentityAuthMiddleware(logger, baseConfig({ azureManagedIdentityTenantId: "tid16" })); + const res = mockRes(); + await mw(mockReq({ authorization: "Bearer t" }), res, vi.fn()); + expect(res.status).toHaveBeenCalledWith(401); + // logged as info with Token verification failed + const info = logger.findMessage("Token verification failed"); + expect(info).toBeTruthy(); + expect(fetchSpy).toHaveBeenCalled(); + }); +}); From aa673711da80f3d22e3411ebb8b40d5f2c05a394 Mon Sep 17 00:00:00 2001 From: Anshul Khantwal Date: Thu, 25 Sep 2025 15:06:40 +0530 Subject: [PATCH 2/2] Added README.md for vector-search tool --- README.md | 145 ++++++++++++++++++++++++++++++++++++++++++- src/common/config.ts | 17 +++-- 2 files changed, 152 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index e5915ed22..4dec4b272 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ A Model Context Protocol server for interacting with MongoDB Databases and Mongo - [📄 Supported Resources](#supported-resources) - [⚙️ Configuration](#configuration) - [Configuration Options](#configuration-options) + - [Vector Search & Embeddings](#vector-search-and-embeddings) - [Atlas API Access](#atlas-api-access) - [Configuration Methods](#configuration-methods) - [Environment Variables](#environment-variables) @@ -320,6 +321,7 @@ NOTE: atlas tools are only available when you set credentials on [configuration] - `collection-storage-size` - Get the size of a collection in MB - `db-stats` - Return statistics about a MongoDB database - `export` - Export query or aggregation results to EJSON format. Creates a uniquely named export accessible via the `exported-data` resource. +- `vector-search` - Execute a vector similarity search ($vectorSearch) over a collection. See [Vector Search & Embeddings](#vector-search--embeddings). ## 📄 Supported Resources @@ -361,6 +363,13 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow | `exportTimeoutMs` | `MDB_MCP_EXPORT_TIMEOUT_MS` | 300000 | Time in milliseconds after which an export is considered expired and eligible for cleanup. | | `exportCleanupIntervalMs` | `MDB_MCP_EXPORT_CLEANUP_INTERVAL_MS` | 120000 | Time in milliseconds between export cleanup cycles that remove expired export files. | | `atlasTemporaryDatabaseUserLifetimeMs` | `MDB_MCP_ATLAS_TEMPORARY_DATABASE_USER_LIFETIME_MS` | 14400000 | Time in milliseconds that temporary database users created when connecting to MongoDB Atlas clusters will remain active before being automatically deleted. | +| `vectorSearchPath` | `MDB_MCP_VECTOR_SEARCH_PATH` | | Default vector field path used by `vector-search` (V2 mode). If set together with `vectorSearchIndex`, the V2 vector search tool variant is enabled. | +| `vectorSearchIndex` | `MDB_MCP_VECTOR_SEARCH_INDEX` | | Default vector search index name used by `vector-search` (V2 mode). Must be set with `vectorSearchPath` to enable V2 mode. | +| `embeddingModelProvider` | `MDB_MCP_EMBEDDING_MODEL_PROVIDER` | azure-ai-inference | Embedding model provider identifier. Currently only `azure-ai-inference` is supported. | +| `embeddingModelEndpoint` | `MDB_MCP_EMBEDDING_MODEL_ENDPOINT` | | Endpoint for the embedding model provider. Required for vector search. | +| `embeddingModelApikey` | `MDB_MCP_EMBEDDING_MODEL_APIKEY` | | API key/credential for the embedding model provider. Required for vector search. | +| `embeddingModelDeploymentName` | `MDB_MCP_EMBEDDING_MODEL_DEPLOYMENT_NAME` | | Deployment/model name to use when requesting embeddings. Required for vector search. | +| `embeddingModelDimension` | `MDB_MCP_EMBEDDING_MODEL_DIMENSION` | | (Optional) Expected embedding dimension for validation (provider specific). | #### Logger Options @@ -482,6 +491,140 @@ You can disable telemetry using: > **💡 Platform Note:** For Windows users, see [Environment Variables](#environment-variables) for platform-specific instructions. +### Vector Search and Embeddings + +The `vector-search` tool lets you run semantic similarity queries against a MongoDB collection using the `$vectorSearch` aggregation stage. This capability is disabled unless a valid embedding configuration is supplied (see below). + +#### Overview + +Two internal variants of the `vector-search` tool may register depending on configuration: + +1. V1 (argument-driven): You supply `path` and optionally `index` as tool arguments each call. +2. V2 (config-driven): You preconfigure both `vectorSearchPath` and `vectorSearchIndex` in server config; the tool omits those arguments and always searches that path/index. + +Variant selection rules: + +- If BOTH `MDB_MCP_VECTOR_SEARCH_PATH` and `MDB_MCP_VECTOR_SEARCH_INDEX` are set at startup → V2 registers. +- If NEITHER (or only one) of those is set → V1 registers, and you must provide a `path` argument per invocation (and may provide `index`). +- If embedding config is incomplete, the tool is not registered (you will see a warning in logs). + +#### Required MongoDB Setup + +1. A collection with a vector field (array of float/number values) containing stored embeddings. +2. A vector search index created on that field (e.g. Atlas Search vector index) when you want to leverage indexing for performance/recall. + +#### Embedding Configuration (Required) + +You must configure an embedding provider so the server can transform the `queryText` you pass in into a numeric embedding vector. Current provider support: + +- `azure-ai-inference` (default if none specified) + +Set the following environment variables (or CLI args) for Azure AI Inference: + +```bash +export MDB_MCP_EMBEDDING_MODEL_ENDPOINT="https://your-azure-resource.services.ai.azure.com/models/embeddings?api-version=2024-05-01-preview" +export MDB_MCP_EMBEDDING_MODEL_APIKEY="" +export MDB_MCP_EMBEDDING_MODEL_DEPLOYMENT_NAME="text-embedding-3-large" # or your deployed embedding model +# (Optional) if you want to assert embedding size +export MDB_MCP_EMBEDDING_MODEL_DIMENSION=3072 +``` + +Without these, `vector-search` will not register. + +#### Optional Vector Search Defaults (Enable V2 Mode) + +To eliminate passing `path` (and optionally `index`) each call, set both: + +```bash +export MDB_MCP_VECTOR_SEARCH_PATH="embedding" # e.g. field path storing embeddings +export MDB_MCP_VECTOR_SEARCH_INDEX="myVectorIndex" # name of the Atlas Search vector index +``` + +If both are present at startup, the V2 variant is loaded and you no longer pass `path`/`index` arguments at call time. Remove one or both to revert to V1. + +#### Usage Examples + +##### Example 1: V1 Variant (no defaults configured) + +Tool invocation arguments: + +```json +{ + "name": "vector-search", + "arguments": { + "database": "mydb", + "collection": "articles", + "queryText": "vector databases for personalization", + "path": "embedding", + "limit": 5, + "numCandidates": 200, + "includeVector": false + } +} +``` + +##### Example 2: V2 Variant (defaults configured) + +With `MDB_MCP_VECTOR_SEARCH_PATH=embedding` and `MDB_MCP_VECTOR_SEARCH_INDEX=myVectorIndex` set at startup: + +```json +{ + "name": "vector-search", + "arguments": { + "database": "mydb", + "collection": "articles", + "queryText": "vector databases for personalization", + "limit": 5, + "numCandidates": 200 + } +} +``` + +#### Returned Data + +The tool returns an array of matched documents. By default the raw embedding field is excluded (set `includeVector: true` if you need it). Standard result size safeguards (`maxDocumentsPerQuery`, `maxBytesPerQuery`) still apply. + +#### Adding a Custom Embedding Provider + +You can extend the server to support additional embedding services (e.g. OpenAI, Hugging Face, Vertex AI) by implementing the `EmbeddingProvider` interface: + +`src/embedding/embeddingProvider.ts`: + +```ts +export interface EmbeddingProvider { + name: string; + embed(input: string[]): Promise; +} +``` + +Steps: + +1. Create a new file under `src/embedding/`, e.g. `myProviderEmbeddingProvider.ts`, implementing the interface. +2. Add a new case in `EmbeddingProviderFactory.create()` & `isEmbeddingConfigValid()` matching a unique `embeddingModelProvider` string (e.g. `my-provider`). +3. Document required env vars (e.g. `MDB_MCP_EMBEDDING_MODEL_ENDPOINT`, `MDB_MCP_EMBEDDING_MODEL_APIKEY`, etc. or new ones) and update README. +4. (Optional) Support provider‑specific validation (dimension, model name) in `assertEmbeddingConfigValid`. +5. Provide tests (unit + integration if vector search depends on it) ensuring your provider returns deterministic dimensionality. + +After adding your provider, users enable it by setting: + +```bash +export MDB_MCP_EMBEDDING_MODEL_PROVIDER=my-provider +# plus any provider-specific variables you defined +``` + +If your provider requires different variable names, follow the existing naming convention: prefix with `MDB_MCP_` and document them. + +#### Troubleshooting + +| Symptom | Likely Cause | Action | +| ------- | ------------ | ------ | +| `vector-search` tool missing | Incomplete embedding config | Set endpoint, api key, deployment name env vars. Restart client. | +| Error: "Embedding provider returned empty embedding" | Provider/network issue | Check credentials & network; verify model supports embeddings. | +| Error requiring 'path' even though I set env vars | Only one of PATH/INDEX set | Set BOTH `MDB_MCP_VECTOR_SEARCH_PATH` and `MDB_MCP_VECTOR_SEARCH_INDEX` or remove both. | +| High latency | Large `numCandidates` or remote model slowness | Lower `numCandidates`; verify model region proximity. | + +--- + ### Atlas API Access To use the Atlas API tools, you'll need to create a service account in MongoDB Atlas: @@ -680,6 +823,6 @@ connecting to the Atlas API, your MongoDB Cluster, or any other external calls to third-party services like OID Providers. The behaviour is the same as what `mongosh` does, so the same settings will work in the MCP Server. -## 🤝Contributing +## Contributing Interested in contributing? Great! Please check our [Contributing Guide](CONTRIBUTING.md) for guidelines on code contributions, standards, adding new tools, and troubleshooting information. diff --git a/src/common/config.ts b/src/common/config.ts index 5e4ee66e2..6ba6232fc 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -22,13 +22,13 @@ const OPTIONS = { "notificationTimeoutMs", "telemetry", "transport", - "httpAuthMode", - "azureManagedIdentityTenantId", - "azureManagedIdentityClientId", - "azureManagedIdentityAudience", - "azureManagedIdentityRequiredRoles", - "azureManagedIdentityAllowedAppIds", - "azureManagedIdentityRoleMatchMode", + "httpAuthMode", + "azureManagedIdentityTenantId", + "azureManagedIdentityClientId", + "azureManagedIdentityAudience", + "azureManagedIdentityRequiredRoles", + "azureManagedIdentityAllowedAppIds", + "azureManagedIdentityRoleMatchMode", "apiVersion", "authenticationDatabase", "authenticationMechanism", @@ -69,7 +69,6 @@ const OPTIONS = { "embeddingModelDimension", "embeddingModelDeploymentName", "embeddingModelProvider", - // Removed retry tunables (maxRetries & retryInitialDelayMs) now fixed internally ], boolean: [ "apiDeprecationErrors", @@ -219,7 +218,7 @@ export interface UserConfig extends CliOptions { embeddingModelEndpoint?: string; // MDB_MCP_EMBEDDING_MODEL_ENDPOINT embeddingModelApikey?: string; // MDB_MCP_EMBEDDING_MODEL_APIKEY embeddingModelDeploymentName?: string; // MDB_MCP_EMBEDDING_MODEL_DEPLOYMENT_NAME - embeddingModelDimension?: number; // MDB_MCP_EMBEDDING_MODEL_DIMENSION + embeddingModelDimension?: number; // [Optional] MDB_MCP_EMBEDDING_MODEL_DIMENSION } export const defaultUserConfig: UserConfig = {