11use std:: { collections:: HashMap , fmt:: Display , ops:: Deref , sync:: Arc } ;
22
33use anyhow:: { Context as _, Result } ;
4+ use async_openai:: types:: chat:: {
5+ ChatCompletionMessageToolCalls , ChatCompletionRequestToolMessageContent ,
6+ } ;
47use async_openai:: {
58 Client ,
69 config:: OpenAIConfig ,
@@ -380,27 +383,22 @@ impl From<ChatMessage> for ChatCompletionRequestMessage {
380383 ChatMessage :: ToolCall ( content) => ChatCompletionRequestMessage :: Assistant (
381384 async_openai:: types:: chat:: ChatCompletionRequestAssistantMessage {
382385 content : None ,
383- tool_calls : Some ( vec ! [
384- async_openai:: types:: chat:: ChatCompletionMessageToolCalls :: Function (
385- async_openai:: types:: chat:: ChatCompletionMessageToolCall {
386- id: content. id,
387- function: async_openai:: types:: chat:: FunctionCall {
388- name: content. name,
389- arguments: content. arguments,
390- } ,
386+ tool_calls : Some ( vec ! [ ChatCompletionMessageToolCalls :: Function (
387+ async_openai:: types:: chat:: ChatCompletionMessageToolCall {
388+ id: content. id,
389+ function: async_openai:: types:: chat:: FunctionCall {
390+ name: content. name,
391+ arguments: content. arguments,
391392 } ,
392- ) ,
393- ] ) ,
393+ } ,
394+ ) ] ) ,
394395 ..Default :: default ( )
395396 } ,
396397 ) ,
397398 ChatMessage :: ToolResponse ( content) => ChatCompletionRequestMessage :: Tool (
398399 async_openai:: types:: chat:: ChatCompletionRequestToolMessage {
399400 tool_call_id : content. id ,
400- content :
401- async_openai:: types:: chat:: ChatCompletionRequestToolMessageContent :: Text (
402- content. result ,
403- ) ,
401+ content : ChatCompletionRequestToolMessageContent :: Text ( content. result ) ,
404402 } ,
405403 ) ,
406404 }
@@ -420,16 +418,17 @@ fn from_openai_chat_messages(messages: Vec<ChatCompletionRequestMessage>) -> Vec
420418 ChatCompletionRequestMessage :: Assistant ( assistant_msg) => {
421419 if let Some ( tool_calls) = & assistant_msg. tool_calls {
422420 for tool_call in tool_calls {
423- // Extract function call from the enum
424- if let async_openai:: types:: chat:: ChatCompletionMessageToolCalls :: Function (
425- func_call,
426- ) = tool_call
427- {
421+ if let ChatCompletionMessageToolCalls :: Function ( func_call) = tool_call {
428422 chat_messages. push ( ChatMessage :: ToolCall ( ToolCallContent {
429423 id : func_call. id . clone ( ) ,
430424 name : func_call. function . name . clone ( ) ,
431425 arguments : func_call. function . arguments . clone ( ) ,
432426 } ) ) ;
427+ } else {
428+ tracing:: warn!(
429+ ?tool_call,
430+ "Encountered unexpected non-function tool call"
431+ ) ;
433432 }
434433 }
435434 }
@@ -441,10 +440,7 @@ fn from_openai_chat_messages(messages: Vec<ChatCompletionRequestMessage>) -> Vec
441440 }
442441 }
443442 ChatCompletionRequestMessage :: Tool ( tool_msg) => {
444- if let async_openai:: types:: chat:: ChatCompletionRequestToolMessageContent :: Text (
445- text,
446- ) = tool_msg. content
447- {
443+ if let ChatCompletionRequestToolMessageContent :: Text ( text) = tool_msg. content {
448444 chat_messages. push ( ChatMessage :: ToolResponse ( ToolResponseContent {
449445 id : tool_msg. tool_call_id . clone ( ) ,
450446 result : text,
@@ -513,7 +509,7 @@ pub fn tool_calling_loop(
513509 chat_messages : Vec < ChatMessage > ,
514510 tool_set : & mut impl Toolset ,
515511 model : Option < String > ,
516- ) -> anyhow :: Result < String > {
512+ ) -> Result < String > {
517513 let mut messages: Vec < ChatCompletionRequestMessage > =
518514 vec ! [ ChatCompletionRequestSystemMessage :: from( system_message) . into( ) ] ;
519515
@@ -530,7 +526,7 @@ pub fn tool_calling_loop(
530526 . map ( |t| t. deref ( ) . try_into ( ) )
531527 . collect :: < Result < Vec < async_openai:: types:: chat:: ChatCompletionTools > , _ > > ( ) ?;
532528
533- let mut response = crate :: openai :: tool_calling_blocking (
529+ let mut response = tool_calling_blocking (
534530 provider,
535531 messages. clone ( ) ,
536532 open_ai_tools. clone ( ) ,
@@ -559,62 +555,51 @@ pub fn tool_calling_loop(
559555 . first ( )
560556 . and_then ( |choice| choice. message . tool_calls . as_ref ( ) )
561557 {
562- let mut tool_calls_messages: Vec <
563- async_openai:: types:: chat:: ChatCompletionMessageToolCalls ,
564- > = vec ! [ ] ;
565- let mut tool_response_messages: Vec <
566- async_openai:: types:: chat:: ChatCompletionRequestMessage ,
567- > = vec ! [ ] ;
558+ let mut tool_calls_messages: Vec < ChatCompletionMessageToolCalls > = vec ! [ ] ;
559+ let mut tool_response_messages: Vec < ChatCompletionRequestMessage > = vec ! [ ] ;
568560
569561 for call in tool_calls {
570562 // Extract function call from the enum
571563 let ( id, function_name, function_args) = match call {
572- async_openai :: types :: chat :: ChatCompletionMessageToolCalls :: Function ( func_call) => (
564+ ChatCompletionMessageToolCalls :: Function ( func_call) => (
573565 func_call. id . clone ( ) ,
574566 func_call. function . name . clone ( ) ,
575567 func_call. function . arguments . clone ( ) ,
576568 ) ,
577- async_openai :: types :: chat :: ChatCompletionMessageToolCalls :: Custom ( _ ) => {
578- // Skip custom tool calls as we only handle function calls
569+ ChatCompletionMessageToolCalls :: Custom ( custom ) => {
570+ tracing :: warn! ( ?custom , "Encountered unexpected custom tool call" ) ;
579571 continue ;
580572 }
581573 } ;
582574
583575 let tool_response = tool_set. call_tool ( & function_name, & function_args) ;
584-
585576 let tool_response_str = serde_json:: to_string ( & tool_response)
586577 . context ( "Failed to serialize tool response" ) ?;
587578
588- tool_calls_messages. push (
589- async_openai:: types:: chat:: ChatCompletionMessageToolCalls :: Function (
590- async_openai:: types:: chat:: ChatCompletionMessageToolCall {
591- id : id. clone ( ) ,
592- function : async_openai:: types:: chat:: FunctionCall {
593- name : function_name,
594- arguments : function_args,
595- } ,
579+ tool_calls_messages. push ( ChatCompletionMessageToolCalls :: Function (
580+ async_openai:: types:: chat:: ChatCompletionMessageToolCall {
581+ id : id. clone ( ) ,
582+ function : async_openai:: types:: chat:: FunctionCall {
583+ name : function_name,
584+ arguments : function_args,
596585 } ,
597- ) ,
598- ) ;
586+ } ,
587+ ) ) ;
599588
600- tool_response_messages. push ( async_openai :: types :: chat :: ChatCompletionRequestMessage :: Tool (
589+ tool_response_messages. push ( ChatCompletionRequestMessage :: Tool (
601590 async_openai:: types:: chat:: ChatCompletionRequestToolMessage {
602591 tool_call_id : id. clone ( ) ,
603- content : async_openai:: types:: chat:: ChatCompletionRequestToolMessageContent :: Text (
604- tool_response_str,
605- ) ,
592+ content : ChatCompletionRequestToolMessageContent :: Text ( tool_response_str) ,
606593 } ,
607594 ) ) ;
608595 }
609596
610- messages. push (
611- async_openai:: types:: chat:: ChatCompletionRequestMessage :: Assistant (
612- async_openai:: types:: chat:: ChatCompletionRequestAssistantMessage {
613- tool_calls : Some ( tool_calls_messages) ,
614- ..Default :: default ( )
615- } ,
616- ) ,
617- ) ;
597+ messages. push ( ChatCompletionRequestMessage :: Assistant (
598+ async_openai:: types:: chat:: ChatCompletionRequestAssistantMessage {
599+ tool_calls : Some ( tool_calls_messages) ,
600+ ..Default :: default ( )
601+ } ,
602+ ) ) ;
618603
619604 messages. extend ( tool_response_messages) ;
620605
@@ -733,14 +718,14 @@ pub fn tool_calling_loop_stream(
733718 ) ,
734719 ) ;
735720
736- tool_response_messages. push ( async_openai :: types :: chat :: ChatCompletionRequestMessage :: Tool (
737- async_openai:: types:: chat:: ChatCompletionRequestToolMessage {
738- tool_call_id : id . clone ( ) ,
739- content : async_openai :: types :: chat :: ChatCompletionRequestToolMessageContent :: Text (
740- tool_response_str,
741- ) ,
742- } ,
743- ) ) ;
721+ tool_response_messages. push (
722+ async_openai:: types:: chat:: ChatCompletionRequestMessage :: Tool (
723+ async_openai :: types :: chat :: ChatCompletionRequestToolMessage {
724+ tool_call_id : id . clone ( ) ,
725+ content : ChatCompletionRequestToolMessageContent :: Text ( tool_response_str) ,
726+ } ,
727+ ) ,
728+ ) ;
744729 }
745730
746731 messages. push (
0 commit comments