Skip to content

Commit 929e9a4

Browse files
committed
tool loop & tool registry files
1 parent f86cf10 commit 929e9a4

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

src/examples/server/toolLoop.ts

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import type {
1717
ServerNotification,
1818
} from "../../types.js";
1919
import { ToolRegistry } from "./toolRegistry.js";
20+
import { ToolResultContent } from "../../../dist/esm/types.js";
2021

2122
/**
2223
* Interface for tracking aggregated token usage across API calls.
@@ -29,6 +30,13 @@ interface AggregatedUsage {
2930
api_calls: number;
3031
}
3132

33+
export class BreakToolLoopError extends Error {
34+
constructor(message: string) {
35+
super(message);
36+
}
37+
}
38+
39+
3240
/**
3341
* Runs a tool loop using sampling.
3442
* Continues until the LLM provides a final answer.
@@ -40,13 +48,14 @@ export async function runToolLoop(
4048
registry: ToolRegistry,
4149
maxIterations?: number,
4250
systemPrompt?: string,
51+
defaultToolChoice?: CreateMessageRequest["params"]['toolChoice']
4352
},
4453
extra: RequestHandlerExtra<ServerRequest, ServerNotification>
4554
): Promise<{ answer: string; transcript: SamplingMessage[]; usage: AggregatedUsage }> {
4655
const messages: SamplingMessage[] = [...options.initialMessages];
4756

4857
// Initialize usage tracking
49-
const aggregatedUsage: AggregatedUsage = {
58+
const usage: AggregatedUsage = {
5059
input_tokens: 0,
5160
output_tokens: 0,
5261
cache_creation_input_tokens: 0,
@@ -57,6 +66,8 @@ export async function runToolLoop(
5766
let iteration = 0;
5867
const maxIterations = options.maxIterations ?? Number.POSITIVE_INFINITY;
5968

69+
const defaultToolChoice = options.defaultToolChoice ?? {mode: "auto"};
70+
6071
let request: CreateMessageRequest["params"] | undefined
6172
let response: CreateMessageResult | undefined
6273
while (iteration < maxIterations) {
@@ -69,17 +80,17 @@ export async function runToolLoop(
6980
maxTokens: 4000,
7081
tools: iteration < maxIterations ? options.registry.tools : undefined,
7182
// Don't allow tool calls at the last iteration: finish with an answer no matter what!
72-
tool_choice: { mode: iteration < maxIterations ? "auto" : "none" },
83+
tool_choice: iteration < maxIterations ? defaultToolChoice : { mode: "none" },
7384
});
7485

7586
// Aggregate usage statistics from the response
7687
if (response._meta?.usage) {
77-
const usage = response._meta.usage as any;
78-
aggregatedUsage.input_tokens += usage.input_tokens || 0;
79-
aggregatedUsage.output_tokens += usage.output_tokens || 0;
80-
aggregatedUsage.cache_creation_input_tokens += usage.cache_creation_input_tokens || 0;
81-
aggregatedUsage.cache_read_input_tokens += usage.cache_read_input_tokens || 0;
82-
aggregatedUsage.api_calls += 1;
88+
const responseUsage = response._meta.usage as any;
89+
usage.input_tokens += responseUsage.input_tokens || 0;
90+
usage.output_tokens += responseUsage.output_tokens || 0;
91+
usage.cache_creation_input_tokens += responseUsage.cache_creation_input_tokens || 0;
92+
usage.cache_read_input_tokens += responseUsage.cache_read_input_tokens || 0;
93+
usage.api_calls += 1;
8394
}
8495

8596
// Add assistant's response to message history
@@ -100,7 +111,16 @@ export async function runToolLoop(
100111
data: `Loop iteration ${iteration}: ${toolCalls.length} tool invocation(s) requested`,
101112
});
102113

103-
const toolResults = await options.registry.callTools(toolCalls, extra);
114+
let toolResults: ToolResultContent[];
115+
try {
116+
toolResults = await options.registry.callTools(toolCalls, extra);
117+
} catch (error) {
118+
if (error instanceof BreakToolLoopError) {
119+
return {answer: `${error.message}`, transcript: messages, usage};
120+
}
121+
console.log(error);
122+
throw new Error(`Tool call failed: ${error}`)
123+
}
104124

105125
messages.push({
106126
role: "user",
@@ -127,7 +147,7 @@ export async function runToolLoop(
127147
return {
128148
answer: contentArray.map(block => block.text).join("\n\n"),
129149
transcript: messages,
130-
usage: aggregatedUsage
150+
usage
131151
};
132152
} else if (response?.stopReason === "maxTokens") {
133153
throw new Error("LLM response hit max tokens limit");
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { McpServer, RegisteredTool } from "../../server/mcp.js";
2+
import { RequestHandlerExtra } from "../../shared/protocol.js";
3+
import { z } from "zod";
4+
import type {
5+
Tool,
6+
ToolCallContent,
7+
ToolResultContent,
8+
ServerRequest,
9+
ServerNotification,
10+
} from "../../types.js";
11+
import { zodToJsonSchema } from "zod-to-json-schema";
12+
import { BreakToolLoopError } from "./toolLoop.js"
13+
14+
15+
export class ToolRegistry {
16+
readonly tools: Tool[]
17+
18+
constructor(private toolDefinitions: {[name: string]: Pick<RegisteredTool, 'title' | 'description' | 'inputSchema' | 'outputSchema' | 'annotations' | '_meta' | 'callback'> }) {
19+
this.tools = Object.entries(this.toolDefinitions).map(([name, tool]) => (<Tool>{
20+
name,
21+
title: tool.title,
22+
description: tool.description,
23+
inputSchema: tool.inputSchema ? zodToJsonSchema(tool.inputSchema) : undefined,
24+
outputSchema: tool.outputSchema ? zodToJsonSchema(tool.outputSchema) : undefined,
25+
annotations: tool.annotations,
26+
_meta: tool._meta,
27+
}));
28+
}
29+
30+
register(server: McpServer) {
31+
for (const [name, tool] of Object.entries(this.toolDefinitions)) {
32+
server.registerTool(name, {
33+
title: tool.title,
34+
description: tool.description,
35+
inputSchema: tool.inputSchema?.shape,
36+
outputSchema: tool.outputSchema?.shape,
37+
annotations: tool.annotations,
38+
_meta: tool._meta,
39+
}, tool.callback);
40+
}
41+
}
42+
43+
async callTools(toolCalls: ToolCallContent[], extra: RequestHandlerExtra<ServerRequest, ServerNotification>): Promise<ToolResultContent[]> {
44+
return Promise.all(toolCalls.map(async ({ name, id, input }) => {
45+
const tool = this.toolDefinitions[name];
46+
if (!tool) {
47+
throw new Error(`Tool ${name} not found`);
48+
}
49+
try {
50+
return <ToolResultContent>{
51+
type: "tool_result",
52+
toolUseId: id,
53+
// Copies fields: content, structuredContent?, isError?
54+
...await tool.callback(input as any, extra),
55+
};
56+
} catch (error) {
57+
if (error instanceof BreakToolLoopError) {
58+
throw error;
59+
}
60+
throw new Error(`Tool ${name} failed: ${error instanceof Error ? `${error.message}\n${error.stack}` : error}`);
61+
}
62+
}));
63+
}
64+
}

0 commit comments

Comments
 (0)