@@ -17,6 +17,7 @@ import type {
1717 ServerNotification ,
1818} from "../../types.js" ;
1919import { ToolRegistry } from "./toolRegistry.js" ;
20+ import { ToolResultContent } from "../../../dist/esm/types.js" ;
2021
2122/**
2223 * Interface for tracking aggregated token usage across API calls.
@@ -29,6 +30,13 @@ interface AggregatedUsage {
2930 api_calls : number ;
3031}
3132
33+ export class BreakToolLoopError extends Error {
34+ constructor ( message : string ) {
35+ super ( message ) ;
36+ }
37+ }
38+
39+
3240/**
3341 * Runs a tool loop using sampling.
3442 * Continues until the LLM provides a final answer.
@@ -40,13 +48,14 @@ export async function runToolLoop(
4048 registry : ToolRegistry ,
4149 maxIterations ?: number ,
4250 systemPrompt ?: string ,
51+ defaultToolChoice ?: CreateMessageRequest [ "params" ] [ 'toolChoice' ]
4352 } ,
4453 extra : RequestHandlerExtra < ServerRequest , ServerNotification >
4554) : Promise < { answer : string ; transcript : SamplingMessage [ ] ; usage : AggregatedUsage } > {
4655 const messages : SamplingMessage [ ] = [ ...options . initialMessages ] ;
4756
4857 // Initialize usage tracking
49- const aggregatedUsage : AggregatedUsage = {
58+ const usage : AggregatedUsage = {
5059 input_tokens : 0 ,
5160 output_tokens : 0 ,
5261 cache_creation_input_tokens : 0 ,
@@ -57,6 +66,8 @@ export async function runToolLoop(
5766 let iteration = 0 ;
5867 const maxIterations = options . maxIterations ?? Number . POSITIVE_INFINITY ;
5968
69+ const defaultToolChoice = options . defaultToolChoice ?? { mode : "auto" } ;
70+
6071 let request : CreateMessageRequest [ "params" ] | undefined
6172 let response : CreateMessageResult | undefined
6273 while ( iteration < maxIterations ) {
@@ -69,17 +80,17 @@ export async function runToolLoop(
6980 maxTokens : 4000 ,
7081 tools : iteration < maxIterations ? options . registry . tools : undefined ,
7182 // Don't allow tool calls at the last iteration: finish with an answer no matter what!
72- tool_choice : { mode : iteration < maxIterations ? "auto" : "none" } ,
83+ tool_choice : iteration < maxIterations ? defaultToolChoice : { mode : "none" } ,
7384 } ) ;
7485
7586 // Aggregate usage statistics from the response
7687 if ( response . _meta ?. usage ) {
77- const usage = response . _meta . usage as any ;
78- aggregatedUsage . input_tokens += usage . input_tokens || 0 ;
79- aggregatedUsage . output_tokens += usage . output_tokens || 0 ;
80- aggregatedUsage . cache_creation_input_tokens += usage . cache_creation_input_tokens || 0 ;
81- aggregatedUsage . cache_read_input_tokens += usage . cache_read_input_tokens || 0 ;
82- aggregatedUsage . api_calls += 1 ;
88+ const responseUsage = response . _meta . usage as any ;
89+ usage . input_tokens += responseUsage . input_tokens || 0 ;
90+ usage . output_tokens += responseUsage . output_tokens || 0 ;
91+ usage . cache_creation_input_tokens += responseUsage . cache_creation_input_tokens || 0 ;
92+ usage . cache_read_input_tokens += responseUsage . cache_read_input_tokens || 0 ;
93+ usage . api_calls += 1 ;
8394 }
8495
8596 // Add assistant's response to message history
@@ -100,7 +111,16 @@ export async function runToolLoop(
100111 data : `Loop iteration ${ iteration } : ${ toolCalls . length } tool invocation(s) requested` ,
101112 } ) ;
102113
103- const toolResults = await options . registry . callTools ( toolCalls , extra ) ;
114+ let toolResults : ToolResultContent [ ] ;
115+ try {
116+ toolResults = await options . registry . callTools ( toolCalls , extra ) ;
117+ } catch ( error ) {
118+ if ( error instanceof BreakToolLoopError ) {
119+ return { answer : `${ error . message } ` , transcript : messages , usage} ;
120+ }
121+ console . log ( error ) ;
122+ throw new Error ( `Tool call failed: ${ error } ` )
123+ }
104124
105125 messages . push ( {
106126 role : "user" ,
@@ -127,7 +147,7 @@ export async function runToolLoop(
127147 return {
128148 answer : contentArray . map ( block => block . text ) . join ( "\n\n" ) ,
129149 transcript : messages ,
130- usage : aggregatedUsage
150+ usage
131151 } ;
132152 } else if ( response ?. stopReason === "maxTokens" ) {
133153 throw new Error ( "LLM response hit max tokens limit" ) ;
0 commit comments