From ca1bdd610c1d32ef4bc779a2a5aca4e86def87cc Mon Sep 17 00:00:00 2001 From: Himanshu Singh Date: Thu, 11 Dec 2025 11:33:38 +0100 Subject: [PATCH] fix: fixes vector stage input parsing Users were unable to use VectorSearch stage because Zod was using our catch all stage (AnyAggregateStage) before VectorSearchStage to parse and validate even the vector seach stage input and it passed through because that schema is a catch all schema. Because of that, in the input that we received, outputDimension was not transformed ever. This commit changes the order of application of schema so that VectorSeachStage schema is validated first and then the catch all stage schema. --- src/tools/mongodb/read/aggregate.ts | 2 +- .../tools/mongodb/read/aggregate.test.ts | 66 ++++++++++++++++++- 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index b98ed68c3..8c87dbc7d 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -43,7 +43,7 @@ const genericPipelineDescription = "An array of aggregation stages to execute."; export const getAggregateArgs = (vectorSearchEnabled: boolean) => ({ pipeline: z - .array(vectorSearchEnabled ? z.union([AnyAggregateStage, VectorSearchStage]) : AnyAggregateStage) + .array(vectorSearchEnabled ? z.union([VectorSearchStage, AnyAggregateStage]) : AnyAggregateStage) .describe(vectorSearchEnabled ? pipelineDescriptionWithVectorSearch : genericPipelineDescription), responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\ The maximum number of bytes to return in the response. This value is capped by the server's configured maxBytesPerQuery and cannot be exceeded. \ diff --git a/tests/integration/tools/mongodb/read/aggregate.test.ts b/tests/integration/tools/mongodb/read/aggregate.test.ts index f35ada510..8a0c0adda 100644 --- a/tests/integration/tools/mongodb/read/aggregate.test.ts +++ b/tests/integration/tools/mongodb/read/aggregate.test.ts @@ -16,6 +16,7 @@ import { import * as constants from "../../../../../src/helpers/constants.js"; import { freshInsertDocuments } from "./find.test.js"; import { BSON } from "bson"; +import { DOCUMENT_EMBEDDINGS } from "./vyai/embeddings.js"; describeWithMongoDB("aggregate tool", (integration) => { afterEach(() => { @@ -384,8 +385,6 @@ describeWithMongoDB( } ); -import { DOCUMENT_EMBEDDINGS } from "./vyai/embeddings.js"; - describeWithMongoDB( "aggregate tool with atlas search enabled", (integration) => { @@ -921,6 +920,69 @@ If the user requests additional filtering, include filters in \`$vectorSearch.fi ); }); }); + + describe("outputDimension transformation", () => { + it.each([ + { numDimensions: 2048, outputDimension: "2048" }, + { numDimensions: 4096, outputDimension: "4096" }, + ])( + "should successfully transform outputDimension string '$outputDimension' to number", + async ({ numDimensions, outputDimension }) => { + await waitUntilSearchIsReady(integration.mongoClient()); + + const collection = integration.mongoClient().db(integration.randomDbName()).collection("databases"); + await collection.insertOne({ name: "mongodb", description_embedding: DOCUMENT_EMBEDDINGS.float }); + + await createVectorSearchIndexAndWait( + integration.mongoClient(), + integration.randomDbName(), + "databases", + [ + { + type: "vector", + path: "description_embedding", + numDimensions, + similarity: "cosine", + quantization: "none", + }, + ] + ); + + await integration.connectMcpClient(); + const response = await integration.mcpClient().callTool({ + name: "aggregate", + arguments: { + database: integration.randomDbName(), + collection: "databases", + pipeline: [ + { + $vectorSearch: { + index: "default", + path: "description_embedding", + queryVector: "example query", + numCandidates: 10, + limit: 10, + embeddingParameters: { + model: "voyage-3-large", + outputDimension, // Pass as string literal + }, + }, + }, + { + $project: { + description_embedding: 0, + }, + }, + ], + }, + }); + + const responseContent = getResponseContent(response); + // String should succeed and be transformed to number internally + expect(responseContent).toContain("The aggregation resulted in"); + } + ); + }); }, { getUserConfig: () => ({