diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index 3d506bdb..bec693ae 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -22,11 +22,14 @@ from agentlab.agents.agent_args import AgentArgs from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( + APIPayload, ClaudeResponseModelArgs, LLMOutput, MessageBuilder, OpenAIChatModelArgs, OpenAIResponseModelArgs, + OpenRouterModelArgs, + ToolCalls, ) from agentlab.llm.tracking import cost_tracker_decorator @@ -98,7 +101,8 @@ def flatten(self) -> list[MessageBuilder]: messages.extend(group.messages) # Mark all summarized messages for caching if i == len(self.groups) - keep_last_n_obs: - messages[i].mark_all_previous_msg_for_caching() + if not isinstance(messages[i], ToolCalls): + messages[i].mark_all_previous_msg_for_caching() return messages def set_last_summary(self, summary: MessageBuilder): @@ -163,18 +167,15 @@ class Obs(Block): use_dom: bool = False use_som: bool = False use_tabs: bool = False - add_mouse_pointer: bool = False + # add_mouse_pointer: bool = False use_zoomed_webpage: bool = False def apply( self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput ) -> dict: - if last_llm_output.tool_calls is None: - obs_msg = llm.msg.user() # type: MessageBuilder - else: - obs_msg = llm.msg.tool(last_llm_output.raw_response) # type: MessageBuilder - + obs_msg = llm.msg.user() + tool_calls = last_llm_output.tool_calls if self.use_last_error: if obs["last_action_error"] != "": obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}") @@ -186,13 +187,12 @@ def apply( else: screenshot = obs["screenshot"] - if self.add_mouse_pointer: - # TODO this mouse pointer should be added at the browsergym level - screenshot = np.array( - agent_utils.add_mouse_pointer_from_action( - Image.fromarray(obs["screenshot"]), obs["last_action"] - ) - ) + # if self.add_mouse_pointer: + # screenshot = np.array( + # agent_utils.add_mouse_pointer_from_action( + # Image.fromarray(obs["screenshot"]), obs["last_action"] + # ) + # ) obs_msg.add_image(image_to_png_base64_url(screenshot)) if self.use_axtree: @@ -203,6 +203,13 @@ def apply( obs_msg.add_text(_format_tabs(obs)) discussion.append(obs_msg) + + if tool_calls: + for call in tool_calls: + call.response_text("See Observation") + tool_response = llm.msg.add_responded_tool_calls(tool_calls) + discussion.append(tool_response) + return obs_msg @@ -254,8 +261,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict: msg = llm.msg.user().add_text("""Summarize\n""") discussion.append(msg) - # TODO need to make sure we don't force tool use here - summary_response = llm(messages=discussion.flatten(), tool_choice="none") + + summary_response = llm(APIPayload(messages=discussion.flatten())) summary_msg = llm.msg.assistant().add_text(summary_response.think) discussion.append(summary_msg) @@ -320,25 +327,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict: discussion.append(msg) -class ToolCall(Block): - - def __init__(self, tool_server): - self.tool_server = tool_server - - def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict: - # build the message by adding components to obs - response: LLMOutput = llm(messages=self.messages) - - messages.append(response.assistant_message) # this is tool call - - tool_answer = self.tool_server.call_tool(response) - tool_msg = llm.msg.tool() # type: MessageBuilder - tool_msg.add_tool_id(response.last_computer_call_id) - tool_msg.update_last_raw_response(response) - tool_msg.add_text(str(tool_answer)) - messages.append(tool_msg) - - @dataclass class PromptConfig: tag_screenshot: bool = True # Whether to tag the screenshot with the last action. @@ -394,7 +382,7 @@ def __init__( self.call_ids = [] - self.llm = model_args.make_model(extra_kwargs={"tools": self.tools}) + self.llm = model_args.make_model() self.msg_builder = model_args.get_message_builder() self.llm.msg = self.msg_builder @@ -462,13 +450,15 @@ def get_action(self, obs: Any) -> float: messages = self.discussion.flatten() response: LLMOutput = self.llm( - messages=messages, - tool_choice="any", - cache_tool_definition=True, - cache_complete_prompt=False, - use_cache_breakpoints=True, + APIPayload( + messages=messages, + tools=self.tools, # You can update tools available tools now. + tool_choice="any", + cache_tool_definition=True, + cache_complete_prompt=False, + use_cache_breakpoints=True, + ) ) - action = response.action think = response.think last_summary = self.discussion.get_last_summary() @@ -476,7 +466,7 @@ def get_action(self, obs: Any) -> float: think = last_summary.content[0]["text"] + "\n" + think self.discussion.new_group() - self.discussion.append(response.tool_calls) + # self.discussion.append(response.tool_calls) # No need to append tool calls anymore. self.last_response = response self._responses.append(response) # may be useful for debugging @@ -486,8 +476,11 @@ def get_action(self, obs: Any) -> float: tools_msg = MessageBuilder("tool_description").add_text(tools_str) # Adding these extra messages to visualize in gradio - messages.insert(0, tools_msg) # insert at the beginning of the messages - messages.append(response.tool_calls) + messages.insert(0, tools_msg) # insert at the beginning of the message + # This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls) + msg = self.llm.msg("tool") + msg.responded_tool_calls = response.tool_calls + messages.append(msg) agent_info = bgym.AgentInfo( think=think, @@ -533,6 +526,31 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +O3_RESPONSE_MODEL = OpenAIResponseModelArgs( + model_name="o3-2025-04-16", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=None, # O3 does not support temperature + vision_support=True, +) +O3_CHATAPI_MODEL = OpenAIChatModelArgs( + model_name="o3-2025-04-16", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=None, + vision_support=True, +) + +GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs( + model_name="openai/gpt-4.1", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=None, # O3 does not support temperature + vision_support=True, +) DEFAULT_PROMPT_CONFIG = PromptConfig( tag_screenshot=True, @@ -548,8 +566,8 @@ def get_action(self, obs: Any) -> float: summarizer=Summarizer(do_summary=True), general_hints=GeneralHints(use_hints=False), task_hint=TaskHint(use_task_hint=True), - keep_last_n_obs=None, # keep only the last observation in the discussion - multiaction=False, # whether to use multi-action or not + keep_last_n_obs=None, + multiaction=True, # whether to use multi-action or not # action_subsets=("bid",), action_subsets=("coord"), # action_subsets=("coord", "bid"), @@ -559,3 +577,18 @@ def get_action(self, obs: Any) -> float: model_args=CLAUDE_MODEL_CONFIG, config=DEFAULT_PROMPT_CONFIG, ) + +OAI_AGENT = ToolUseAgentArgs( + model_args=GPT_4_1, + config=DEFAULT_PROMPT_CONFIG, +) + +OAI_CHATAPI_AGENT = ToolUseAgentArgs( + model_args=O3_CHATAPI_MODEL, + config=DEFAULT_PROMPT_CONFIG, +) + +OAI_OPENROUTER_AGENT = ToolUseAgentArgs( + model_args=GPT4_1_OPENROUTER_MODEL, + config=DEFAULT_PROMPT_CONFIG, +) diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index bd1c6ad4..8b846f5f 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -26,6 +26,7 @@ from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage from agentlab.llm.llm_utils import Discussion from agentlab.llm.response_api import MessageBuilder +from agentlab.llm.response_api import ToolCalls select_dir_instructions = "Select Experiment Directory" AGENT_NAME_KEY = "agent.agent_name" @@ -673,6 +674,9 @@ def dict_to_markdown(d: dict): str: A markdown-formatted string representation of the dictionary. """ if not isinstance(d, dict): + if isinstance(d, ToolCalls): + # ToolCalls rendered by to_markdown method. + return "" warning(f"Expected dict, got {type(d)}") return repr(d) if not d: diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 8e6f3695..e8c74849 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -3,10 +3,12 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Union import openai from anthropic import Anthropic +from anthropic.types import Completion +from anthropic.types import Message as AnthrophicMessage from openai import OpenAI from agentlab.llm.llm_utils import image_to_png_base64_url @@ -30,23 +32,90 @@ Message = Dict[str, Union[str, List[ContentItem]]] +@dataclass +class ToolCall: + """Represents a tool call made by the LLM. + Attributes: + name: Name of the tool called. + arguments: Arguments passed to the tool. + raw_call: The raw call object from the LLM API. + tool_response: Output of the tool call goes here. It can be only one content item. + """ + + name: str = field(default=None) + arguments: Dict[str, Any] = field(default_factory=dict) + raw_call: Any = field(default=None) + tool_response: ContentItem = None + + @property + def is_response_set(self) -> bool: + """Check if the tool response is set.""" + return self.tool_response is not None + + def response_text(self, text: str) -> "MessageBuilder": + self.tool_response = {"text": text} + return self + + def response_image(self, image: str) -> "MessageBuilder": + self.tool_response = {"image": image} + return self + + def __repr__(self): + return f"ToolCall(name={self.name}, arguments={self.arguments})" + + +@dataclass +class ToolCalls: + """A collection of tool calls made by the LLM. + + Attributes: + tool_calls: List of ToolCall objects. + raw_calls: Represents raw tool calls object returned by a LLM API, may contain one or more tool calls. + """ + + tool_calls: List[ToolCall] = field(default_factory=list) + raw_calls: List[Any] = field(default_factory=list) + + def add_tool_call(self, tool_call: ToolCall) -> "ToolCalls": + self.tool_calls.append(tool_call) + return self + + @property + def all_responses_set(self) -> bool: + """Check if all tool calls have responses set.""" + return all(call.is_response_set for call in self.tool_calls) + + def __len__(self) -> int: + """Return the number of tool calls.""" + return len(self.tool_calls) + + def __iter__(self): + """Make ToolCalls iterable.""" + return iter(self.tool_calls) + + def __bool__(self): + """Check if there are any tool calls.""" + return len(self.tool_calls) > 0 + + @dataclass class LLMOutput: """Serializable object for the output of a response LLM.""" - raw_response: Any = field(default_factory=dict) + raw_response: Any = field(default=None) think: str = field(default="") - action: str = field(default=None) # Default action if no tool call is made - tool_calls: Any = field(default=None) # This will hold the tool call response if any + action: str | None = field(default=None) # Default action if no tool call is made + tool_calls: ToolCalls | None = field( + default=None + ) # This will hold the tool call response if any class MessageBuilder: def __init__(self, role: str): self.role = role - self.last_raw_response: LLMOutput = None self.content: List[ContentItem] = [] - self.tool_call_id: Optional[str] = None + self.responded_tool_calls: ToolCalls = None @classmethod def system(cls) -> "MessageBuilder": @@ -60,19 +129,11 @@ def user(cls) -> "MessageBuilder": def assistant(cls) -> "MessageBuilder": return cls("assistant") - @classmethod - def tool(cls, last_raw_response) -> "MessageBuilder": - return cls("tool").update_last_raw_response(last_raw_response) - @abstractmethod def prepare_message(self) -> List[Message]: """Prepare the message for the API call.""" raise NotImplementedError("Subclasses must implement this method.") - def update_last_raw_response(self, last_raw_response: Any) -> "MessageBuilder": - self.last_raw_response = last_raw_response - return self - def add_text(self, text: str) -> "MessageBuilder": self.content.append({"text": text}) return self @@ -89,6 +150,22 @@ def to_markdown(self) -> str: elif "image" in item: parts.append(f"![Image]({item['image']})") + # Tool call markdown repr + if self.responded_tool_calls is not None: + for i, tool_call in enumerate(self.responded_tool_calls.tool_calls, 1): + parts.append( + f"\n**Tool Call {i}**: {tool_call_to_python_code(tool_call.name, tool_call.arguments)}" + ) + response = tool_call.tool_response + if response is not None: + parts.append(f"\n**Tool Response {i}:**") + content = ( + f"```\n{response['text']}\n```" + if "text" in response + else f"![Tool Response Image]({response['image']})" + ) + parts.append(content) + markdown = f"### {self.role.capitalize()}\n" markdown += "\n".join(parts) @@ -104,8 +181,13 @@ def mark_all_previous_msg_for_caching(self): # This is a placeholder for future implementation. raise NotImplementedError - -# TODO: Support parallel tool calls. + @classmethod + def add_responded_tool_calls(cls, responded_tool_calls: ToolCalls) -> "MessageBuilder": + """Add tool calls to the message content.""" + assert responded_tool_calls.all_responses_set, "All tool calls must have a response." + msg = cls("tool") + msg.responded_tool_calls = responded_tool_calls + return msg class OpenAIResponseAPIMessageBuilder(MessageBuilder): @@ -117,44 +199,62 @@ def system(cls) -> "OpenAIResponseAPIMessageBuilder": def prepare_message(self) -> List[Message]: content = [] for item in self.content: - if "text" in item: - content_type = "input_text" if self.role != "assistant" else "output_text" - content.append({"type": content_type, "text": item["text"]}) + content.append(self.convert_content_to_expected_format(item)) + output = [{"role": self.role, "content": content}] - elif "image" in item: - content.append({"type": "input_image", "image_url": item["image"]}) + return output if self.role != "tool" else self.handle_tool_call() - output = [{"role": self.role, "content": content}] - if self.role != "tool": - return output + def convert_content_to_expected_format(self, content: ContentItem) -> ContentItem: + """Convert the content item to the expected format for OpenAI Responses.""" + if "text" in content: + content_type = "input_text" if self.role != "assistant" else "output_text" + return {"type": content_type, "text": content["text"]} + elif "image" in content: + return {"type": "input_image", "image_url": content["image"]} else: - tool_call_response = self.handle_tool_call(content) - return tool_call_response + raise ValueError(f"Unsupported content type: {content}") - def handle_tool_call(self, content): + def handle_tool_call(self) -> List[Message]: """Handle the tool call response from the last raw response.""" + if self.responded_tool_calls is None: + raise ValueError("No tool calls found in responded_tool_calls") + output = [] - head_content, *tail_content = content - api_response = self.last_raw_response - fn_calls = [content for content in api_response.output if content.type == "function_call"] - assert len(fn_calls) > 0, "No function calls found in the last response" - if len(fn_calls) > 1: - logging.warning("Using only the first tool call from many.") - - first_fn_call_id = fn_calls[0].call_id - fn_output = head_content.get("text", "Function call answer in next message") - fn_call_response = { - "type": "function_call_output", - "call_id": first_fn_call_id, - "output": fn_output, - } - output.append(fn_call_response) - if tail_content: - # if there are more content items, add them as a new user message - output.append({"role": "user", "content": tail_content}) + output.extend(self.responded_tool_calls.raw_calls.output) # this contains response + for fn_call in self.responded_tool_calls: + call_type = fn_call.raw_call.type + call_id = fn_call.raw_call.call_id + call_response = fn_call.tool_response + + match call_type: + case "function_call": + # image output is not supported in function calls response. + assert ( + "image" not in call_response + ), "Image output is not supported in function calls response." + fn_call_response = { + "type": "function_call_output", + "call_id": call_id, + "output": self.convert_content_to_expected_format(call_response)["text"], + } + output.append(fn_call_response) + + case "computer_call": + # For computer calls, use only images are expected. + assert ( + "text" not in call_response + ), "Text output is not supported in computer calls response." + computer_call_output = { + "type": "computer_call_output", + "call_id": call_id, + "output": self.convert_content_to_expected_format(call_response), + } + output.append(computer_call_output) # this needs to be a screenshot + return output - def mark_all_previous_msg_for_caching(self) -> List[Message]: + def mark_all_previous_msg_for_caching(self): + """Nothing special to do here for openAI. They do not have a notion of cache breakpoints.""" pass @@ -164,29 +264,9 @@ def prepare_message(self) -> List[Message]: content = [self.transform_content(item) for item in self.content] output = {"role": self.role, "content": content} - if self.role == "system": - logging.info( - "Treating system message as 'user'. In the Anthropic API, system messages should be passed as a direct input to the client." - ) - output["role"] = "user" - if self.role == "tool": + return self.handle_tool_call() - api_response = self.last_raw_response - fn_calls = [content for content in api_response.content if content.type == "tool_use"] - assert len(fn_calls) > 0, "No tool calls found in the last response" - if len(fn_calls) > 1: - logging.warning("Using only the first tool call from many.") - tool_call_id = fn_calls[0].id # Using the first tool call ID - - output["role"] = "user" - output["content"] = [ - { - "type": "tool_result", - "tool_use_id": tool_call_id, - "content": output["content"], - } - ] if self.role == "assistant": # Strip whitespace from assistant text responses. See anthropic error code 400. for c in output["content"]: @@ -194,6 +274,32 @@ def prepare_message(self) -> List[Message]: c["text"] = c["text"].strip() return [output] + def handle_tool_call(self) -> List[Message]: + """Handle the tool call response from the last raw response.""" + if self.responded_tool_calls is None: + raise ValueError("No tool calls found in responded_tool_calls") + + llm_tool_call = { + "role": "assistant", + "content": self.responded_tool_calls.raw_calls.content, + } # Add the toolcall block + tool_response = {"role": "user", "content": []} # Anthropic expects a list of messages + for call in self.responded_tool_calls: + assert ( + "image" not in call.tool_response + ), "Image output is not supported in tool calls response." + tool_response["content"].append( + { + "type": "tool_result", + "tool_use_id": call.raw_call.id, + "content": self.transform_content(call.tool_response)[ + "text" + ], # needs to be str + } + ) + + return [llm_tool_call, tool_response] + def transform_content(self, content: ContentItem) -> ContentItem: """Transform content item to the format expected by Anthropic API.""" if "text" in content: @@ -224,13 +330,13 @@ class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder): def prepare_message(self) -> List[Message]: """Prepare the message for the OpenAI API.""" - content = [self.transform_content(item) for item in self.content] - if self.role == "tool": - return self.handle_tool_call(content) - else: - return [{"role": self.role, "content": content}] + content = [] + for item in self.content: + content.append(self.convert_content_to_expected_format(item)) + output = [{"role": self.role, "content": content}] + return output if self.role != "tool" else self.handle_tool_call() - def transform_content(self, content: ContentItem) -> ContentItem: + def convert_content_to_expected_format(self, content: ContentItem) -> ContentItem: """Transform content item to the format expected by OpenAI ChatCompletion.""" if "text" in content: return {"type": "text", "text": content["text"]} @@ -239,30 +345,57 @@ def transform_content(self, content: ContentItem) -> ContentItem: else: raise ValueError(f"Unsupported content type: {content}") - def handle_tool_call(self, content) -> List[Message]: + def handle_tool_call(self) -> List[Message]: """Handle the tool call response from the last raw response.""" + if self.responded_tool_calls is None: + raise ValueError("No tool calls found in responded_tool_calls") output = [] - content_head, *content_tail = content - api_response = self.last_raw_response.choices[0].message - fn_calls = getattr(api_response, "tool_calls", None) - assert fn_calls is not None, "Tool calls not found in the last response" - if len(fn_calls) > 1: - logging.warning("Using only the first tool call from many.") - - # a function_call_output dict has keys "role", "tool_call_id" and "content" - tool_call_reponse = { - "role": "tool", - "tool_call_id": fn_calls[0].id, # using the first tool call ID - "content": content_head.get("text", "Tool call answer in next message"), - "name": fn_calls[0].function.name, # required with OpenRouter - } + output.append( + self.responded_tool_calls.raw_calls.choices[0].message + ) # add raw calls to output + for fn_call in self.responded_tool_calls: + raw_call = fn_call.raw_call + assert ( + "image" not in fn_call.tool_response + ), "Image output is not supported in function calls response." + # a function_call_output dict has keys "role", "tool_call_id" and "content" + tool_call_reponse = { + "name": raw_call["function"]["name"], # required with OpenRouter + "role": "tool", + "tool_call_id": raw_call["id"], + "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"], + } + output.append(tool_call_reponse) - output.append(tool_call_reponse) - if content_tail: - # if there are more content items, add them as a new user message - output.append({"role": "user", "content": content_tail}) return output + def mark_all_previous_msg_for_caching(self): + """Nothing special to do here for openAI. They do not have a notion of cache breakpoints.""" + pass + + +@dataclass +class APIPayload: + messages: List[MessageBuilder] | None = None + tools: List[Dict[str, Any]] | None = None + tool_choice: Literal["none", "auto", "any", "required"] | None = None + force_call_tool: str | None = ( + None # Name of the tool to call # If set, will force the LLM to call this tool. + ) + use_cache_breakpoints: bool = ( + False # If True, will apply cache breakpoints to the messages. # applicable for Anthropic + ) + cache_tool_definition: bool = ( + False # If True, will cache the tool definition in the last message. + ) + cache_complete_prompt: bool = ( + False # If True, will cache the complete prompt in the last message. + ) + + def __post_init__(self): + if self.tool_choice and self.force_call_tool: + raise ValueError("tool_choice and force_call_tool are mutually exclusive") + # # Base class for all API Endpoints class BaseResponseModel(ABC): @@ -270,25 +403,22 @@ def __init__( self, model_name: str, api_key: Optional[str] = None, - temperature: float = 0.5, - max_tokens: int = 100, - extra_kwargs: Optional[Dict[str, Any]] = None, + temperature: float | None = None, + max_tokens: int | None = None, ): self.model_name = model_name self.api_key = api_key self.temperature = temperature self.max_tokens = max_tokens - self.extra_kwargs = extra_kwargs or {} - super().__init__() - def __call__(self, messages: list[dict | MessageBuilder], **kwargs) -> dict: + def __call__(self, payload: APIPayload) -> LLMOutput: """Make a call to the model and return the parsed response.""" - response = self._call_api(messages, **kwargs) + response = self._call_api(payload) return self._parse_response(response) @abstractmethod - def _call_api(self, messages: list[dict | MessageBuilder], **kwargs) -> Any: + def _call_api(self, payload: APIPayload) -> Any: """Make a call to the model API and return the raw response.""" pass @@ -298,6 +428,38 @@ def _parse_response(self, response: Any) -> LLMOutput: pass +class AgentlabAction: + """ + Collection of utility function to convert tool calls to Agentlab action format. + """ + + @staticmethod + def convert_toolcall_to_agentlab_action_format(toolcall: ToolCall) -> str: + """Convert a tool call to an Agentlab environment action string. + + Args: + toolcall: ToolCall object containing the name and arguments of the tool call. + + Returns: + A string representing the action in Agentlab format i.e. python function call string. + """ + + tool_name, tool_args = toolcall.name, toolcall.arguments + return tool_call_to_python_code(tool_name, tool_args) + + @staticmethod + def convert_multiactions_to_agentlab_action_format(actions: list[str]) -> str | None: + """Convert multiple actions list to a format that env supports. + + Args: + actions: List of action strings to be joined. + + Returns: + Joined actions separated by newlines, or None if empty. + """ + return "\n".join(actions) if actions else None + + class BaseModelWithPricing(TrackAPIPricingMixin, BaseResponseModel): pass @@ -306,46 +468,47 @@ class OpenAIResponseModel(BaseModelWithPricing): def __init__( self, model_name: str, + base_url: Optional[str] = None, api_key: Optional[str] = None, - temperature: float = 0.5, - max_tokens: int = 100, - extra_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, + temperature: float | None = None, + max_tokens: int | None = 100, ): - self.tools = kwargs.pop("tools", None) - super().__init__( - model_name=model_name, - api_key=api_key, - temperature=temperature, - max_tokens=max_tokens, - extra_kwargs=extra_kwargs, - **kwargs, + self.action_space_as_tools = True # this should be a config + super().__init__( # This is passed to BaseModel + model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens ) - self.client = OpenAI(api_key=api_key) + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + self.client = OpenAI(**client_args) + # Init pricing tracker after super() so that all attributes have been set. + self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin - def _call_api( - self, messages: list[Any | MessageBuilder], tool_choice: str = "auto", **kwargs - ) -> dict: - input = [] - for msg in messages: - input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]) + def _call_api(self, payload: APIPayload) -> "OpenAIResponseObject": + input = [] + for msg in payload.messages: + input.extend(msg.prepare_message()) api_params: Dict[str, Any] = { "model": self.model_name, "input": input, - "temperature": self.temperature, - "max_output_tokens": self.max_tokens, - **self.extra_kwargs, } + # Not all Open AI models support these parameters (example: o3), so we check if they are set. + if self.temperature is not None: + api_params["temperature"] = self.temperature + if self.max_tokens is not None: + api_params["max_output_tokens"] = self.max_tokens + if payload.tools is not None: + api_params["tools"] = payload.tools + if payload.tool_choice is not None and payload.force_call_tool is None: + api_params["tool_choice"] = ( + "required" if payload.tool_choice in ("required", "any") else payload.tool_choice + ) + if payload.force_call_tool is not None: + api_params["tool_choice"] = {"type": "function", "name": payload.force_call_tool} - if self.tools is not None: - api_params["tools"] = self.tools - if tool_choice in ("any", "required"): - tool_choice = "required" - - api_params["tool_choice"] = tool_choice - - # api_params |= kwargs # Merge any additional parameters passed response = call_openai_api_with_retries( self.client.responses.create, api_params, @@ -353,110 +516,221 @@ def _call_api( return response - def _parse_response(self, response: dict) -> dict: - result = LLMOutput( + def _parse_response(self, response: "OpenAIResponseObject") -> LLMOutput: + """Parse the raw response from the OpenAI Responses API.""" + + think_output = self._extract_thinking_content_from_response(response) + toolcalls = self._extract_tool_calls_from_response(response) + + if self.action_space_as_tools: + env_action = self._extract_env_actions_from_toolcalls(toolcalls) + else: + env_action = self._extract_env_actions_from_text_response(response) + + return LLMOutput( raw_response=response, - think="", - action=None, - tool_calls=None, + think=think_output, + action=env_action if env_action is not None else None, + tool_calls=toolcalls if toolcalls is not None else None, ) - interesting_keys = ["output_text"] + + def _extract_tool_calls_from_response(self, response: "OpenAIResponseObject") -> ToolCalls: + """Extracts tool calls from the response.""" + tool_calls = [] for output in response.output: if output.type == "function_call": - result.action = tool_call_to_python_code(output.name, json.loads(output.arguments)) - result.tool_calls = output - break - elif output.type == "reasoning": - if len(output.summary) > 0: - result.think += output.summary[0].text + "\n" + tool_name = output.name + tool_args = json.loads(output.arguments) + elif output.type == "computer_call": + tool_name, tool_args = self.cua_action_to_env_tool_name_and_args(output.action) + else: + continue + tool_calls.append(ToolCall(name=tool_name, arguments=tool_args, raw_call=output)) + + return ToolCalls(tool_calls=tool_calls, raw_calls=response) + + def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: + """Extracts actions from the response.""" + if not toolcalls: + return None + + actions = [ + AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls + ] + actions = ( + AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) + if len(actions) > 1 + else actions[0] + ) + return actions + def _extract_thinking_content_from_response(self, response: "OpenAIResponseObject") -> str: + """Extracts the thinking content from the response.""" + thinking_content = "" + for output in response.output: + if output.type == "reasoning": + if len(output.summary) > 0: + thinking_content += output.summary[0].text + "\n" elif output.type == "message" and output.content: - result.think += output.content[0].text + "\n" - for key in interesting_keys: - if key_content := getattr(output, "output_text", None) is not None: - result.think += f"<{key}>{key_content}" - return result + thinking_content += output.content[0].text + "\n" + elif hasattr(output, "output_text") and output.output_text: + thinking_content += f"{output.output_text}\n" + return thinking_content + + def cua_action_to_env_tool_name_and_args(self, action: str) -> tuple[str, Dict[str, Any]]: + """ "Overwrite this method to convert a computer action to agentlab action string""" + raise NotImplementedError( + "This method should be implemented in the subclass to convert a computer action to agentlab action string." + ) + + def _extract_env_actions_from_text_response( + self, response: "OpenAIResponseObject" + ) -> str | None: + """Extracts environment actions from the text response.""" + # Use when action space is not given as tools. + pass class OpenAIChatCompletionModel(BaseModelWithPricing): def __init__( self, model_name: str, - client_args: Optional[Dict[str, Any]] = {}, - temperature: float = 0.5, - max_tokens: int = 100, - extra_kwargs: Optional[Dict[str, Any]] = None, - *args, - **kwargs, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float | None = None, + max_tokens: int | None = 100, ): - - self.tools = self.format_tools_for_chat_completion(kwargs.pop("tools", None)) - super().__init__( model_name=model_name, temperature=temperature, max_tokens=max_tokens, - extra_kwargs=extra_kwargs, - *args, - **kwargs, ) - - self.client = OpenAI( - **client_args - ) # Ensures client_args is a dict or defaults to an empty dict - - def _call_api( - self, messages: list[dict | MessageBuilder], tool_choice: str = "auto" - ) -> openai.types.chat.ChatCompletion: + self.action_space_as_tools = True # this should be a config + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + self.client = OpenAI(**client_args) + self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin + + def _call_api(self, payload: APIPayload) -> "openai.types.chat.ChatCompletion": input = [] - for msg in messages: - input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]) + for msg in payload.messages: + input.extend(msg.prepare_message()) api_params: Dict[str, Any] = { "model": self.model_name, "messages": input, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - **self.extra_kwargs, # Pass tools, tool_choice, etc. here } - if self.tools is not None: - api_params["tools"] = self.tools + if self.temperature is not None: + api_params["temperature"] = self.temperature + + if self.max_tokens is not None: + api_params["max_completion_tokens"] = self.max_tokens + + if payload.tools is not None: + # tools format is OpenAI Response API format. + api_params["tools"] = self.format_tools_for_chat_completion(payload.tools) - if tool_choice in ("any", "required"): - tool_choice = "required" - api_params["tool_choice"] = tool_choice + if payload.tool_choice is not None and payload.force_call_tool is None: + api_params["tool_choice"] = ( + "required" if payload.tool_choice in ("required", "any") else payload.tool_choice + ) + + if payload.force_call_tool is not None: + api_params["tool_choice"] = { + "type": "function", + "function": {"name": payload.force_call_tool}, + } response = call_openai_api_with_retries(self.client.chat.completions.create, api_params) return response - def _parse_response(self, response: openai.types.chat.ChatCompletion) -> LLMOutput: + def _parse_response(self, response: "openai.types.chat.ChatCompletion") -> LLMOutput: + think_output = self._extract_thinking_content_from_response(response) + tool_calls = self._extract_tool_calls_from_response(response) - output = LLMOutput( + if self.action_space_as_tools: + env_action = self._extract_env_actions_from_toolcalls(tool_calls) + else: + env_action = self._extract_env_actions_from_text_response(response) + return LLMOutput( raw_response=response, - think="", - action=None, # Default if no tool call - tool_calls=None, + think=think_output, + action=env_action if env_action is not None else None, + tool_calls=tool_calls if tool_calls is not None else None, ) + + def _extract_thinking_content_from_response( + self, response: openai.types.chat.ChatCompletion, wrap_tag="think" + ): + """Extracts the content from the message, including reasoning if available. + It wraps the reasoning around ... for easy identification of reasoning content, + When LLM produces 'text' and 'reasoning' in the same message. + Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. + + Args: + response: The message object or dict containing content and reasoning. + wrap_tag: The tag name to wrap reasoning content (default: "think"). + + Returns: + str: The extracted content with reasoning wrapped in specified tags. + """ + message = response.choices[0].message + if not isinstance(message, dict): + message = message.to_dict() + + reasoning_content = message.get("reasoning", None) + msg_content = message.get("text", "") # works for Open-router + if reasoning_content: + # Wrap reasoning in tags with newlines for clarity + reasoning_content = f"<{wrap_tag}>{reasoning_content}\n" + logging.debug("Extracting content from response.choices[i].message.reasoning") + else: + reasoning_content = "" + return f"{reasoning_content}{msg_content}{message.get('content', '')}" + + def _extract_tool_calls_from_response( + self, response: openai.types.chat.ChatCompletion + ) -> ToolCalls | None: + """Extracts tool calls from the response.""" message = response.choices[0].message.to_dict() - output.think = self.extract_content_with_reasoning(message) - - if tool_calls := message.get("tool_calls", None): - for tool_call in tool_calls: - function = tool_call["function"] - arguments = json.loads(function["arguments"]) - func_args_str = ", ".join( - [ - f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" - for k, v in arguments.items() - ] + tool_calls = message.get("tool_calls", None) + if tool_calls is None: + return None + tool_call_list = [] + for tc in tool_calls: + tool_call_list.append( + ToolCall( + name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + raw_call=tc, ) - output.action = f"{function['name']}({func_args_str})" - output.tool_calls = { - "role": "assistant", - "tool_calls": [message["tool_calls"][0]], # Use only the first tool call - } - break - return output + ) + return ToolCalls(tool_calls=tool_call_list, raw_calls=response) + + def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: + """Extracts actions from the response.""" + if not toolcalls: + return None + + actions = [ + AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls + ] + actions = ( + AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) + if len(actions) > 1 + else actions[0] + ) + return actions + + def _extract_env_actions_from_text_response( + self, response: "openai.types.chat.ChatCompletion" + ) -> str | None: + """Extracts environment actions from the text response.""" + # Use when action space is not given as tools. + pass @staticmethod def format_tools_for_chat_completion(tools): @@ -483,93 +757,69 @@ def format_tools_for_chat_completion(tools): ] return formatted_tools - @staticmethod - def extract_content_with_reasoning(message, wrap_tag="think"): - """Extracts the content from the message, including reasoning if available. - It wraps the reasoning around ... for easy identification of reasoning content, - When LLM produces 'text' and 'reasoning' in the same message. - Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. - - Args: - message: The message object or dict containing content and reasoning. - wrap_tag: The tag name to wrap reasoning content (default: "think"). - - Returns: - str: The extracted content with reasoning wrapped in specified tags. - """ - if not isinstance(message, dict): - message = message.to_dict() - - reasoning_content = message.get("reasoning", None) - msg_content = message.get("text", "") # works for OR - - if reasoning_content: - # Wrap reasoning in tags with newlines for clarity - reasoning_content = f"<{wrap_tag}>{reasoning_content}\n" - logging.debug("Extracting content from response.choices[i].message.reasoning") - else: - reasoning_content = "" - return f"{reasoning_content}{msg_content}{message.get('content', '')}" - class ClaudeResponseModel(BaseModelWithPricing): def __init__( self, model_name: str, + base_url: Optional[str] = None, api_key: Optional[str] = None, - temperature: float = 0.5, - max_tokens: int = 100, - extra_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, + temperature: float | None = None, + max_tokens: int | None = 100, ): - self.tools = kwargs.pop("tools", None) + self.action_space_as_tools = True # this should be a config super().__init__( model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens, - extra_kwargs=extra_kwargs, - **kwargs, ) - - self.client = Anthropic(api_key=api_key) - - def _call_api( - self, messages: list[dict | MessageBuilder], tool_choice="auto", **kwargs - ) -> dict: - input = [] - - sys_msg, other_msgs = self.filter_system_messages(messages) + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + self.client = Anthropic(**client_args) + self.init_pricing_tracker(pricing_api="anthropic") # Use the PricingMixin + + def _call_api(self, payload: APIPayload) -> Completion: + sys_msg, other_msgs = self.filter_system_messages(payload.messages) sys_msg_text = "\n".join(c["text"] for m in sys_msg for c in m.content) + input = [] for msg in other_msgs: - temp = msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg] - if kwargs.pop("use_cache_breakpoints", False): + temp = msg.prepare_message() + if payload.use_cache_breakpoints: temp = self.apply_cache_breakpoints(msg, temp) input.extend(temp) - if tool_choice in ("any", "required"): - tool_choice = "any" # Claude API expects "any" and gpt expects "required" - api_params: Dict[str, Any] = { "model": self.model_name, "messages": input, - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "system": sys_msg_text, # Anthropic API expects system message as a string - "tool_choice": {"type": tool_choice}, # Tool choice for Claude API - **self.extra_kwargs, # Pass tools, tool_choice, etc. here - } - if self.tools is not None: - api_params["tools"] = self.tools - if kwargs.pop("cache_tool_definition", False): - # Indicating cache control for the last tool enables caching of all previous tool definitions. + "system": sys_msg_text, + } # Anthropic API expects system message as a string + + if self.temperature is not None: + api_params["temperature"] = self.temperature + if self.max_tokens is not None: + api_params["max_tokens"] = self.max_tokens + + if payload.tools is not None: + api_params["tools"] = payload.tools + if payload.tool_choice is not None and payload.force_call_tool is None: + api_params["tool_choice"] = ( + {"type": "any"} + if payload.tool_choice in ("required", "any") + else {"type": payload.tool_choice} + ) + if payload.force_call_tool is not None: + api_params["tool_choice"] = {"type": "tool", "name": payload.force_call_tool} + if payload.cache_tool_definition: + # Indicating cache control for the last message enables caching of the last message. api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} - if kwargs.pop("cache_complete_prompt", False): + if payload.cache_complete_prompt: # Indicating cache control for the last message enables caching of the complete prompt. api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} - if self.extra_kwargs.get("reasoning", None) is not None: - api_params["reasoning"] = self.extra_kwargs["reasoning"] response = call_anthropic_api_with_retries(self.client.messages.create, api_params) @@ -592,26 +842,58 @@ def filter_system_messages(messages: list[dict | MessageBuilder]) -> tuple[Messa other_msgs.append(msg) return sys_msgs, other_msgs - def _parse_response(self, response: dict) -> dict: - result = LLMOutput( + def _parse_response(self, response: "AnthrophicMessage") -> LLMOutput: + + toolcalls = self._extract_tool_calls_from_response(response) + think_output = self._extract_thinking_content_from_response(response) + if self.action_space_as_tools: + env_action = self._extract_env_actions_from_toolcalls(toolcalls) + else: + env_action = self._extract_env_actions_from_text_response(response) + return LLMOutput( raw_response=response, - think="", - action=None, - tool_calls={ - "role": "assistant", - "content": response.content, - }, + think=think_output, + action=env_action if env_action is not None else None, + tool_calls=toolcalls if toolcalls is not None else None, ) + + def _extract_tool_calls_from_response(self, response: "AnthrophicMessage") -> ToolCalls: + """Extracts tool calls from the response.""" + tool_calls = [] for output in response.content: if output.type == "tool_use": - result.action = tool_call_to_python_code(output.name, output.input) - elif output.type == "text": - result.think += output.text - return result + tool_calls.append( + ToolCall( + name=output.name, + arguments=output.input, + raw_call=output, + ) + ) + return ToolCalls(tool_calls=tool_calls, raw_calls=response) + + def _extract_thinking_content_from_response(self, response: "AnthrophicMessage"): + """Extracts the thinking content from the response.""" + return "".join(output.text for output in response.content if output.type == "text") + + def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: + """Extracts actions from the response.""" + if not toolcalls: + return None + + actions = [ + AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls + ] + actions = ( + AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) + if len(actions) > 1 + else actions[0] + ) + return actions - # def ensure_cache_conditions(self, msgs: List[Message]) -> bool: - # """Ensure API specific cache conditions are met.""" - # assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message." + def _extract_env_actions_from_text_response(self, response: "AnthrophicMessage") -> str | None: + """Extracts environment actions from the text response.""" + # Use when action space is not given as tools. + pass def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Message]: """Apply cache breakpoints to the messages.""" @@ -621,6 +903,8 @@ def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Mess # Factory classes to create the appropriate model based on the API endpoint. + + @dataclass class OpenAIResponseModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI @@ -628,14 +912,11 @@ class OpenAIResponseModelArgs(BaseModelArgs): api = "openai" - def make_model(self, extra_kwargs=None, **kwargs): + def make_model(self): return OpenAIResponseModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="openai", - **kwargs, ) def get_message_builder(self) -> MessageBuilder: @@ -649,14 +930,11 @@ class ClaudeResponseModelArgs(BaseModelArgs): api = "anthropic" - def make_model(self, extra_kwargs=None, **kwargs): + def make_model(self): return ClaudeResponseModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="anthropic", - **kwargs, ) def get_message_builder(self) -> MessageBuilder: @@ -670,14 +948,11 @@ class OpenAIChatModelArgs(BaseModelArgs): api = "openai" - def make_model(self, extra_kwargs=None, **kwargs): + def make_model(self): return OpenAIChatCompletionModel( model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="openai", - **kwargs, ) def get_message_builder(self) -> MessageBuilder: @@ -691,42 +966,13 @@ class OpenRouterModelArgs(BaseModelArgs): api: str = "openai" # tool description format used by actionset.to_tool_description() in bgym - def make_model(self, extra_kwargs=None, **kwargs): + def make_model(self): return OpenAIChatCompletionModel( - client_args={ - "base_url": "https://openrouter.ai/api/v1", - "api_key": os.getenv("OPENROUTER_API_KEY"), - }, + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), model_name=self.model_name, temperature=self.temperature, max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="openrouter", - **kwargs, - ) - - def get_message_builder(self) -> MessageBuilder: - return OpenAIChatCompletionAPIMessageBuilder - - -class VLLMModelArgs(BaseModelArgs): - """Serializable object for instantiating a generic chat model with a VLLM - model.""" - - api = "openai" # tool description format used by actionset.to_tool_description() in bgym - - def make_model(self, extra_kwargs=None, **kwargs): - return OpenAIChatCompletionModel( - client_args={ - "base_url": "http://localhost:8000/v1", - "api_key": os.getenv("VLLM_API_KEY", "EMPTY"), - }, - model_name=self.model_name, # this needs to be set - temperature=self.temperature, - max_tokens=self.max_new_tokens, - extra_kwargs=extra_kwargs, - pricing_api="vllm", - **kwargs, ) def get_message_builder(self) -> MessageBuilder: @@ -743,3 +989,29 @@ def tool_call_to_python_code(func_name, kwargs): args_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items()) return f"{func_name}({args_str})" + + +# ___Not__Tested__# + +# class VLLMModelArgs(BaseModelArgs): +# """Serializable object for instantiating a generic chat model with a VLLM +# model.""" + +# api = "openai" # tool description format used by actionset.to_tool_description() in bgym + +# def make_model(self, extra_kwargs=None, **kwargs): +# return OpenAIChatCompletionModel( +# client_args={ +# "base_url": "http://localhost:8000/v1", +# "api_key": os.getenv("VLLM_API_KEY", "EMPTY"), +# }, +# model_name=self.model_name, # this needs to be set +# temperature=self.temperature, +# max_tokens=self.max_new_tokens, +# extra_kwargs=extra_kwargs, +# pricing_api="vllm", +# **kwargs, +# ) + +# def get_message_builder(self) -> MessageBuilder: +# return OpenAIChatCompletionAPIMessageBuilder diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 23e48930..b8bcce7c 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -13,7 +13,7 @@ TRACKER = threading.local() -ANTHROPHIC_CACHE_PRICING_FACTOR = { +ANTHROPIC_CACHE_PRICING_FACTOR = { "cache_read_tokens": 0.1, # Cost for 5 min ephemeral cache. See Pricing Here: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing "cache_write_tokens": 1.25, } @@ -151,20 +151,20 @@ class TrackAPIPricingMixin: def reset_stats(self): self.stats = Stats() - def __init__(self, *args, **kwargs): - pricing_api = kwargs.pop("pricing_api", None) + def init_pricing_tracker( + self, pricing_api=None + ): # TODO: Use this function in the base class init instead of having a init in the Mixin class. self._pricing_api = pricing_api - super().__init__(*args, **kwargs) self.set_pricing_attributes() self.reset_stats() def __call__(self, *args, **kwargs): """Call the API and update the pricing tracker.""" + # 'self' here calls ._call_api() method of the subclass response = self._call_api(*args, **kwargs) - usage = dict(getattr(response, "usage", {})) if "prompt_tokens_details" in usage: - usage["cached_tokens"] = usage["prompt_token_details"].cached_tokens + usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens if "input_tokens_details" in usage: usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))} @@ -274,8 +274,8 @@ def get_effective_cost_from_antrophic_api(self, response) -> float: cache_read_tokens = getattr(usage, "cache_input_tokens", 0) cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0) - cache_read_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_read_tokens"] - cache_write_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_write_tokens"] + cache_read_cost = self.input_cost * ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"] + cache_write_cost = self.input_cost * ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"] # Calculate the effective cost effective_cost = ( @@ -284,6 +284,10 @@ def get_effective_cost_from_antrophic_api(self, response) -> float: + cache_read_tokens * cache_read_cost + cache_write_tokens * cache_write_cost ) + if effective_cost < 0: + logging.warning( + "Anthropic: Negative effective cost detected.(Impossible! Likely a bug)" + ) return effective_cost def get_effective_cost_from_openai_api(self, response) -> float: @@ -308,25 +312,29 @@ def get_effective_cost_from_openai_api(self, response) -> float: return 0.0 api_type = "chatcompletion" if hasattr(usage, "prompt_tokens_details") else "response" if api_type == "chatcompletion": - total_input_tokens = usage.prompt_tokens + total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens) output_tokens = usage.completion_tokens cached_input_tokens = usage.prompt_tokens_details.cached_tokens - non_cached_input_tokens = total_input_tokens - cached_input_tokens + new_input_tokens = total_input_tokens - cached_input_tokens elif api_type == "response": - total_input_tokens = usage.input_tokens + total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens) output_tokens = usage.output_tokens cached_input_tokens = usage.input_tokens_details.cached_tokens - non_cached_input_tokens = total_input_tokens - cached_input_tokens + new_input_tokens = total_input_tokens - cached_input_tokens else: logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.") return 0.0 - cache_read_cost = self.input_cost * OPENAI_CACHE_PRICING_FACTOR["cache_read_tokens"] effective_cost = ( - self.input_cost * non_cached_input_tokens + self.input_cost * new_input_tokens + cached_input_tokens * cache_read_cost + self.output_cost * output_tokens ) + if effective_cost < 0: + logging.warning( + f"OpenAI: Negative effective cost detected.(Impossible! Likely a bug). " + f"New input tokens: {total_input_tokens}" + ) return effective_cost diff --git a/tests/agents/test_gaia_agent.py b/tests/agents/test_gaia_agent.py index 0d39f9ef..604ac00c 100644 --- a/tests/agents/test_gaia_agent.py +++ b/tests/agents/test_gaia_agent.py @@ -2,10 +2,15 @@ import uuid from pathlib import Path -from tapeagents.steps import ImageObservation +try: + from tapeagents.steps import ImageObservation -from agentlab.agents.tapeagent.agent import TapeAgent, TapeAgentArgs, load_config -from agentlab.benchmarks.gaia import GaiaBenchmark, GaiaQuestion + from agentlab.agents.tapeagent.agent import TapeAgent, TapeAgentArgs, load_config + from agentlab.benchmarks.gaia import GaiaBenchmark, GaiaQuestion +except ModuleNotFoundError: + import pytest + + pytest.skip("Skipping test due to missing dependencies", allow_module_level=True) def mock_dataset() -> dict: diff --git a/tests/llm/test_response_api.py b/tests/llm/test_response_api.py index 567b49da..6bb639f6 100644 --- a/tests/llm/test_response_api.py +++ b/tests/llm/test_response_api.py @@ -9,6 +9,7 @@ from agentlab.llm import tracking from agentlab.llm.response_api import ( AnthropicAPIMessageBuilder, + APIPayload, ClaudeResponseModelArgs, LLMOutput, OpenAIChatCompletionAPIMessageBuilder, @@ -55,6 +56,9 @@ def create_mock_openai_chat_completion( # or if get_tokens_counts_from_response had different fallback logic. completion.usage.prompt_tokens = prompt_tokens completion.usage.completion_tokens = completion_tokens + prompt_tokens_details_mock = MagicMock() + prompt_tokens_details_mock.cached_tokens = 0 + completion.usage.prompt_tokens_details = prompt_tokens_details_mock completion.model_dump.return_value = { "id": "chatcmpl-xxxx", @@ -68,6 +72,7 @@ def create_mock_openai_chat_completion( "output_tokens": completion_tokens, # Generic name "prompt_tokens": prompt_tokens, # OpenAI specific "completion_tokens": completion_tokens, # OpenAI specific + "prompt_tokens_details": {"cached_tokens": 0}, }, } message.to_dict.return_value = { @@ -78,6 +83,69 @@ def create_mock_openai_chat_completion( return completion +responses_api_tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } +] + +chat_api_tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + } +] +anthropic_tools = [ + { + "name": "get_weather", + "description": "Get the current weather in a given location.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + }, + "required": ["location"], + }, + } +] + + # Helper to create a mock Anthropic response def create_mock_anthropic_response( text_content=None, tool_use=None, input_tokens=15, output_tokens=25 @@ -102,6 +170,8 @@ def create_mock_anthropic_response( response.usage = MagicMock() response.usage.input_tokens = input_tokens response.usage.output_tokens = output_tokens + response.usage.cache_input_tokens = 0 + response.usage.cache_creation_input_tokens = 0 return response @@ -114,7 +184,7 @@ def create_mock_openai_responses_api_response( Compatible with OpenAIResponseModel and TrackAPIPricingMixin. """ - response_mock = MagicMock(openai.types.responses.response) + response_mock = MagicMock(spec=openai.types.responses.response.Response) response_mock.type = "response" response_mock.output = [] @@ -138,11 +208,14 @@ def create_mock_openai_responses_api_response( response_mock.output.append(output_item_mock) # Token usage for pricing tracking - response_mock.usage = MagicMock() + response_mock.usage = MagicMock(spec=openai.types.responses.response.ResponseUsage) response_mock.usage.input_tokens = input_tokens response_mock.usage.output_tokens = output_tokens response_mock.usage.prompt_tokens = input_tokens response_mock.usage.completion_tokens = output_tokens + input_tokens_details_mock = MagicMock() + input_tokens_details_mock.cached_tokens = 0 + response_mock.usage.input_tokens_details = input_tokens_details_mock return response_mock @@ -196,13 +269,6 @@ def test_anthropic_api_message_builder_image(): def test_openai_chat_completion_api_message_builder_text(): builder = OpenAIChatCompletionAPIMessageBuilder.user() builder.add_text("Hello, ChatCompletion!") - # Mock last_response as it's used by tool role - builder.last_raw_response = MagicMock(spec=LLMOutput) - builder.last_raw_response.raw_response = MagicMock() - builder.last_raw_response.raw_response.choices = [MagicMock()] - builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = { - "tool_calls": [{"function": {"name": "some_function"}}] - } messages = builder.prepare_message() assert len(messages) == 1 @@ -213,13 +279,6 @@ def test_openai_chat_completion_api_message_builder_text(): def test_openai_chat_completion_api_message_builder_image(): builder = OpenAIChatCompletionAPIMessageBuilder.user() builder.add_image("") - # Mock last_response - builder.last_raw_response = MagicMock(spec=LLMOutput) - builder.last_raw_response.raw_response = MagicMock() - builder.last_raw_response.raw_response.choices = [MagicMock()] - builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = { - "tool_calls": [{"function": {"name": "some_function"}}] - } messages = builder.prepare_message() assert len(messages) == 1 @@ -230,14 +289,12 @@ def test_openai_chat_completion_api_message_builder_image(): def test_openai_chat_completion_model_parse_and_cost(): - args = OpenAIChatModelArgs(model_name="gpt-3.5-turbo") # A cheap model for testing - # Mock the OpenAI client to avoid needing OPENAI_API_KEY + args = OpenAIChatModelArgs(model_name="gpt-3.5-turbo") with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class: mock_client = MagicMock() mock_openai_class.return_value = mock_client model = args.make_model() - # Mock the API call mock_response = create_mock_openai_chat_completion( content="This is a test thought.", tool_calls=[ @@ -254,17 +311,18 @@ def test_openai_chat_completion_model_parse_and_cost(): with patch.object( model.client.chat.completions, "create", return_value=mock_response ) as mock_create: - with tracking.set_tracker() as global_tracker: # Use your global tracker + with tracking.set_tracker() as global_tracker: messages = [ - OpenAIChatCompletionAPIMessageBuilder.user() - .add_text("What's the weather in Paris?") - .prepare_message()[0] + OpenAIChatCompletionAPIMessageBuilder.user().add_text( + "What's the weather in Paris?" + ) ] - parsed_output = model(messages) + payload = APIPayload(messages=messages) + parsed_output = model(payload) mock_create.assert_called_once() assert parsed_output.raw_response.choices[0].message.content == "This is a test thought." - assert parsed_output.action == 'get_weather(location="Paris")' + assert parsed_output.action == """get_weather(location='Paris')""" assert parsed_output.raw_response.choices[0].message.tool_calls[0].id == "call_123" # Check cost tracking (token counts) assert global_tracker.stats["input_tokens"] == 50 @@ -273,7 +331,7 @@ def test_openai_chat_completion_model_parse_and_cost(): def test_claude_response_model_parse_and_cost(): - args = ClaudeResponseModelArgs(model_name="claude-3-haiku-20240307") # A cheap model + args = ClaudeResponseModelArgs(model_name="claude-3-haiku-20240307") model = args.make_model() mock_anthropic_api_response = create_mock_anthropic_response( @@ -287,33 +345,23 @@ def test_claude_response_model_parse_and_cost(): model.client.messages, "create", return_value=mock_anthropic_api_response ) as mock_create: with tracking.set_tracker() as global_tracker: - messages = [ - AnthropicAPIMessageBuilder.user() - .add_text("Search for latest news") - .prepare_message()[0] - ] - parsed_output = model(messages) + messages = [AnthropicAPIMessageBuilder.user().add_text("Search for latest news")] + payload = APIPayload(messages=messages) + parsed_output = model(payload) mock_create.assert_called_once() - fn_calls = [ - content for content in parsed_output.raw_response.content if content.type == "tool_use" - ] + fn_call = next(iter(parsed_output.tool_calls)) + assert "Thinking about the request." in parsed_output.think - assert parsed_output.action == "search_web(query='latest news')" - assert fn_calls[0].id == "tool_abc" + assert parsed_output.action == """search_web(query='latest news')""" + assert fn_call.name == "search_web" assert global_tracker.stats["input_tokens"] == 40 assert global_tracker.stats["output_tokens"] == 20 - # assert global_tracker.stats["cost"] > 0 # Verify cost is calculated def test_openai_response_model_parse_and_cost(): - """ - Tests OpenAIResponseModel output parsing and cost tracking with both - function_call and reasoning outputs. - """ args = OpenAIResponseModelArgs(model_name="gpt-4.1") - # Mock outputs mock_function_call_output = { "type": "function_call", "name": "get_current_weather", @@ -327,7 +375,6 @@ def test_openai_response_model_parse_and_cost(): output_tokens=40, ) - # Mock the OpenAI client to avoid needing OPENAI_API_KEY with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class: mock_client = MagicMock() mock_openai_class.return_value = mock_client @@ -338,15 +385,16 @@ def test_openai_response_model_parse_and_cost(): ) as mock_create_method: with tracking.set_tracker() as global_tracker: messages = [ - OpenAIResponseAPIMessageBuilder.user() - .add_text("What's the weather in Boston?") - .prepare_message()[0] + OpenAIResponseAPIMessageBuilder.user().add_text("What's the weather in Boston?") ] - parsed_output = model(messages) + payload = APIPayload(messages=messages) + parsed_output = model(payload) mock_create_method.assert_called_once() fn_calls = [ - content for content in parsed_output.raw_response.output if content.type == "function_call" + content + for content in parsed_output.tool_calls.raw_calls.output + if content.type == "function_call" ] assert parsed_output.action == "get_current_weather(location='Boston, MA', unit='celsius')" assert fn_calls[0].call_id == "call_abc123" @@ -368,43 +416,20 @@ def test_openai_chat_completion_model_pricy_call(): max_new_tokens=100, ) - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get the current weather in a given location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather for.", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The unit of temperature.", - }, - }, - "required": ["location"], - }, - } - ] - - model = args.make_model(tools=tools, tool_choice="required") + tools = chat_api_tools + model = args.make_model() with tracking.set_tracker() as global_tracker: messages = [ - OpenAIChatCompletionAPIMessageBuilder.user() - .add_text("What is the weather in Paris?") - .prepare_message()[0] + OpenAIChatCompletionAPIMessageBuilder.user().add_text("What is the weather in Paris?") ] - parsed_output = model(messages) + payload = APIPayload(messages=messages, tools=tools, tool_choice="required") + parsed_output = model(payload) assert parsed_output.raw_response is not None assert ( - parsed_output.action == 'get_weather(location="Paris")' - ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}""" + parsed_output.action == "get_weather(location='Paris')" + ), f""" Expected get_weather(location='Paris') but got {parsed_output.action}""" assert global_tracker.stats["input_tokens"] > 0 assert global_tracker.stats["output_tokens"] > 0 assert global_tracker.stats["cost"] > 0 @@ -420,36 +445,18 @@ def test_claude_response_model_pricy_call(): temperature=1e-5, max_new_tokens=100, ) - tools = [ - { - "name": "get_weather", - "description": "Get the current weather in a given location.", - "input_schema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather for.", - }, - }, - "required": ["location"], - }, - } - ] - model = args.make_model(tools=tools) + tools = anthropic_tools + model = args.make_model() with tracking.set_tracker() as global_tracker: - messages = [ - AnthropicAPIMessageBuilder.user() - .add_text("What is the weather in Paris?") - .prepare_message()[0] - ] - parsed_output = model(messages) + messages = [AnthropicAPIMessageBuilder.user().add_text("What is the weather in Paris?")] + payload = APIPayload(messages=messages, tools=tools) + parsed_output = model(payload) assert parsed_output.raw_response is not None assert ( - parsed_output.action == 'get_weather(location="Paris")' - ), f'Expected get_weather("Paris") but got {parsed_output.action}' + parsed_output.action == "get_weather(location='Paris')" + ), f"""Expected get_weather('Paris') but got {parsed_output.action}""" assert global_tracker.stats["input_tokens"] > 0 assert global_tracker.stats["output_tokens"] > 0 assert global_tracker.stats["cost"] > 0 @@ -464,42 +471,20 @@ def test_openai_response_model_pricy_call(): """ args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100) - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get the current weather in a given location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather for.", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The unit of temperature.", - }, - }, - "required": ["location"], - }, - } - ] - model = args.make_model(tools=tools) + tools = responses_api_tools + model = args.make_model() with tracking.set_tracker() as global_tracker: messages = [ - OpenAIResponseAPIMessageBuilder.user() - .add_text("What is the weather in Paris?") - .prepare_message()[0] + OpenAIResponseAPIMessageBuilder.user().add_text("What is the weather in Paris?") ] - parsed_output = model(messages) + payload = APIPayload(messages=messages, tools=tools) + parsed_output = model(payload) assert parsed_output.raw_response is not None assert ( - parsed_output.action == """get_weather(location="Paris")""" - ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}""" + parsed_output.action == """get_weather(location='Paris')""" + ), f""" Expected get_weather(location='Paris') but got {parsed_output.action}""" assert global_tracker.stats["input_tokens"] > 0 assert global_tracker.stats["output_tokens"] > 0 assert global_tracker.stats["cost"] > 0 @@ -514,61 +499,43 @@ def test_openai_response_model_with_multiple_messages_and_cost_tracking(): """ args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100) - tools = [ - { - "type": "function", - "name": "get_weather", - "description": "Get the current weather in a given location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather for.", - }, - "unit": { - "type": "string", - "enum": ["celsius", "fahrenheit"], - "description": "The unit of temperature.", - }, - }, - "required": ["location"], - }, - } - ] - - model = args.make_model(tools=tools, tool_choice="required") + tools = responses_api_tools + model = args.make_model() builder = args.get_message_builder() messages = [builder.user().add_text("What is the weather in Paris?")] with tracking.set_tracker() as tracker: - # First turn: get initial tool call - parsed = model(messages) + payload = APIPayload(messages=messages, tools=tools, tool_choice="required") + parsed = model(payload) prev_input = tracker.stats["input_tokens"] prev_output = tracker.stats["output_tokens"] prev_cost = tracker.stats["cost"] + assert parsed.tool_calls, "Expected tool calls in the response" + # Set tool responses + for tool_call in parsed.tool_calls: + tool_call.response_text("Its sunny! 25°C") # Simulate tool execution and user follow-up messages += [ - parsed.tool_calls, # Add tool call from the model - builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"), + builder.add_responded_tool_calls(parsed.tool_calls), builder.user().add_text("What is the weather in Delhi?"), ] - parsed = model(messages) + payload = APIPayload(messages=messages, tools=tools, tool_choice="required") + parsed = model(payload) - # Token and cost deltas delta_input = tracker.stats["input_tokens"] - prev_input delta_output = tracker.stats["output_tokens"] - prev_output delta_cost = tracker.stats["cost"] - prev_cost - # Assertions assert prev_input > 0 assert prev_output > 0 assert prev_cost > 0 assert parsed.raw_response is not None - assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}" + assert ( + parsed.action == """get_weather(location='Delhi')""" + ), f"Unexpected action: {parsed.action}" assert delta_input > 0 assert delta_output > 0 assert delta_cost > 0 @@ -609,38 +576,41 @@ def test_openai_chat_completion_model_with_multiple_messages_and_cost_tracking() } ] - model = args.make_model(tools=tools, tool_choice="required") + model = args.make_model() builder = args.get_message_builder() messages = [builder.user().add_text("What is the weather in Paris?")] with tracking.set_tracker() as tracker: - # First turn: get initial tool call - parsed = model(messages) + payload = APIPayload(messages=messages, tools=tools, tool_choice="required") + parsed = model(payload) prev_input = tracker.stats["input_tokens"] prev_output = tracker.stats["output_tokens"] prev_cost = tracker.stats["cost"] + for tool_call in parsed.tool_calls: + tool_call.response_text("Its sunny! 25°C") # Simulate tool execution and user follow-up messages += [ - parsed.tool_calls, # Add tool call from the model - builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"), + builder.add_responded_tool_calls(parsed.tool_calls), builder.user().add_text("What is the weather in Delhi?"), ] + # Set tool responses - parsed = model(messages) + payload = APIPayload(messages=messages, tools=tools, tool_choice="required") + parsed = model(payload) - # Token and cost deltas delta_input = tracker.stats["input_tokens"] - prev_input delta_output = tracker.stats["output_tokens"] - prev_output delta_cost = tracker.stats["cost"] - prev_cost - # Assertions assert prev_input > 0 assert prev_output > 0 assert prev_cost > 0 assert parsed.raw_response is not None - assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}" + assert ( + parsed.action == """get_weather(location='Delhi')""" + ), f"Unexpected action: {parsed.action}" assert delta_input > 0 assert delta_output > 0 assert delta_cost > 0 @@ -676,35 +646,38 @@ def test_claude_model_with_multiple_messages_pricy_call(): }, } ] - model = model_factory.make_model(tools=tools) + model = model_factory.make_model() msg_builder = model_factory.get_message_builder() messages = [] messages.append(msg_builder.user().add_text("What is the weather in Paris?")) with tracking.set_tracker() as global_tracker: - llm_output1 = model(messages) + payload = APIPayload(messages=messages, tools=tools) + llm_output1 = model(payload) prev_input = global_tracker.stats["input_tokens"] prev_output = global_tracker.stats["output_tokens"] prev_cost = global_tracker.stats["cost"] - messages.append(llm_output1.tool_calls) - messages.append(msg_builder.tool(llm_output1.raw_response).add_text("Its sunny! 25°C")) - messages.append(msg_builder.user().add_text("What is the weather in Delhi?")) - llm_output2 = model(messages) - # Token and cost deltas + for tool_call in llm_output1.tool_calls: + tool_call.response_text("It's sunny! 25°C") + messages += [ + msg_builder.add_responded_tool_calls(llm_output1.tool_calls), + msg_builder.user().add_text("What is the weather in Delhi?"), + ] + payload = APIPayload(messages=messages, tools=tools) + llm_output2 = model(payload) delta_input = global_tracker.stats["input_tokens"] - prev_input delta_output = global_tracker.stats["output_tokens"] - prev_output delta_cost = global_tracker.stats["cost"] - prev_cost - # Assertions assert prev_input > 0, "Expected previous input tokens to be greater than 0" assert prev_output > 0, "Expected previous output tokens to be greater than 0" assert prev_cost > 0, "Expected previous cost value to be greater than 0" assert llm_output2.raw_response is not None assert ( - llm_output2.action == 'get_weather(location="Delhi", unit="celsius")' - ), f'Expected get_weather("Delhi") but got {llm_output2.action}' + llm_output2.action == """get_weather(location='Delhi', unit='celsius')""" + ), f"""Expected get_weather('Delhi') but got {llm_output2.action}""" assert delta_input > 0, "Expected new input tokens to be greater than 0" assert delta_output > 0, "Expected new output tokens to be greater than 0" assert delta_cost > 0, "Expected new cost value to be greater than 0" @@ -713,9 +686,71 @@ def test_claude_model_with_multiple_messages_pricy_call(): assert global_tracker.stats["cost"] == pytest.approx(prev_cost + delta_cost) -# TODO: Add tests for image token costing (this is complex and model-specific) -# - For OpenAI, you'd need to know how they bill for images (e.g., fixed cost per image + tokens for text parts) -# - You'd likely need to mock the response from client.chat.completions.create to include specific usage for images. +## Test multiaction +@pytest.mark.pricy +def test_multi_action_tool_calls(): + """ + Test that the model can produce multiple tool calls in parallel. + Uncomment commented lines to see the full behaviour of models and tool choices. + """ + # test_config (setting name, BaseModelArgs, model_name, tools) + tool_test_configs = [ + ( + "gpt-4.1-responses API", + OpenAIResponseModelArgs, + "gpt-4.1-2025-04-14", + responses_api_tools, + ), + ("gpt-4.1-chat Completions API", OpenAIChatModelArgs, "gpt-4.1-2025-04-14", chat_api_tools), + # ("claude-3", ClaudeResponseModelArgs, "claude-3-haiku-20240307", anthropic_tools), # fails + # ("claude-3.7", ClaudeResponseModelArgs, "claude-3-7-sonnet-20250219", anthropic_tools), # fails + ("claude-4-sonnet", ClaudeResponseModelArgs, "claude-sonnet-4-20250514", anthropic_tools), + # add more models as needed + ] + + def add_user_messages(msg_builder): + return [ + msg_builder.user().add_text("What is the weather in Paris and Delhi?"), + msg_builder.user().add_text("You must call multiple tools to achieve the task."), + ] + + res_df = [] + + for tool_choice in [ + # 'none', + # 'required', # fails for Responses API + # 'any', # fails for Responses API + "auto", + # 'get_weather' + ]: + for name, llm_class, checkpoint_name, tools in tool_test_configs: + print(name, "tool choice:", tool_choice, "\n", "**" * 10) + model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None) + llm, msg_builder = model_args.make_model(), model_args.get_message_builder() + messages = add_user_messages(msg_builder) + if tool_choice == "get_weather": # force a specific tool call + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice) + ) + else: + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, tool_choice=tool_choice) + ) + num_tool_calls = len(response.tool_calls) if response.tool_calls else 0 + res_df.append( + { + "model": name, + "checkpoint": checkpoint_name, + "tool_choice": tool_choice, + "num_tool_calls": num_tool_calls, + "action": response.action, + } + ) + assert ( + num_tool_calls == 2 + ), f"Expected 2 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice}" + # import pandas as pd + # print(pd.DataFrame(res_df)) EDGE_CASES = [