1- import { describe , it , expect , beforeAll , afterAll , vi , afterEach } from "vitest" ;
1+ import { describe , it , expect } from "vitest" ;
22import { ToolBase , type ToolArgs } from "../../src/tools/index.js" ;
33import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js" ;
44import { z } from "zod" ;
5- import { Session } from "../../src/common/session.js" ;
6- import { Server } from "../../src/server.js" ;
75import type { TelemetryToolMetadata } from "../../src/telemetry/types.js" ;
8- import { CompositeLogger } from "../../src/common/logger.js" ;
9- import { ExportsManager } from "../../src/common/exportsManager.js" ;
10- import { Telemetry } from "../../src/telemetry/telemetry.js" ;
11- import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js" ;
12- import { Client } from "@modelcontextprotocol/sdk/client/index.js" ;
13- import { InMemoryTransport } from "./inMemoryTransport.js" ;
14- import { MCPConnectionManager } from "../../src/common/connectionManager.js" ;
15- import { DeviceId } from "../../src/helpers/deviceId.js" ;
16- import { connectionErrorHandler } from "../../src/common/connectionErrorHandler.js" ;
17- import { Keychain } from "../../src/common/keychain.js" ;
18- import { Elicitation } from "../../src/elicitation.js" ;
19- import { defaultTestConfig , driverOptions } from "./helpers.js" ;
20- import { VectorSearchEmbeddingsManager } from "../../src/common/search/vectorSearchEmbeddingsManager.js" ;
21- import { defaultCreateAtlasLocalClient } from "../../src/common/atlasLocal.js" ;
6+ import { defaultTestConfig , driverOptions , setupIntegrationTest } from "./helpers.js" ;
227
238describe ( "Custom Tools" , ( ) => {
24- let mcpClient : Client ;
25- let mcpServer : Server ;
26- let deviceId : DeviceId ;
27-
28- beforeAll ( async ( ) => {
29- const userConfig = { ...defaultTestConfig } ;
30-
31- const clientTransport = new InMemoryTransport ( ) ;
32- const serverTransport = new InMemoryTransport ( ) ;
33- const logger = new CompositeLogger ( ) ;
34-
35- await serverTransport . start ( ) ;
36- await clientTransport . start ( ) ;
37-
38- void clientTransport . output . pipeTo ( serverTransport . input ) ;
39- void serverTransport . output . pipeTo ( clientTransport . input ) ;
40-
41- mcpClient = new Client (
42- {
43- name : "test-client" ,
44- version : "1.2.3" ,
9+ const { mcpClient, mcpServer } = setupIntegrationTest (
10+ ( ) => ( { ...defaultTestConfig } ) ,
11+ ( ) => driverOptions ,
12+ {
13+ serverOptions : {
14+ tools : [ CustomGreetingTool , CustomCalculatorTool ] ,
4515 } ,
46- {
47- capabilities : { } ,
48- }
49- ) ;
50-
51- const exportsManager = ExportsManager . init ( userConfig , logger ) ;
52-
53- deviceId = DeviceId . create ( logger ) ;
54- const connectionManager = new MCPConnectionManager ( userConfig , driverOptions , logger , deviceId ) ;
55-
56- const session = new Session ( {
57- apiBaseUrl : userConfig . apiBaseUrl ,
58- apiClientId : userConfig . apiClientId ,
59- apiClientSecret : userConfig . apiClientSecret ,
60- logger,
61- exportsManager,
62- connectionManager,
63- keychain : new Keychain ( ) ,
64- vectorSearchEmbeddingsManager : new VectorSearchEmbeddingsManager ( userConfig , connectionManager ) ,
65- atlasLocalClient : await defaultCreateAtlasLocalClient ( ) ,
66- } ) ;
67-
68- // Mock hasValidAccessToken for tests
69- if ( ! userConfig . apiClientId && ! userConfig . apiClientSecret ) {
70- const mockFn = vi . fn ( ) . mockResolvedValue ( true ) ;
71- session . apiClient . validateAccessToken = mockFn ;
72- }
73-
74- userConfig . telemetry = "disabled" ;
75-
76- const telemetry = Telemetry . create ( session , userConfig , deviceId ) ;
77-
78- const mcpServerInstance = new McpServer ( {
79- name : "test-server" ,
80- version : "5.2.3" ,
81- } ) ;
82-
83- const elicitation = new Elicitation ( { server : mcpServerInstance . server } ) ;
84-
85- mcpServer = new Server ( {
86- session,
87- userConfig,
88- telemetry,
89- mcpServer : mcpServerInstance ,
90- elicitation,
91- connectionErrorHandler,
92- tools : [ CustomGreetingTool , CustomCalculatorTool ] ,
93- } ) ;
94-
95- await mcpServer . connect ( serverTransport ) ;
96- await mcpClient . connect ( clientTransport ) ;
97- } ) ;
98-
99- afterEach ( async ( ) => {
100- if ( mcpServer ) {
101- await mcpServer . session . disconnect ( ) ;
10216 }
103-
104- vi . clearAllMocks ( ) ;
105- } ) ;
106-
107- afterAll ( async ( ) => {
108- await mcpClient . close ( ) ;
109-
110- await mcpServer . close ( ) ;
111-
112- deviceId . close ( ) ;
113- } ) ;
17+ ) ;
11418
11519 it ( "should register custom tools instead of default tools" , async ( ) => {
11620 // Check that custom tools are registered
117- const tools = await mcpClient . listTools ( ) ;
21+ const tools = await mcpClient ( ) . listTools ( ) ;
11822 const customGreetingTool = tools . tools . find ( ( t ) => t . name === "custom_greeting" ) ;
11923 const customCalculatorTool = tools . tools . find ( ( t ) => t . name === "custom_calculator" ) ;
12024
@@ -127,7 +31,7 @@ describe("Custom Tools", () => {
12731 } ) ;
12832
12933 it ( "should execute custom tools" , async ( ) => {
130- const result = await mcpClient . callTool ( {
34+ const result = await mcpClient ( ) . callTool ( {
13135 name : "custom_greeting" ,
13236 arguments : { name : "World" } ,
13337 } ) ;
@@ -139,7 +43,7 @@ describe("Custom Tools", () => {
13943 } ,
14044 ] ) ;
14145
142- const result2 = await mcpClient . callTool ( {
46+ const result2 = await mcpClient ( ) . callTool ( {
14347 name : "custom_calculator" ,
14448 arguments : { a : 5 , b : 3 } ,
14549 } ) ;
@@ -151,7 +55,7 @@ describe("Custom Tools", () => {
15155 } ,
15256 ] ) ;
15357
154- const result3 = await mcpClient . callTool ( {
58+ const result3 = await mcpClient ( ) . callTool ( {
15559 name : "custom_calculator" ,
15660 arguments : { a : 4 , b : 7 } ,
15761 } ) ;
@@ -165,11 +69,11 @@ describe("Custom Tools", () => {
16569 } ) ;
16670
16771 it ( "should respect tool categories and operation types from custom tools" , ( ) => {
168- const customGreetingTool = mcpServer . tools . find ( ( t ) => t . name === "custom_greeting" ) ;
72+ const customGreetingTool = mcpServer ( ) . tools . find ( ( t ) => t . name === "custom_greeting" ) ;
16973 expect ( customGreetingTool ?. category ) . toBe ( "mongodb" ) ;
17074 expect ( customGreetingTool ?. operationType ) . toBe ( "read" ) ;
17175
172- const customCalculatorTool = mcpServer . tools . find ( ( t ) => t . name === "custom_calculator" ) ;
76+ const customCalculatorTool = mcpServer ( ) . tools . find ( ( t ) => t . name === "custom_calculator" ) ;
17377 expect ( customCalculatorTool ?. category ) . toBe ( "mongodb" ) ;
17478 expect ( customCalculatorTool ?. operationType ) . toBe ( "read" ) ;
17579 } ) ;
0 commit comments