Skip to content

Commit ca1bdd6

Browse files
fix: fixes vector stage input parsing
Users were unable to use VectorSearch stage because Zod was using our catch all stage (AnyAggregateStage) before VectorSearchStage to parse and validate even the vector seach stage input and it passed through because that schema is a catch all schema. Because of that, in the input that we received, outputDimension was not transformed ever. This commit changes the order of application of schema so that VectorSeachStage schema is validated first and then the catch all stage schema.
1 parent eab28b3 commit ca1bdd6

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

src/tools/mongodb/read/aggregate.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ const genericPipelineDescription = "An array of aggregation stages to execute.";
4343
export const getAggregateArgs = (vectorSearchEnabled: boolean) =>
4444
({
4545
pipeline: z
46-
.array(vectorSearchEnabled ? z.union([AnyAggregateStage, VectorSearchStage]) : AnyAggregateStage)
46+
.array(vectorSearchEnabled ? z.union([VectorSearchStage, AnyAggregateStage]) : AnyAggregateStage)
4747
.describe(vectorSearchEnabled ? pipelineDescriptionWithVectorSearch : genericPipelineDescription),
4848
responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\
4949
The maximum number of bytes to return in the response. This value is capped by the server's configured maxBytesPerQuery and cannot be exceeded. \

tests/integration/tools/mongodb/read/aggregate.test.ts

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import {
1616
import * as constants from "../../../../../src/helpers/constants.js";
1717
import { freshInsertDocuments } from "./find.test.js";
1818
import { BSON } from "bson";
19+
import { DOCUMENT_EMBEDDINGS } from "./vyai/embeddings.js";
1920

2021
describeWithMongoDB("aggregate tool", (integration) => {
2122
afterEach(() => {
@@ -384,8 +385,6 @@ describeWithMongoDB(
384385
}
385386
);
386387

387-
import { DOCUMENT_EMBEDDINGS } from "./vyai/embeddings.js";
388-
389388
describeWithMongoDB(
390389
"aggregate tool with atlas search enabled",
391390
(integration) => {
@@ -921,6 +920,69 @@ If the user requests additional filtering, include filters in \`$vectorSearch.fi
921920
);
922921
});
923922
});
923+
924+
describe("outputDimension transformation", () => {
925+
it.each([
926+
{ numDimensions: 2048, outputDimension: "2048" },
927+
{ numDimensions: 4096, outputDimension: "4096" },
928+
])(
929+
"should successfully transform outputDimension string '$outputDimension' to number",
930+
async ({ numDimensions, outputDimension }) => {
931+
await waitUntilSearchIsReady(integration.mongoClient());
932+
933+
const collection = integration.mongoClient().db(integration.randomDbName()).collection("databases");
934+
await collection.insertOne({ name: "mongodb", description_embedding: DOCUMENT_EMBEDDINGS.float });
935+
936+
await createVectorSearchIndexAndWait(
937+
integration.mongoClient(),
938+
integration.randomDbName(),
939+
"databases",
940+
[
941+
{
942+
type: "vector",
943+
path: "description_embedding",
944+
numDimensions,
945+
similarity: "cosine",
946+
quantization: "none",
947+
},
948+
]
949+
);
950+
951+
await integration.connectMcpClient();
952+
const response = await integration.mcpClient().callTool({
953+
name: "aggregate",
954+
arguments: {
955+
database: integration.randomDbName(),
956+
collection: "databases",
957+
pipeline: [
958+
{
959+
$vectorSearch: {
960+
index: "default",
961+
path: "description_embedding",
962+
queryVector: "example query",
963+
numCandidates: 10,
964+
limit: 10,
965+
embeddingParameters: {
966+
model: "voyage-3-large",
967+
outputDimension, // Pass as string literal
968+
},
969+
},
970+
},
971+
{
972+
$project: {
973+
description_embedding: 0,
974+
},
975+
},
976+
],
977+
},
978+
});
979+
980+
const responseContent = getResponseContent(response);
981+
// String should succeed and be transformed to number internally
982+
expect(responseContent).toContain("The aggregation resulted in");
983+
}
984+
);
985+
});
924986
},
925987
{
926988
getUserConfig: () => ({

0 commit comments

Comments
 (0)