Skip to content

Commit f5774ed

Browse files
committed
refactor
- more logging in case something is unexpected. - directly import some types to remove clutter.
1 parent 423cca1 commit f5774ed

File tree

1 file changed

+50
-65
lines changed

1 file changed

+50
-65
lines changed

crates/but-action/src/openai.rs

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::{collections::HashMap, fmt::Display, ops::Deref, sync::Arc};
22

33
use anyhow::{Context as _, Result};
4+
use async_openai::types::chat::{
5+
ChatCompletionMessageToolCalls, ChatCompletionRequestToolMessageContent,
6+
};
47
use async_openai::{
58
Client,
69
config::OpenAIConfig,
@@ -380,27 +383,22 @@ impl From<ChatMessage> for ChatCompletionRequestMessage {
380383
ChatMessage::ToolCall(content) => ChatCompletionRequestMessage::Assistant(
381384
async_openai::types::chat::ChatCompletionRequestAssistantMessage {
382385
content: None,
383-
tool_calls: Some(vec![
384-
async_openai::types::chat::ChatCompletionMessageToolCalls::Function(
385-
async_openai::types::chat::ChatCompletionMessageToolCall {
386-
id: content.id,
387-
function: async_openai::types::chat::FunctionCall {
388-
name: content.name,
389-
arguments: content.arguments,
390-
},
386+
tool_calls: Some(vec![ChatCompletionMessageToolCalls::Function(
387+
async_openai::types::chat::ChatCompletionMessageToolCall {
388+
id: content.id,
389+
function: async_openai::types::chat::FunctionCall {
390+
name: content.name,
391+
arguments: content.arguments,
391392
},
392-
),
393-
]),
393+
},
394+
)]),
394395
..Default::default()
395396
},
396397
),
397398
ChatMessage::ToolResponse(content) => ChatCompletionRequestMessage::Tool(
398399
async_openai::types::chat::ChatCompletionRequestToolMessage {
399400
tool_call_id: content.id,
400-
content:
401-
async_openai::types::chat::ChatCompletionRequestToolMessageContent::Text(
402-
content.result,
403-
),
401+
content: ChatCompletionRequestToolMessageContent::Text(content.result),
404402
},
405403
),
406404
}
@@ -420,16 +418,17 @@ fn from_openai_chat_messages(messages: Vec<ChatCompletionRequestMessage>) -> Vec
420418
ChatCompletionRequestMessage::Assistant(assistant_msg) => {
421419
if let Some(tool_calls) = &assistant_msg.tool_calls {
422420
for tool_call in tool_calls {
423-
// Extract function call from the enum
424-
if let async_openai::types::chat::ChatCompletionMessageToolCalls::Function(
425-
func_call,
426-
) = tool_call
427-
{
421+
if let ChatCompletionMessageToolCalls::Function(func_call) = tool_call {
428422
chat_messages.push(ChatMessage::ToolCall(ToolCallContent {
429423
id: func_call.id.clone(),
430424
name: func_call.function.name.clone(),
431425
arguments: func_call.function.arguments.clone(),
432426
}));
427+
} else {
428+
tracing::warn!(
429+
?tool_call,
430+
"Encountered unexpected non-function tool call"
431+
);
433432
}
434433
}
435434
}
@@ -441,10 +440,7 @@ fn from_openai_chat_messages(messages: Vec<ChatCompletionRequestMessage>) -> Vec
441440
}
442441
}
443442
ChatCompletionRequestMessage::Tool(tool_msg) => {
444-
if let async_openai::types::chat::ChatCompletionRequestToolMessageContent::Text(
445-
text,
446-
) = tool_msg.content
447-
{
443+
if let ChatCompletionRequestToolMessageContent::Text(text) = tool_msg.content {
448444
chat_messages.push(ChatMessage::ToolResponse(ToolResponseContent {
449445
id: tool_msg.tool_call_id.clone(),
450446
result: text,
@@ -513,7 +509,7 @@ pub fn tool_calling_loop(
513509
chat_messages: Vec<ChatMessage>,
514510
tool_set: &mut impl Toolset,
515511
model: Option<String>,
516-
) -> anyhow::Result<String> {
512+
) -> Result<String> {
517513
let mut messages: Vec<ChatCompletionRequestMessage> =
518514
vec![ChatCompletionRequestSystemMessage::from(system_message).into()];
519515

@@ -530,7 +526,7 @@ pub fn tool_calling_loop(
530526
.map(|t| t.deref().try_into())
531527
.collect::<Result<Vec<async_openai::types::chat::ChatCompletionTools>, _>>()?;
532528

533-
let mut response = crate::openai::tool_calling_blocking(
529+
let mut response = tool_calling_blocking(
534530
provider,
535531
messages.clone(),
536532
open_ai_tools.clone(),
@@ -559,62 +555,51 @@ pub fn tool_calling_loop(
559555
.first()
560556
.and_then(|choice| choice.message.tool_calls.as_ref())
561557
{
562-
let mut tool_calls_messages: Vec<
563-
async_openai::types::chat::ChatCompletionMessageToolCalls,
564-
> = vec![];
565-
let mut tool_response_messages: Vec<
566-
async_openai::types::chat::ChatCompletionRequestMessage,
567-
> = vec![];
558+
let mut tool_calls_messages: Vec<ChatCompletionMessageToolCalls> = vec![];
559+
let mut tool_response_messages: Vec<ChatCompletionRequestMessage> = vec![];
568560

569561
for call in tool_calls {
570562
// Extract function call from the enum
571563
let (id, function_name, function_args) = match call {
572-
async_openai::types::chat::ChatCompletionMessageToolCalls::Function(func_call) => (
564+
ChatCompletionMessageToolCalls::Function(func_call) => (
573565
func_call.id.clone(),
574566
func_call.function.name.clone(),
575567
func_call.function.arguments.clone(),
576568
),
577-
async_openai::types::chat::ChatCompletionMessageToolCalls::Custom(_) => {
578-
// Skip custom tool calls as we only handle function calls
569+
ChatCompletionMessageToolCalls::Custom(custom) => {
570+
tracing::warn!(?custom, "Encountered unexpected custom tool call");
579571
continue;
580572
}
581573
};
582574

583575
let tool_response = tool_set.call_tool(&function_name, &function_args);
584-
585576
let tool_response_str = serde_json::to_string(&tool_response)
586577
.context("Failed to serialize tool response")?;
587578

588-
tool_calls_messages.push(
589-
async_openai::types::chat::ChatCompletionMessageToolCalls::Function(
590-
async_openai::types::chat::ChatCompletionMessageToolCall {
591-
id: id.clone(),
592-
function: async_openai::types::chat::FunctionCall {
593-
name: function_name,
594-
arguments: function_args,
595-
},
579+
tool_calls_messages.push(ChatCompletionMessageToolCalls::Function(
580+
async_openai::types::chat::ChatCompletionMessageToolCall {
581+
id: id.clone(),
582+
function: async_openai::types::chat::FunctionCall {
583+
name: function_name,
584+
arguments: function_args,
596585
},
597-
),
598-
);
586+
},
587+
));
599588

600-
tool_response_messages.push(async_openai::types::chat::ChatCompletionRequestMessage::Tool(
589+
tool_response_messages.push(ChatCompletionRequestMessage::Tool(
601590
async_openai::types::chat::ChatCompletionRequestToolMessage {
602591
tool_call_id: id.clone(),
603-
content: async_openai::types::chat::ChatCompletionRequestToolMessageContent::Text(
604-
tool_response_str,
605-
),
592+
content: ChatCompletionRequestToolMessageContent::Text(tool_response_str),
606593
},
607594
));
608595
}
609596

610-
messages.push(
611-
async_openai::types::chat::ChatCompletionRequestMessage::Assistant(
612-
async_openai::types::chat::ChatCompletionRequestAssistantMessage {
613-
tool_calls: Some(tool_calls_messages),
614-
..Default::default()
615-
},
616-
),
617-
);
597+
messages.push(ChatCompletionRequestMessage::Assistant(
598+
async_openai::types::chat::ChatCompletionRequestAssistantMessage {
599+
tool_calls: Some(tool_calls_messages),
600+
..Default::default()
601+
},
602+
));
618603

619604
messages.extend(tool_response_messages);
620605

@@ -733,14 +718,14 @@ pub fn tool_calling_loop_stream(
733718
),
734719
);
735720

736-
tool_response_messages.push(async_openai::types::chat::ChatCompletionRequestMessage::Tool(
737-
async_openai::types::chat::ChatCompletionRequestToolMessage {
738-
tool_call_id: id.clone(),
739-
content: async_openai::types::chat::ChatCompletionRequestToolMessageContent::Text(
740-
tool_response_str,
741-
),
742-
},
743-
));
721+
tool_response_messages.push(
722+
async_openai::types::chat::ChatCompletionRequestMessage::Tool(
723+
async_openai::types::chat::ChatCompletionRequestToolMessage {
724+
tool_call_id: id.clone(),
725+
content: ChatCompletionRequestToolMessageContent::Text(tool_response_str),
726+
},
727+
),
728+
);
744729
}
745730

746731
messages.push(

0 commit comments

Comments
 (0)