diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py
index 3d506bdb..bec693ae 100644
--- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py
+++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py
@@ -22,11 +22,14 @@
from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.llm_utils import image_to_png_base64_url
from agentlab.llm.response_api import (
+ APIPayload,
ClaudeResponseModelArgs,
LLMOutput,
MessageBuilder,
OpenAIChatModelArgs,
OpenAIResponseModelArgs,
+ OpenRouterModelArgs,
+ ToolCalls,
)
from agentlab.llm.tracking import cost_tracker_decorator
@@ -98,7 +101,8 @@ def flatten(self) -> list[MessageBuilder]:
messages.extend(group.messages)
# Mark all summarized messages for caching
if i == len(self.groups) - keep_last_n_obs:
- messages[i].mark_all_previous_msg_for_caching()
+ if not isinstance(messages[i], ToolCalls):
+ messages[i].mark_all_previous_msg_for_caching()
return messages
def set_last_summary(self, summary: MessageBuilder):
@@ -163,18 +167,15 @@ class Obs(Block):
use_dom: bool = False
use_som: bool = False
use_tabs: bool = False
- add_mouse_pointer: bool = False
+ # add_mouse_pointer: bool = False
use_zoomed_webpage: bool = False
def apply(
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
) -> dict:
- if last_llm_output.tool_calls is None:
- obs_msg = llm.msg.user() # type: MessageBuilder
- else:
- obs_msg = llm.msg.tool(last_llm_output.raw_response) # type: MessageBuilder
-
+ obs_msg = llm.msg.user()
+ tool_calls = last_llm_output.tool_calls
if self.use_last_error:
if obs["last_action_error"] != "":
obs_msg.add_text(f"Last action error:\n{obs['last_action_error']}")
@@ -186,13 +187,12 @@ def apply(
else:
screenshot = obs["screenshot"]
- if self.add_mouse_pointer:
- # TODO this mouse pointer should be added at the browsergym level
- screenshot = np.array(
- agent_utils.add_mouse_pointer_from_action(
- Image.fromarray(obs["screenshot"]), obs["last_action"]
- )
- )
+ # if self.add_mouse_pointer:
+ # screenshot = np.array(
+ # agent_utils.add_mouse_pointer_from_action(
+ # Image.fromarray(obs["screenshot"]), obs["last_action"]
+ # )
+ # )
obs_msg.add_image(image_to_png_base64_url(screenshot))
if self.use_axtree:
@@ -203,6 +203,13 @@ def apply(
obs_msg.add_text(_format_tabs(obs))
discussion.append(obs_msg)
+
+ if tool_calls:
+ for call in tool_calls:
+ call.response_text("See Observation")
+ tool_response = llm.msg.add_responded_tool_calls(tool_calls)
+ discussion.append(tool_response)
+
return obs_msg
@@ -254,8 +261,8 @@ def apply(self, llm, discussion: StructuredDiscussion) -> dict:
msg = llm.msg.user().add_text("""Summarize\n""")
discussion.append(msg)
- # TODO need to make sure we don't force tool use here
- summary_response = llm(messages=discussion.flatten(), tool_choice="none")
+
+ summary_response = llm(APIPayload(messages=discussion.flatten()))
summary_msg = llm.msg.assistant().add_text(summary_response.think)
discussion.append(summary_msg)
@@ -320,25 +327,6 @@ def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
discussion.append(msg)
-class ToolCall(Block):
-
- def __init__(self, tool_server):
- self.tool_server = tool_server
-
- def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
- # build the message by adding components to obs
- response: LLMOutput = llm(messages=self.messages)
-
- messages.append(response.assistant_message) # this is tool call
-
- tool_answer = self.tool_server.call_tool(response)
- tool_msg = llm.msg.tool() # type: MessageBuilder
- tool_msg.add_tool_id(response.last_computer_call_id)
- tool_msg.update_last_raw_response(response)
- tool_msg.add_text(str(tool_answer))
- messages.append(tool_msg)
-
-
@dataclass
class PromptConfig:
tag_screenshot: bool = True # Whether to tag the screenshot with the last action.
@@ -394,7 +382,7 @@ def __init__(
self.call_ids = []
- self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})
+ self.llm = model_args.make_model()
self.msg_builder = model_args.get_message_builder()
self.llm.msg = self.msg_builder
@@ -462,13 +450,15 @@ def get_action(self, obs: Any) -> float:
messages = self.discussion.flatten()
response: LLMOutput = self.llm(
- messages=messages,
- tool_choice="any",
- cache_tool_definition=True,
- cache_complete_prompt=False,
- use_cache_breakpoints=True,
+ APIPayload(
+ messages=messages,
+ tools=self.tools, # You can update tools available tools now.
+ tool_choice="any",
+ cache_tool_definition=True,
+ cache_complete_prompt=False,
+ use_cache_breakpoints=True,
+ )
)
-
action = response.action
think = response.think
last_summary = self.discussion.get_last_summary()
@@ -476,7 +466,7 @@ def get_action(self, obs: Any) -> float:
think = last_summary.content[0]["text"] + "\n" + think
self.discussion.new_group()
- self.discussion.append(response.tool_calls)
+ # self.discussion.append(response.tool_calls) # No need to append tool calls anymore.
self.last_response = response
self._responses.append(response) # may be useful for debugging
@@ -486,8 +476,11 @@ def get_action(self, obs: Any) -> float:
tools_msg = MessageBuilder("tool_description").add_text(tools_str)
# Adding these extra messages to visualize in gradio
- messages.insert(0, tools_msg) # insert at the beginning of the messages
- messages.append(response.tool_calls)
+ messages.insert(0, tools_msg) # insert at the beginning of the message
+ # This avoids the assertion error with self.llm.user().add_responded_tool_calls(tool_calls)
+ msg = self.llm.msg("tool")
+ msg.responded_tool_calls = response.tool_calls
+ messages.append(msg)
agent_info = bgym.AgentInfo(
think=think,
@@ -533,6 +526,31 @@ def get_action(self, obs: Any) -> float:
vision_support=True,
)
+O3_RESPONSE_MODEL = OpenAIResponseModelArgs(
+ model_name="o3-2025-04-16",
+ max_total_tokens=200_000,
+ max_input_tokens=200_000,
+ max_new_tokens=2_000,
+ temperature=None, # O3 does not support temperature
+ vision_support=True,
+)
+O3_CHATAPI_MODEL = OpenAIChatModelArgs(
+ model_name="o3-2025-04-16",
+ max_total_tokens=200_000,
+ max_input_tokens=200_000,
+ max_new_tokens=2_000,
+ temperature=None,
+ vision_support=True,
+)
+
+GPT4_1_OPENROUTER_MODEL = OpenRouterModelArgs(
+ model_name="openai/gpt-4.1",
+ max_total_tokens=200_000,
+ max_input_tokens=200_000,
+ max_new_tokens=2_000,
+ temperature=None, # O3 does not support temperature
+ vision_support=True,
+)
DEFAULT_PROMPT_CONFIG = PromptConfig(
tag_screenshot=True,
@@ -548,8 +566,8 @@ def get_action(self, obs: Any) -> float:
summarizer=Summarizer(do_summary=True),
general_hints=GeneralHints(use_hints=False),
task_hint=TaskHint(use_task_hint=True),
- keep_last_n_obs=None, # keep only the last observation in the discussion
- multiaction=False, # whether to use multi-action or not
+ keep_last_n_obs=None,
+ multiaction=True, # whether to use multi-action or not
# action_subsets=("bid",),
action_subsets=("coord"),
# action_subsets=("coord", "bid"),
@@ -559,3 +577,18 @@ def get_action(self, obs: Any) -> float:
model_args=CLAUDE_MODEL_CONFIG,
config=DEFAULT_PROMPT_CONFIG,
)
+
+OAI_AGENT = ToolUseAgentArgs(
+ model_args=GPT_4_1,
+ config=DEFAULT_PROMPT_CONFIG,
+)
+
+OAI_CHATAPI_AGENT = ToolUseAgentArgs(
+ model_args=O3_CHATAPI_MODEL,
+ config=DEFAULT_PROMPT_CONFIG,
+)
+
+OAI_OPENROUTER_AGENT = ToolUseAgentArgs(
+ model_args=GPT4_1_OPENROUTER_MODEL,
+ config=DEFAULT_PROMPT_CONFIG,
+)
diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py
index bd1c6ad4..8b846f5f 100644
--- a/src/agentlab/analyze/agent_xray.py
+++ b/src/agentlab/analyze/agent_xray.py
@@ -26,6 +26,7 @@
from agentlab.llm.llm_utils import BaseMessage as AgentLabBaseMessage
from agentlab.llm.llm_utils import Discussion
from agentlab.llm.response_api import MessageBuilder
+from agentlab.llm.response_api import ToolCalls
select_dir_instructions = "Select Experiment Directory"
AGENT_NAME_KEY = "agent.agent_name"
@@ -673,6 +674,9 @@ def dict_to_markdown(d: dict):
str: A markdown-formatted string representation of the dictionary.
"""
if not isinstance(d, dict):
+ if isinstance(d, ToolCalls):
+ # ToolCalls rendered by to_markdown method.
+ return ""
warning(f"Expected dict, got {type(d)}")
return repr(d)
if not d:
diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py
index 8e6f3695..e8c74849 100644
--- a/src/agentlab/llm/response_api.py
+++ b/src/agentlab/llm/response_api.py
@@ -3,10 +3,12 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import Any, Dict, List, Literal, Optional, Union
import openai
from anthropic import Anthropic
+from anthropic.types import Completion
+from anthropic.types import Message as AnthrophicMessage
from openai import OpenAI
from agentlab.llm.llm_utils import image_to_png_base64_url
@@ -30,23 +32,90 @@
Message = Dict[str, Union[str, List[ContentItem]]]
+@dataclass
+class ToolCall:
+ """Represents a tool call made by the LLM.
+ Attributes:
+ name: Name of the tool called.
+ arguments: Arguments passed to the tool.
+ raw_call: The raw call object from the LLM API.
+ tool_response: Output of the tool call goes here. It can be only one content item.
+ """
+
+ name: str = field(default=None)
+ arguments: Dict[str, Any] = field(default_factory=dict)
+ raw_call: Any = field(default=None)
+ tool_response: ContentItem = None
+
+ @property
+ def is_response_set(self) -> bool:
+ """Check if the tool response is set."""
+ return self.tool_response is not None
+
+ def response_text(self, text: str) -> "MessageBuilder":
+ self.tool_response = {"text": text}
+ return self
+
+ def response_image(self, image: str) -> "MessageBuilder":
+ self.tool_response = {"image": image}
+ return self
+
+ def __repr__(self):
+ return f"ToolCall(name={self.name}, arguments={self.arguments})"
+
+
+@dataclass
+class ToolCalls:
+ """A collection of tool calls made by the LLM.
+
+ Attributes:
+ tool_calls: List of ToolCall objects.
+ raw_calls: Represents raw tool calls object returned by a LLM API, may contain one or more tool calls.
+ """
+
+ tool_calls: List[ToolCall] = field(default_factory=list)
+ raw_calls: List[Any] = field(default_factory=list)
+
+ def add_tool_call(self, tool_call: ToolCall) -> "ToolCalls":
+ self.tool_calls.append(tool_call)
+ return self
+
+ @property
+ def all_responses_set(self) -> bool:
+ """Check if all tool calls have responses set."""
+ return all(call.is_response_set for call in self.tool_calls)
+
+ def __len__(self) -> int:
+ """Return the number of tool calls."""
+ return len(self.tool_calls)
+
+ def __iter__(self):
+ """Make ToolCalls iterable."""
+ return iter(self.tool_calls)
+
+ def __bool__(self):
+ """Check if there are any tool calls."""
+ return len(self.tool_calls) > 0
+
+
@dataclass
class LLMOutput:
"""Serializable object for the output of a response LLM."""
- raw_response: Any = field(default_factory=dict)
+ raw_response: Any = field(default=None)
think: str = field(default="")
- action: str = field(default=None) # Default action if no tool call is made
- tool_calls: Any = field(default=None) # This will hold the tool call response if any
+ action: str | None = field(default=None) # Default action if no tool call is made
+ tool_calls: ToolCalls | None = field(
+ default=None
+ ) # This will hold the tool call response if any
class MessageBuilder:
def __init__(self, role: str):
self.role = role
- self.last_raw_response: LLMOutput = None
self.content: List[ContentItem] = []
- self.tool_call_id: Optional[str] = None
+ self.responded_tool_calls: ToolCalls = None
@classmethod
def system(cls) -> "MessageBuilder":
@@ -60,19 +129,11 @@ def user(cls) -> "MessageBuilder":
def assistant(cls) -> "MessageBuilder":
return cls("assistant")
- @classmethod
- def tool(cls, last_raw_response) -> "MessageBuilder":
- return cls("tool").update_last_raw_response(last_raw_response)
-
@abstractmethod
def prepare_message(self) -> List[Message]:
"""Prepare the message for the API call."""
raise NotImplementedError("Subclasses must implement this method.")
- def update_last_raw_response(self, last_raw_response: Any) -> "MessageBuilder":
- self.last_raw_response = last_raw_response
- return self
-
def add_text(self, text: str) -> "MessageBuilder":
self.content.append({"text": text})
return self
@@ -89,6 +150,22 @@ def to_markdown(self) -> str:
elif "image" in item:
parts.append(f"")
+ # Tool call markdown repr
+ if self.responded_tool_calls is not None:
+ for i, tool_call in enumerate(self.responded_tool_calls.tool_calls, 1):
+ parts.append(
+ f"\n**Tool Call {i}**: {tool_call_to_python_code(tool_call.name, tool_call.arguments)}"
+ )
+ response = tool_call.tool_response
+ if response is not None:
+ parts.append(f"\n**Tool Response {i}:**")
+ content = (
+ f"```\n{response['text']}\n```"
+ if "text" in response
+ else f""
+ )
+ parts.append(content)
+
markdown = f"### {self.role.capitalize()}\n"
markdown += "\n".join(parts)
@@ -104,8 +181,13 @@ def mark_all_previous_msg_for_caching(self):
# This is a placeholder for future implementation.
raise NotImplementedError
-
-# TODO: Support parallel tool calls.
+ @classmethod
+ def add_responded_tool_calls(cls, responded_tool_calls: ToolCalls) -> "MessageBuilder":
+ """Add tool calls to the message content."""
+ assert responded_tool_calls.all_responses_set, "All tool calls must have a response."
+ msg = cls("tool")
+ msg.responded_tool_calls = responded_tool_calls
+ return msg
class OpenAIResponseAPIMessageBuilder(MessageBuilder):
@@ -117,44 +199,62 @@ def system(cls) -> "OpenAIResponseAPIMessageBuilder":
def prepare_message(self) -> List[Message]:
content = []
for item in self.content:
- if "text" in item:
- content_type = "input_text" if self.role != "assistant" else "output_text"
- content.append({"type": content_type, "text": item["text"]})
+ content.append(self.convert_content_to_expected_format(item))
+ output = [{"role": self.role, "content": content}]
- elif "image" in item:
- content.append({"type": "input_image", "image_url": item["image"]})
+ return output if self.role != "tool" else self.handle_tool_call()
- output = [{"role": self.role, "content": content}]
- if self.role != "tool":
- return output
+ def convert_content_to_expected_format(self, content: ContentItem) -> ContentItem:
+ """Convert the content item to the expected format for OpenAI Responses."""
+ if "text" in content:
+ content_type = "input_text" if self.role != "assistant" else "output_text"
+ return {"type": content_type, "text": content["text"]}
+ elif "image" in content:
+ return {"type": "input_image", "image_url": content["image"]}
else:
- tool_call_response = self.handle_tool_call(content)
- return tool_call_response
+ raise ValueError(f"Unsupported content type: {content}")
- def handle_tool_call(self, content):
+ def handle_tool_call(self) -> List[Message]:
"""Handle the tool call response from the last raw response."""
+ if self.responded_tool_calls is None:
+ raise ValueError("No tool calls found in responded_tool_calls")
+
output = []
- head_content, *tail_content = content
- api_response = self.last_raw_response
- fn_calls = [content for content in api_response.output if content.type == "function_call"]
- assert len(fn_calls) > 0, "No function calls found in the last response"
- if len(fn_calls) > 1:
- logging.warning("Using only the first tool call from many.")
-
- first_fn_call_id = fn_calls[0].call_id
- fn_output = head_content.get("text", "Function call answer in next message")
- fn_call_response = {
- "type": "function_call_output",
- "call_id": first_fn_call_id,
- "output": fn_output,
- }
- output.append(fn_call_response)
- if tail_content:
- # if there are more content items, add them as a new user message
- output.append({"role": "user", "content": tail_content})
+ output.extend(self.responded_tool_calls.raw_calls.output) # this contains response
+ for fn_call in self.responded_tool_calls:
+ call_type = fn_call.raw_call.type
+ call_id = fn_call.raw_call.call_id
+ call_response = fn_call.tool_response
+
+ match call_type:
+ case "function_call":
+ # image output is not supported in function calls response.
+ assert (
+ "image" not in call_response
+ ), "Image output is not supported in function calls response."
+ fn_call_response = {
+ "type": "function_call_output",
+ "call_id": call_id,
+ "output": self.convert_content_to_expected_format(call_response)["text"],
+ }
+ output.append(fn_call_response)
+
+ case "computer_call":
+ # For computer calls, use only images are expected.
+ assert (
+ "text" not in call_response
+ ), "Text output is not supported in computer calls response."
+ computer_call_output = {
+ "type": "computer_call_output",
+ "call_id": call_id,
+ "output": self.convert_content_to_expected_format(call_response),
+ }
+ output.append(computer_call_output) # this needs to be a screenshot
+
return output
- def mark_all_previous_msg_for_caching(self) -> List[Message]:
+ def mark_all_previous_msg_for_caching(self):
+ """Nothing special to do here for openAI. They do not have a notion of cache breakpoints."""
pass
@@ -164,29 +264,9 @@ def prepare_message(self) -> List[Message]:
content = [self.transform_content(item) for item in self.content]
output = {"role": self.role, "content": content}
- if self.role == "system":
- logging.info(
- "Treating system message as 'user'. In the Anthropic API, system messages should be passed as a direct input to the client."
- )
- output["role"] = "user"
-
if self.role == "tool":
+ return self.handle_tool_call()
- api_response = self.last_raw_response
- fn_calls = [content for content in api_response.content if content.type == "tool_use"]
- assert len(fn_calls) > 0, "No tool calls found in the last response"
- if len(fn_calls) > 1:
- logging.warning("Using only the first tool call from many.")
- tool_call_id = fn_calls[0].id # Using the first tool call ID
-
- output["role"] = "user"
- output["content"] = [
- {
- "type": "tool_result",
- "tool_use_id": tool_call_id,
- "content": output["content"],
- }
- ]
if self.role == "assistant":
# Strip whitespace from assistant text responses. See anthropic error code 400.
for c in output["content"]:
@@ -194,6 +274,32 @@ def prepare_message(self) -> List[Message]:
c["text"] = c["text"].strip()
return [output]
+ def handle_tool_call(self) -> List[Message]:
+ """Handle the tool call response from the last raw response."""
+ if self.responded_tool_calls is None:
+ raise ValueError("No tool calls found in responded_tool_calls")
+
+ llm_tool_call = {
+ "role": "assistant",
+ "content": self.responded_tool_calls.raw_calls.content,
+ } # Add the toolcall block
+ tool_response = {"role": "user", "content": []} # Anthropic expects a list of messages
+ for call in self.responded_tool_calls:
+ assert (
+ "image" not in call.tool_response
+ ), "Image output is not supported in tool calls response."
+ tool_response["content"].append(
+ {
+ "type": "tool_result",
+ "tool_use_id": call.raw_call.id,
+ "content": self.transform_content(call.tool_response)[
+ "text"
+ ], # needs to be str
+ }
+ )
+
+ return [llm_tool_call, tool_response]
+
def transform_content(self, content: ContentItem) -> ContentItem:
"""Transform content item to the format expected by Anthropic API."""
if "text" in content:
@@ -224,13 +330,13 @@ class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder):
def prepare_message(self) -> List[Message]:
"""Prepare the message for the OpenAI API."""
- content = [self.transform_content(item) for item in self.content]
- if self.role == "tool":
- return self.handle_tool_call(content)
- else:
- return [{"role": self.role, "content": content}]
+ content = []
+ for item in self.content:
+ content.append(self.convert_content_to_expected_format(item))
+ output = [{"role": self.role, "content": content}]
+ return output if self.role != "tool" else self.handle_tool_call()
- def transform_content(self, content: ContentItem) -> ContentItem:
+ def convert_content_to_expected_format(self, content: ContentItem) -> ContentItem:
"""Transform content item to the format expected by OpenAI ChatCompletion."""
if "text" in content:
return {"type": "text", "text": content["text"]}
@@ -239,30 +345,57 @@ def transform_content(self, content: ContentItem) -> ContentItem:
else:
raise ValueError(f"Unsupported content type: {content}")
- def handle_tool_call(self, content) -> List[Message]:
+ def handle_tool_call(self) -> List[Message]:
"""Handle the tool call response from the last raw response."""
+ if self.responded_tool_calls is None:
+ raise ValueError("No tool calls found in responded_tool_calls")
output = []
- content_head, *content_tail = content
- api_response = self.last_raw_response.choices[0].message
- fn_calls = getattr(api_response, "tool_calls", None)
- assert fn_calls is not None, "Tool calls not found in the last response"
- if len(fn_calls) > 1:
- logging.warning("Using only the first tool call from many.")
-
- # a function_call_output dict has keys "role", "tool_call_id" and "content"
- tool_call_reponse = {
- "role": "tool",
- "tool_call_id": fn_calls[0].id, # using the first tool call ID
- "content": content_head.get("text", "Tool call answer in next message"),
- "name": fn_calls[0].function.name, # required with OpenRouter
- }
+ output.append(
+ self.responded_tool_calls.raw_calls.choices[0].message
+ ) # add raw calls to output
+ for fn_call in self.responded_tool_calls:
+ raw_call = fn_call.raw_call
+ assert (
+ "image" not in fn_call.tool_response
+ ), "Image output is not supported in function calls response."
+ # a function_call_output dict has keys "role", "tool_call_id" and "content"
+ tool_call_reponse = {
+ "name": raw_call["function"]["name"], # required with OpenRouter
+ "role": "tool",
+ "tool_call_id": raw_call["id"],
+ "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"],
+ }
+ output.append(tool_call_reponse)
- output.append(tool_call_reponse)
- if content_tail:
- # if there are more content items, add them as a new user message
- output.append({"role": "user", "content": content_tail})
return output
+ def mark_all_previous_msg_for_caching(self):
+ """Nothing special to do here for openAI. They do not have a notion of cache breakpoints."""
+ pass
+
+
+@dataclass
+class APIPayload:
+ messages: List[MessageBuilder] | None = None
+ tools: List[Dict[str, Any]] | None = None
+ tool_choice: Literal["none", "auto", "any", "required"] | None = None
+ force_call_tool: str | None = (
+ None # Name of the tool to call # If set, will force the LLM to call this tool.
+ )
+ use_cache_breakpoints: bool = (
+ False # If True, will apply cache breakpoints to the messages. # applicable for Anthropic
+ )
+ cache_tool_definition: bool = (
+ False # If True, will cache the tool definition in the last message.
+ )
+ cache_complete_prompt: bool = (
+ False # If True, will cache the complete prompt in the last message.
+ )
+
+ def __post_init__(self):
+ if self.tool_choice and self.force_call_tool:
+ raise ValueError("tool_choice and force_call_tool are mutually exclusive")
+
# # Base class for all API Endpoints
class BaseResponseModel(ABC):
@@ -270,25 +403,22 @@ def __init__(
self,
model_name: str,
api_key: Optional[str] = None,
- temperature: float = 0.5,
- max_tokens: int = 100,
- extra_kwargs: Optional[Dict[str, Any]] = None,
+ temperature: float | None = None,
+ max_tokens: int | None = None,
):
self.model_name = model_name
self.api_key = api_key
self.temperature = temperature
self.max_tokens = max_tokens
- self.extra_kwargs = extra_kwargs or {}
-
super().__init__()
- def __call__(self, messages: list[dict | MessageBuilder], **kwargs) -> dict:
+ def __call__(self, payload: APIPayload) -> LLMOutput:
"""Make a call to the model and return the parsed response."""
- response = self._call_api(messages, **kwargs)
+ response = self._call_api(payload)
return self._parse_response(response)
@abstractmethod
- def _call_api(self, messages: list[dict | MessageBuilder], **kwargs) -> Any:
+ def _call_api(self, payload: APIPayload) -> Any:
"""Make a call to the model API and return the raw response."""
pass
@@ -298,6 +428,38 @@ def _parse_response(self, response: Any) -> LLMOutput:
pass
+class AgentlabAction:
+ """
+ Collection of utility function to convert tool calls to Agentlab action format.
+ """
+
+ @staticmethod
+ def convert_toolcall_to_agentlab_action_format(toolcall: ToolCall) -> str:
+ """Convert a tool call to an Agentlab environment action string.
+
+ Args:
+ toolcall: ToolCall object containing the name and arguments of the tool call.
+
+ Returns:
+ A string representing the action in Agentlab format i.e. python function call string.
+ """
+
+ tool_name, tool_args = toolcall.name, toolcall.arguments
+ return tool_call_to_python_code(tool_name, tool_args)
+
+ @staticmethod
+ def convert_multiactions_to_agentlab_action_format(actions: list[str]) -> str | None:
+ """Convert multiple actions list to a format that env supports.
+
+ Args:
+ actions: List of action strings to be joined.
+
+ Returns:
+ Joined actions separated by newlines, or None if empty.
+ """
+ return "\n".join(actions) if actions else None
+
+
class BaseModelWithPricing(TrackAPIPricingMixin, BaseResponseModel):
pass
@@ -306,46 +468,47 @@ class OpenAIResponseModel(BaseModelWithPricing):
def __init__(
self,
model_name: str,
+ base_url: Optional[str] = None,
api_key: Optional[str] = None,
- temperature: float = 0.5,
- max_tokens: int = 100,
- extra_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs,
+ temperature: float | None = None,
+ max_tokens: int | None = 100,
):
- self.tools = kwargs.pop("tools", None)
- super().__init__(
- model_name=model_name,
- api_key=api_key,
- temperature=temperature,
- max_tokens=max_tokens,
- extra_kwargs=extra_kwargs,
- **kwargs,
+ self.action_space_as_tools = True # this should be a config
+ super().__init__( # This is passed to BaseModel
+ model_name=model_name, api_key=api_key, temperature=temperature, max_tokens=max_tokens
)
- self.client = OpenAI(api_key=api_key)
+ client_args = {}
+ if base_url is not None:
+ client_args["base_url"] = base_url
+ if api_key is not None:
+ client_args["api_key"] = api_key
+ self.client = OpenAI(**client_args)
+ # Init pricing tracker after super() so that all attributes have been set.
+ self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin
- def _call_api(
- self, messages: list[Any | MessageBuilder], tool_choice: str = "auto", **kwargs
- ) -> dict:
- input = []
- for msg in messages:
- input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
+ def _call_api(self, payload: APIPayload) -> "OpenAIResponseObject":
+ input = []
+ for msg in payload.messages:
+ input.extend(msg.prepare_message())
api_params: Dict[str, Any] = {
"model": self.model_name,
"input": input,
- "temperature": self.temperature,
- "max_output_tokens": self.max_tokens,
- **self.extra_kwargs,
}
+ # Not all Open AI models support these parameters (example: o3), so we check if they are set.
+ if self.temperature is not None:
+ api_params["temperature"] = self.temperature
+ if self.max_tokens is not None:
+ api_params["max_output_tokens"] = self.max_tokens
+ if payload.tools is not None:
+ api_params["tools"] = payload.tools
+ if payload.tool_choice is not None and payload.force_call_tool is None:
+ api_params["tool_choice"] = (
+ "required" if payload.tool_choice in ("required", "any") else payload.tool_choice
+ )
+ if payload.force_call_tool is not None:
+ api_params["tool_choice"] = {"type": "function", "name": payload.force_call_tool}
- if self.tools is not None:
- api_params["tools"] = self.tools
- if tool_choice in ("any", "required"):
- tool_choice = "required"
-
- api_params["tool_choice"] = tool_choice
-
- # api_params |= kwargs # Merge any additional parameters passed
response = call_openai_api_with_retries(
self.client.responses.create,
api_params,
@@ -353,110 +516,221 @@ def _call_api(
return response
- def _parse_response(self, response: dict) -> dict:
- result = LLMOutput(
+ def _parse_response(self, response: "OpenAIResponseObject") -> LLMOutput:
+ """Parse the raw response from the OpenAI Responses API."""
+
+ think_output = self._extract_thinking_content_from_response(response)
+ toolcalls = self._extract_tool_calls_from_response(response)
+
+ if self.action_space_as_tools:
+ env_action = self._extract_env_actions_from_toolcalls(toolcalls)
+ else:
+ env_action = self._extract_env_actions_from_text_response(response)
+
+ return LLMOutput(
raw_response=response,
- think="",
- action=None,
- tool_calls=None,
+ think=think_output,
+ action=env_action if env_action is not None else None,
+ tool_calls=toolcalls if toolcalls is not None else None,
)
- interesting_keys = ["output_text"]
+
+ def _extract_tool_calls_from_response(self, response: "OpenAIResponseObject") -> ToolCalls:
+ """Extracts tool calls from the response."""
+ tool_calls = []
for output in response.output:
if output.type == "function_call":
- result.action = tool_call_to_python_code(output.name, json.loads(output.arguments))
- result.tool_calls = output
- break
- elif output.type == "reasoning":
- if len(output.summary) > 0:
- result.think += output.summary[0].text + "\n"
+ tool_name = output.name
+ tool_args = json.loads(output.arguments)
+ elif output.type == "computer_call":
+ tool_name, tool_args = self.cua_action_to_env_tool_name_and_args(output.action)
+ else:
+ continue
+ tool_calls.append(ToolCall(name=tool_name, arguments=tool_args, raw_call=output))
+
+ return ToolCalls(tool_calls=tool_calls, raw_calls=response)
+
+ def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
+ """Extracts actions from the response."""
+ if not toolcalls:
+ return None
+
+ actions = [
+ AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
+ ]
+ actions = (
+ AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
+ if len(actions) > 1
+ else actions[0]
+ )
+ return actions
+ def _extract_thinking_content_from_response(self, response: "OpenAIResponseObject") -> str:
+ """Extracts the thinking content from the response."""
+ thinking_content = ""
+ for output in response.output:
+ if output.type == "reasoning":
+ if len(output.summary) > 0:
+ thinking_content += output.summary[0].text + "\n"
elif output.type == "message" and output.content:
- result.think += output.content[0].text + "\n"
- for key in interesting_keys:
- if key_content := getattr(output, "output_text", None) is not None:
- result.think += f"<{key}>{key_content}{key}>"
- return result
+ thinking_content += output.content[0].text + "\n"
+ elif hasattr(output, "output_text") and output.output_text:
+ thinking_content += f"{output.output_text}\n"
+ return thinking_content
+
+ def cua_action_to_env_tool_name_and_args(self, action: str) -> tuple[str, Dict[str, Any]]:
+ """ "Overwrite this method to convert a computer action to agentlab action string"""
+ raise NotImplementedError(
+ "This method should be implemented in the subclass to convert a computer action to agentlab action string."
+ )
+
+ def _extract_env_actions_from_text_response(
+ self, response: "OpenAIResponseObject"
+ ) -> str | None:
+ """Extracts environment actions from the text response."""
+ # Use when action space is not given as tools.
+ pass
class OpenAIChatCompletionModel(BaseModelWithPricing):
def __init__(
self,
model_name: str,
- client_args: Optional[Dict[str, Any]] = {},
- temperature: float = 0.5,
- max_tokens: int = 100,
- extra_kwargs: Optional[Dict[str, Any]] = None,
- *args,
- **kwargs,
+ base_url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ temperature: float | None = None,
+ max_tokens: int | None = 100,
):
-
- self.tools = self.format_tools_for_chat_completion(kwargs.pop("tools", None))
-
super().__init__(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
- extra_kwargs=extra_kwargs,
- *args,
- **kwargs,
)
-
- self.client = OpenAI(
- **client_args
- ) # Ensures client_args is a dict or defaults to an empty dict
-
- def _call_api(
- self, messages: list[dict | MessageBuilder], tool_choice: str = "auto"
- ) -> openai.types.chat.ChatCompletion:
+ self.action_space_as_tools = True # this should be a config
+ client_args = {}
+ if base_url is not None:
+ client_args["base_url"] = base_url
+ if api_key is not None:
+ client_args["api_key"] = api_key
+ self.client = OpenAI(**client_args)
+ self.init_pricing_tracker(pricing_api="openai") # Use the PricingMixin
+
+ def _call_api(self, payload: APIPayload) -> "openai.types.chat.ChatCompletion":
input = []
- for msg in messages:
- input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
+ for msg in payload.messages:
+ input.extend(msg.prepare_message())
api_params: Dict[str, Any] = {
"model": self.model_name,
"messages": input,
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- **self.extra_kwargs, # Pass tools, tool_choice, etc. here
}
- if self.tools is not None:
- api_params["tools"] = self.tools
+ if self.temperature is not None:
+ api_params["temperature"] = self.temperature
+
+ if self.max_tokens is not None:
+ api_params["max_completion_tokens"] = self.max_tokens
+
+ if payload.tools is not None:
+ # tools format is OpenAI Response API format.
+ api_params["tools"] = self.format_tools_for_chat_completion(payload.tools)
- if tool_choice in ("any", "required"):
- tool_choice = "required"
- api_params["tool_choice"] = tool_choice
+ if payload.tool_choice is not None and payload.force_call_tool is None:
+ api_params["tool_choice"] = (
+ "required" if payload.tool_choice in ("required", "any") else payload.tool_choice
+ )
+
+ if payload.force_call_tool is not None:
+ api_params["tool_choice"] = {
+ "type": "function",
+ "function": {"name": payload.force_call_tool},
+ }
response = call_openai_api_with_retries(self.client.chat.completions.create, api_params)
return response
- def _parse_response(self, response: openai.types.chat.ChatCompletion) -> LLMOutput:
+ def _parse_response(self, response: "openai.types.chat.ChatCompletion") -> LLMOutput:
+ think_output = self._extract_thinking_content_from_response(response)
+ tool_calls = self._extract_tool_calls_from_response(response)
- output = LLMOutput(
+ if self.action_space_as_tools:
+ env_action = self._extract_env_actions_from_toolcalls(tool_calls)
+ else:
+ env_action = self._extract_env_actions_from_text_response(response)
+ return LLMOutput(
raw_response=response,
- think="",
- action=None, # Default if no tool call
- tool_calls=None,
+ think=think_output,
+ action=env_action if env_action is not None else None,
+ tool_calls=tool_calls if tool_calls is not None else None,
)
+
+ def _extract_thinking_content_from_response(
+ self, response: openai.types.chat.ChatCompletion, wrap_tag="think"
+ ):
+ """Extracts the content from the message, including reasoning if available.
+ It wraps the reasoning around ... for easy identification of reasoning content,
+ When LLM produces 'text' and 'reasoning' in the same message.
+ Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered.
+
+ Args:
+ response: The message object or dict containing content and reasoning.
+ wrap_tag: The tag name to wrap reasoning content (default: "think").
+
+ Returns:
+ str: The extracted content with reasoning wrapped in specified tags.
+ """
+ message = response.choices[0].message
+ if not isinstance(message, dict):
+ message = message.to_dict()
+
+ reasoning_content = message.get("reasoning", None)
+ msg_content = message.get("text", "") # works for Open-router
+ if reasoning_content:
+ # Wrap reasoning in tags with newlines for clarity
+ reasoning_content = f"<{wrap_tag}>{reasoning_content}{wrap_tag}>\n"
+ logging.debug("Extracting content from response.choices[i].message.reasoning")
+ else:
+ reasoning_content = ""
+ return f"{reasoning_content}{msg_content}{message.get('content', '')}"
+
+ def _extract_tool_calls_from_response(
+ self, response: openai.types.chat.ChatCompletion
+ ) -> ToolCalls | None:
+ """Extracts tool calls from the response."""
message = response.choices[0].message.to_dict()
- output.think = self.extract_content_with_reasoning(message)
-
- if tool_calls := message.get("tool_calls", None):
- for tool_call in tool_calls:
- function = tool_call["function"]
- arguments = json.loads(function["arguments"])
- func_args_str = ", ".join(
- [
- f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
- for k, v in arguments.items()
- ]
+ tool_calls = message.get("tool_calls", None)
+ if tool_calls is None:
+ return None
+ tool_call_list = []
+ for tc in tool_calls:
+ tool_call_list.append(
+ ToolCall(
+ name=tc["function"]["name"],
+ arguments=json.loads(tc["function"]["arguments"]),
+ raw_call=tc,
)
- output.action = f"{function['name']}({func_args_str})"
- output.tool_calls = {
- "role": "assistant",
- "tool_calls": [message["tool_calls"][0]], # Use only the first tool call
- }
- break
- return output
+ )
+ return ToolCalls(tool_calls=tool_call_list, raw_calls=response)
+
+ def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
+ """Extracts actions from the response."""
+ if not toolcalls:
+ return None
+
+ actions = [
+ AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
+ ]
+ actions = (
+ AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
+ if len(actions) > 1
+ else actions[0]
+ )
+ return actions
+
+ def _extract_env_actions_from_text_response(
+ self, response: "openai.types.chat.ChatCompletion"
+ ) -> str | None:
+ """Extracts environment actions from the text response."""
+ # Use when action space is not given as tools.
+ pass
@staticmethod
def format_tools_for_chat_completion(tools):
@@ -483,93 +757,69 @@ def format_tools_for_chat_completion(tools):
]
return formatted_tools
- @staticmethod
- def extract_content_with_reasoning(message, wrap_tag="think"):
- """Extracts the content from the message, including reasoning if available.
- It wraps the reasoning around ... for easy identification of reasoning content,
- When LLM produces 'text' and 'reasoning' in the same message.
- Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered.
-
- Args:
- message: The message object or dict containing content and reasoning.
- wrap_tag: The tag name to wrap reasoning content (default: "think").
-
- Returns:
- str: The extracted content with reasoning wrapped in specified tags.
- """
- if not isinstance(message, dict):
- message = message.to_dict()
-
- reasoning_content = message.get("reasoning", None)
- msg_content = message.get("text", "") # works for OR
-
- if reasoning_content:
- # Wrap reasoning in tags with newlines for clarity
- reasoning_content = f"<{wrap_tag}>{reasoning_content}{wrap_tag}>\n"
- logging.debug("Extracting content from response.choices[i].message.reasoning")
- else:
- reasoning_content = ""
- return f"{reasoning_content}{msg_content}{message.get('content', '')}"
-
class ClaudeResponseModel(BaseModelWithPricing):
def __init__(
self,
model_name: str,
+ base_url: Optional[str] = None,
api_key: Optional[str] = None,
- temperature: float = 0.5,
- max_tokens: int = 100,
- extra_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs,
+ temperature: float | None = None,
+ max_tokens: int | None = 100,
):
- self.tools = kwargs.pop("tools", None)
+ self.action_space_as_tools = True # this should be a config
super().__init__(
model_name=model_name,
api_key=api_key,
temperature=temperature,
max_tokens=max_tokens,
- extra_kwargs=extra_kwargs,
- **kwargs,
)
-
- self.client = Anthropic(api_key=api_key)
-
- def _call_api(
- self, messages: list[dict | MessageBuilder], tool_choice="auto", **kwargs
- ) -> dict:
- input = []
-
- sys_msg, other_msgs = self.filter_system_messages(messages)
+ client_args = {}
+ if base_url is not None:
+ client_args["base_url"] = base_url
+ if api_key is not None:
+ client_args["api_key"] = api_key
+ self.client = Anthropic(**client_args)
+ self.init_pricing_tracker(pricing_api="anthropic") # Use the PricingMixin
+
+ def _call_api(self, payload: APIPayload) -> Completion:
+ sys_msg, other_msgs = self.filter_system_messages(payload.messages)
sys_msg_text = "\n".join(c["text"] for m in sys_msg for c in m.content)
+ input = []
for msg in other_msgs:
- temp = msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]
- if kwargs.pop("use_cache_breakpoints", False):
+ temp = msg.prepare_message()
+ if payload.use_cache_breakpoints:
temp = self.apply_cache_breakpoints(msg, temp)
input.extend(temp)
- if tool_choice in ("any", "required"):
- tool_choice = "any" # Claude API expects "any" and gpt expects "required"
-
api_params: Dict[str, Any] = {
"model": self.model_name,
"messages": input,
- "temperature": self.temperature,
- "max_tokens": self.max_tokens,
- "system": sys_msg_text, # Anthropic API expects system message as a string
- "tool_choice": {"type": tool_choice}, # Tool choice for Claude API
- **self.extra_kwargs, # Pass tools, tool_choice, etc. here
- }
- if self.tools is not None:
- api_params["tools"] = self.tools
- if kwargs.pop("cache_tool_definition", False):
- # Indicating cache control for the last tool enables caching of all previous tool definitions.
+ "system": sys_msg_text,
+ } # Anthropic API expects system message as a string
+
+ if self.temperature is not None:
+ api_params["temperature"] = self.temperature
+ if self.max_tokens is not None:
+ api_params["max_tokens"] = self.max_tokens
+
+ if payload.tools is not None:
+ api_params["tools"] = payload.tools
+ if payload.tool_choice is not None and payload.force_call_tool is None:
+ api_params["tool_choice"] = (
+ {"type": "any"}
+ if payload.tool_choice in ("required", "any")
+ else {"type": payload.tool_choice}
+ )
+ if payload.force_call_tool is not None:
+ api_params["tool_choice"] = {"type": "tool", "name": payload.force_call_tool}
+ if payload.cache_tool_definition:
+ # Indicating cache control for the last message enables caching of the last message.
api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"}
- if kwargs.pop("cache_complete_prompt", False):
+ if payload.cache_complete_prompt:
# Indicating cache control for the last message enables caching of the complete prompt.
api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
- if self.extra_kwargs.get("reasoning", None) is not None:
- api_params["reasoning"] = self.extra_kwargs["reasoning"]
response = call_anthropic_api_with_retries(self.client.messages.create, api_params)
@@ -592,26 +842,58 @@ def filter_system_messages(messages: list[dict | MessageBuilder]) -> tuple[Messa
other_msgs.append(msg)
return sys_msgs, other_msgs
- def _parse_response(self, response: dict) -> dict:
- result = LLMOutput(
+ def _parse_response(self, response: "AnthrophicMessage") -> LLMOutput:
+
+ toolcalls = self._extract_tool_calls_from_response(response)
+ think_output = self._extract_thinking_content_from_response(response)
+ if self.action_space_as_tools:
+ env_action = self._extract_env_actions_from_toolcalls(toolcalls)
+ else:
+ env_action = self._extract_env_actions_from_text_response(response)
+ return LLMOutput(
raw_response=response,
- think="",
- action=None,
- tool_calls={
- "role": "assistant",
- "content": response.content,
- },
+ think=think_output,
+ action=env_action if env_action is not None else None,
+ tool_calls=toolcalls if toolcalls is not None else None,
)
+
+ def _extract_tool_calls_from_response(self, response: "AnthrophicMessage") -> ToolCalls:
+ """Extracts tool calls from the response."""
+ tool_calls = []
for output in response.content:
if output.type == "tool_use":
- result.action = tool_call_to_python_code(output.name, output.input)
- elif output.type == "text":
- result.think += output.text
- return result
+ tool_calls.append(
+ ToolCall(
+ name=output.name,
+ arguments=output.input,
+ raw_call=output,
+ )
+ )
+ return ToolCalls(tool_calls=tool_calls, raw_calls=response)
+
+ def _extract_thinking_content_from_response(self, response: "AnthrophicMessage"):
+ """Extracts the thinking content from the response."""
+ return "".join(output.text for output in response.content if output.type == "text")
+
+ def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
+ """Extracts actions from the response."""
+ if not toolcalls:
+ return None
+
+ actions = [
+ AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
+ ]
+ actions = (
+ AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
+ if len(actions) > 1
+ else actions[0]
+ )
+ return actions
- # def ensure_cache_conditions(self, msgs: List[Message]) -> bool:
- # """Ensure API specific cache conditions are met."""
- # assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message."
+ def _extract_env_actions_from_text_response(self, response: "AnthrophicMessage") -> str | None:
+ """Extracts environment actions from the text response."""
+ # Use when action space is not given as tools.
+ pass
def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Message]:
"""Apply cache breakpoints to the messages."""
@@ -621,6 +903,8 @@ def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Mess
# Factory classes to create the appropriate model based on the API endpoint.
+
+
@dataclass
class OpenAIResponseModelArgs(BaseModelArgs):
"""Serializable object for instantiating a generic chat model with an OpenAI
@@ -628,14 +912,11 @@ class OpenAIResponseModelArgs(BaseModelArgs):
api = "openai"
- def make_model(self, extra_kwargs=None, **kwargs):
+ def make_model(self):
return OpenAIResponseModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
- extra_kwargs=extra_kwargs,
- pricing_api="openai",
- **kwargs,
)
def get_message_builder(self) -> MessageBuilder:
@@ -649,14 +930,11 @@ class ClaudeResponseModelArgs(BaseModelArgs):
api = "anthropic"
- def make_model(self, extra_kwargs=None, **kwargs):
+ def make_model(self):
return ClaudeResponseModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
- extra_kwargs=extra_kwargs,
- pricing_api="anthropic",
- **kwargs,
)
def get_message_builder(self) -> MessageBuilder:
@@ -670,14 +948,11 @@ class OpenAIChatModelArgs(BaseModelArgs):
api = "openai"
- def make_model(self, extra_kwargs=None, **kwargs):
+ def make_model(self):
return OpenAIChatCompletionModel(
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
- extra_kwargs=extra_kwargs,
- pricing_api="openai",
- **kwargs,
)
def get_message_builder(self) -> MessageBuilder:
@@ -691,42 +966,13 @@ class OpenRouterModelArgs(BaseModelArgs):
api: str = "openai" # tool description format used by actionset.to_tool_description() in bgym
- def make_model(self, extra_kwargs=None, **kwargs):
+ def make_model(self):
return OpenAIChatCompletionModel(
- client_args={
- "base_url": "https://openrouter.ai/api/v1",
- "api_key": os.getenv("OPENROUTER_API_KEY"),
- },
+ base_url="https://openrouter.ai/api/v1",
+ api_key=os.getenv("OPENROUTER_API_KEY"),
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
- extra_kwargs=extra_kwargs,
- pricing_api="openrouter",
- **kwargs,
- )
-
- def get_message_builder(self) -> MessageBuilder:
- return OpenAIChatCompletionAPIMessageBuilder
-
-
-class VLLMModelArgs(BaseModelArgs):
- """Serializable object for instantiating a generic chat model with a VLLM
- model."""
-
- api = "openai" # tool description format used by actionset.to_tool_description() in bgym
-
- def make_model(self, extra_kwargs=None, **kwargs):
- return OpenAIChatCompletionModel(
- client_args={
- "base_url": "http://localhost:8000/v1",
- "api_key": os.getenv("VLLM_API_KEY", "EMPTY"),
- },
- model_name=self.model_name, # this needs to be set
- temperature=self.temperature,
- max_tokens=self.max_new_tokens,
- extra_kwargs=extra_kwargs,
- pricing_api="vllm",
- **kwargs,
)
def get_message_builder(self) -> MessageBuilder:
@@ -743,3 +989,29 @@ def tool_call_to_python_code(func_name, kwargs):
args_str = ", ".join(f"{key}={repr(value)}" for key, value in kwargs.items())
return f"{func_name}({args_str})"
+
+
+# ___Not__Tested__#
+
+# class VLLMModelArgs(BaseModelArgs):
+# """Serializable object for instantiating a generic chat model with a VLLM
+# model."""
+
+# api = "openai" # tool description format used by actionset.to_tool_description() in bgym
+
+# def make_model(self, extra_kwargs=None, **kwargs):
+# return OpenAIChatCompletionModel(
+# client_args={
+# "base_url": "http://localhost:8000/v1",
+# "api_key": os.getenv("VLLM_API_KEY", "EMPTY"),
+# },
+# model_name=self.model_name, # this needs to be set
+# temperature=self.temperature,
+# max_tokens=self.max_new_tokens,
+# extra_kwargs=extra_kwargs,
+# pricing_api="vllm",
+# **kwargs,
+# )
+
+# def get_message_builder(self) -> MessageBuilder:
+# return OpenAIChatCompletionAPIMessageBuilder
diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py
index 23e48930..b8bcce7c 100644
--- a/src/agentlab/llm/tracking.py
+++ b/src/agentlab/llm/tracking.py
@@ -13,7 +13,7 @@
TRACKER = threading.local()
-ANTHROPHIC_CACHE_PRICING_FACTOR = {
+ANTHROPIC_CACHE_PRICING_FACTOR = {
"cache_read_tokens": 0.1, # Cost for 5 min ephemeral cache. See Pricing Here: https://docs.anthropic.com/en/docs/about-claude/pricing#model-pricing
"cache_write_tokens": 1.25,
}
@@ -151,20 +151,20 @@ class TrackAPIPricingMixin:
def reset_stats(self):
self.stats = Stats()
- def __init__(self, *args, **kwargs):
- pricing_api = kwargs.pop("pricing_api", None)
+ def init_pricing_tracker(
+ self, pricing_api=None
+ ): # TODO: Use this function in the base class init instead of having a init in the Mixin class.
self._pricing_api = pricing_api
- super().__init__(*args, **kwargs)
self.set_pricing_attributes()
self.reset_stats()
def __call__(self, *args, **kwargs):
"""Call the API and update the pricing tracker."""
+ # 'self' here calls ._call_api() method of the subclass
response = self._call_api(*args, **kwargs)
-
usage = dict(getattr(response, "usage", {}))
if "prompt_tokens_details" in usage:
- usage["cached_tokens"] = usage["prompt_token_details"].cached_tokens
+ usage["cached_tokens"] = usage["prompt_tokens_details"].cached_tokens
if "input_tokens_details" in usage:
usage["cached_tokens"] = usage["input_tokens_details"].cached_tokens
usage = {f"usage_{k}": v for k, v in usage.items() if isinstance(v, (int, float))}
@@ -274,8 +274,8 @@ def get_effective_cost_from_antrophic_api(self, response) -> float:
cache_read_tokens = getattr(usage, "cache_input_tokens", 0)
cache_write_tokens = getattr(usage, "cache_creation_input_tokens", 0)
- cache_read_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
- cache_write_cost = self.input_cost * ANTHROPHIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
+ cache_read_cost = self.input_cost * ANTHROPIC_CACHE_PRICING_FACTOR["cache_read_tokens"]
+ cache_write_cost = self.input_cost * ANTHROPIC_CACHE_PRICING_FACTOR["cache_write_tokens"]
# Calculate the effective cost
effective_cost = (
@@ -284,6 +284,10 @@ def get_effective_cost_from_antrophic_api(self, response) -> float:
+ cache_read_tokens * cache_read_cost
+ cache_write_tokens * cache_write_cost
)
+ if effective_cost < 0:
+ logging.warning(
+ "Anthropic: Negative effective cost detected.(Impossible! Likely a bug)"
+ )
return effective_cost
def get_effective_cost_from_openai_api(self, response) -> float:
@@ -308,25 +312,29 @@ def get_effective_cost_from_openai_api(self, response) -> float:
return 0.0
api_type = "chatcompletion" if hasattr(usage, "prompt_tokens_details") else "response"
if api_type == "chatcompletion":
- total_input_tokens = usage.prompt_tokens
+ total_input_tokens = usage.prompt_tokens # (cache read tokens + new input tokens)
output_tokens = usage.completion_tokens
cached_input_tokens = usage.prompt_tokens_details.cached_tokens
- non_cached_input_tokens = total_input_tokens - cached_input_tokens
+ new_input_tokens = total_input_tokens - cached_input_tokens
elif api_type == "response":
- total_input_tokens = usage.input_tokens
+ total_input_tokens = usage.input_tokens # (cache read tokens + new input tokens)
output_tokens = usage.output_tokens
cached_input_tokens = usage.input_tokens_details.cached_tokens
- non_cached_input_tokens = total_input_tokens - cached_input_tokens
+ new_input_tokens = total_input_tokens - cached_input_tokens
else:
logging.warning(f"Unsupported API type: {api_type}. Defaulting cost to 0.0.")
return 0.0
-
cache_read_cost = self.input_cost * OPENAI_CACHE_PRICING_FACTOR["cache_read_tokens"]
effective_cost = (
- self.input_cost * non_cached_input_tokens
+ self.input_cost * new_input_tokens
+ cached_input_tokens * cache_read_cost
+ self.output_cost * output_tokens
)
+ if effective_cost < 0:
+ logging.warning(
+ f"OpenAI: Negative effective cost detected.(Impossible! Likely a bug). "
+ f"New input tokens: {total_input_tokens}"
+ )
return effective_cost
diff --git a/tests/agents/test_gaia_agent.py b/tests/agents/test_gaia_agent.py
index 0d39f9ef..604ac00c 100644
--- a/tests/agents/test_gaia_agent.py
+++ b/tests/agents/test_gaia_agent.py
@@ -2,10 +2,15 @@
import uuid
from pathlib import Path
-from tapeagents.steps import ImageObservation
+try:
+ from tapeagents.steps import ImageObservation
-from agentlab.agents.tapeagent.agent import TapeAgent, TapeAgentArgs, load_config
-from agentlab.benchmarks.gaia import GaiaBenchmark, GaiaQuestion
+ from agentlab.agents.tapeagent.agent import TapeAgent, TapeAgentArgs, load_config
+ from agentlab.benchmarks.gaia import GaiaBenchmark, GaiaQuestion
+except ModuleNotFoundError:
+ import pytest
+
+ pytest.skip("Skipping test due to missing dependencies", allow_module_level=True)
def mock_dataset() -> dict:
diff --git a/tests/llm/test_response_api.py b/tests/llm/test_response_api.py
index 567b49da..6bb639f6 100644
--- a/tests/llm/test_response_api.py
+++ b/tests/llm/test_response_api.py
@@ -9,6 +9,7 @@
from agentlab.llm import tracking
from agentlab.llm.response_api import (
AnthropicAPIMessageBuilder,
+ APIPayload,
ClaudeResponseModelArgs,
LLMOutput,
OpenAIChatCompletionAPIMessageBuilder,
@@ -55,6 +56,9 @@ def create_mock_openai_chat_completion(
# or if get_tokens_counts_from_response had different fallback logic.
completion.usage.prompt_tokens = prompt_tokens
completion.usage.completion_tokens = completion_tokens
+ prompt_tokens_details_mock = MagicMock()
+ prompt_tokens_details_mock.cached_tokens = 0
+ completion.usage.prompt_tokens_details = prompt_tokens_details_mock
completion.model_dump.return_value = {
"id": "chatcmpl-xxxx",
@@ -68,6 +72,7 @@ def create_mock_openai_chat_completion(
"output_tokens": completion_tokens, # Generic name
"prompt_tokens": prompt_tokens, # OpenAI specific
"completion_tokens": completion_tokens, # OpenAI specific
+ "prompt_tokens_details": {"cached_tokens": 0},
},
}
message.to_dict.return_value = {
@@ -78,6 +83,69 @@ def create_mock_openai_chat_completion(
return completion
+responses_api_tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the current weather in a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The location to get the weather for.",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The unit of temperature.",
+ },
+ },
+ "required": ["location"],
+ },
+ }
+]
+
+chat_api_tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the current weather in a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The location to get the weather for.",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The unit of temperature.",
+ },
+ },
+ "required": ["location"],
+ },
+ }
+]
+anthropic_tools = [
+ {
+ "name": "get_weather",
+ "description": "Get the current weather in a given location.",
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The location to get the weather for.",
+ },
+ },
+ "required": ["location"],
+ },
+ }
+]
+
+
# Helper to create a mock Anthropic response
def create_mock_anthropic_response(
text_content=None, tool_use=None, input_tokens=15, output_tokens=25
@@ -102,6 +170,8 @@ def create_mock_anthropic_response(
response.usage = MagicMock()
response.usage.input_tokens = input_tokens
response.usage.output_tokens = output_tokens
+ response.usage.cache_input_tokens = 0
+ response.usage.cache_creation_input_tokens = 0
return response
@@ -114,7 +184,7 @@ def create_mock_openai_responses_api_response(
Compatible with OpenAIResponseModel and TrackAPIPricingMixin.
"""
- response_mock = MagicMock(openai.types.responses.response)
+ response_mock = MagicMock(spec=openai.types.responses.response.Response)
response_mock.type = "response"
response_mock.output = []
@@ -138,11 +208,14 @@ def create_mock_openai_responses_api_response(
response_mock.output.append(output_item_mock)
# Token usage for pricing tracking
- response_mock.usage = MagicMock()
+ response_mock.usage = MagicMock(spec=openai.types.responses.response.ResponseUsage)
response_mock.usage.input_tokens = input_tokens
response_mock.usage.output_tokens = output_tokens
response_mock.usage.prompt_tokens = input_tokens
response_mock.usage.completion_tokens = output_tokens
+ input_tokens_details_mock = MagicMock()
+ input_tokens_details_mock.cached_tokens = 0
+ response_mock.usage.input_tokens_details = input_tokens_details_mock
return response_mock
@@ -196,13 +269,6 @@ def test_anthropic_api_message_builder_image():
def test_openai_chat_completion_api_message_builder_text():
builder = OpenAIChatCompletionAPIMessageBuilder.user()
builder.add_text("Hello, ChatCompletion!")
- # Mock last_response as it's used by tool role
- builder.last_raw_response = MagicMock(spec=LLMOutput)
- builder.last_raw_response.raw_response = MagicMock()
- builder.last_raw_response.raw_response.choices = [MagicMock()]
- builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = {
- "tool_calls": [{"function": {"name": "some_function"}}]
- }
messages = builder.prepare_message()
assert len(messages) == 1
@@ -213,13 +279,6 @@ def test_openai_chat_completion_api_message_builder_text():
def test_openai_chat_completion_api_message_builder_image():
builder = OpenAIChatCompletionAPIMessageBuilder.user()
builder.add_image("data:image/jpeg;base64,CHATCOMPLETIONBASE64")
- # Mock last_response
- builder.last_raw_response = MagicMock(spec=LLMOutput)
- builder.last_raw_response.raw_response = MagicMock()
- builder.last_raw_response.raw_response.choices = [MagicMock()]
- builder.last_raw_response.raw_response.choices[0].message.to_dict.return_value = {
- "tool_calls": [{"function": {"name": "some_function"}}]
- }
messages = builder.prepare_message()
assert len(messages) == 1
@@ -230,14 +289,12 @@ def test_openai_chat_completion_api_message_builder_image():
def test_openai_chat_completion_model_parse_and_cost():
- args = OpenAIChatModelArgs(model_name="gpt-3.5-turbo") # A cheap model for testing
- # Mock the OpenAI client to avoid needing OPENAI_API_KEY
+ args = OpenAIChatModelArgs(model_name="gpt-3.5-turbo")
with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class:
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
model = args.make_model()
- # Mock the API call
mock_response = create_mock_openai_chat_completion(
content="This is a test thought.",
tool_calls=[
@@ -254,17 +311,18 @@ def test_openai_chat_completion_model_parse_and_cost():
with patch.object(
model.client.chat.completions, "create", return_value=mock_response
) as mock_create:
- with tracking.set_tracker() as global_tracker: # Use your global tracker
+ with tracking.set_tracker() as global_tracker:
messages = [
- OpenAIChatCompletionAPIMessageBuilder.user()
- .add_text("What's the weather in Paris?")
- .prepare_message()[0]
+ OpenAIChatCompletionAPIMessageBuilder.user().add_text(
+ "What's the weather in Paris?"
+ )
]
- parsed_output = model(messages)
+ payload = APIPayload(messages=messages)
+ parsed_output = model(payload)
mock_create.assert_called_once()
assert parsed_output.raw_response.choices[0].message.content == "This is a test thought."
- assert parsed_output.action == 'get_weather(location="Paris")'
+ assert parsed_output.action == """get_weather(location='Paris')"""
assert parsed_output.raw_response.choices[0].message.tool_calls[0].id == "call_123"
# Check cost tracking (token counts)
assert global_tracker.stats["input_tokens"] == 50
@@ -273,7 +331,7 @@ def test_openai_chat_completion_model_parse_and_cost():
def test_claude_response_model_parse_and_cost():
- args = ClaudeResponseModelArgs(model_name="claude-3-haiku-20240307") # A cheap model
+ args = ClaudeResponseModelArgs(model_name="claude-3-haiku-20240307")
model = args.make_model()
mock_anthropic_api_response = create_mock_anthropic_response(
@@ -287,33 +345,23 @@ def test_claude_response_model_parse_and_cost():
model.client.messages, "create", return_value=mock_anthropic_api_response
) as mock_create:
with tracking.set_tracker() as global_tracker:
- messages = [
- AnthropicAPIMessageBuilder.user()
- .add_text("Search for latest news")
- .prepare_message()[0]
- ]
- parsed_output = model(messages)
+ messages = [AnthropicAPIMessageBuilder.user().add_text("Search for latest news")]
+ payload = APIPayload(messages=messages)
+ parsed_output = model(payload)
mock_create.assert_called_once()
- fn_calls = [
- content for content in parsed_output.raw_response.content if content.type == "tool_use"
- ]
+ fn_call = next(iter(parsed_output.tool_calls))
+
assert "Thinking about the request." in parsed_output.think
- assert parsed_output.action == "search_web(query='latest news')"
- assert fn_calls[0].id == "tool_abc"
+ assert parsed_output.action == """search_web(query='latest news')"""
+ assert fn_call.name == "search_web"
assert global_tracker.stats["input_tokens"] == 40
assert global_tracker.stats["output_tokens"] == 20
- # assert global_tracker.stats["cost"] > 0 # Verify cost is calculated
def test_openai_response_model_parse_and_cost():
- """
- Tests OpenAIResponseModel output parsing and cost tracking with both
- function_call and reasoning outputs.
- """
args = OpenAIResponseModelArgs(model_name="gpt-4.1")
- # Mock outputs
mock_function_call_output = {
"type": "function_call",
"name": "get_current_weather",
@@ -327,7 +375,6 @@ def test_openai_response_model_parse_and_cost():
output_tokens=40,
)
- # Mock the OpenAI client to avoid needing OPENAI_API_KEY
with patch("agentlab.llm.response_api.OpenAI") as mock_openai_class:
mock_client = MagicMock()
mock_openai_class.return_value = mock_client
@@ -338,15 +385,16 @@ def test_openai_response_model_parse_and_cost():
) as mock_create_method:
with tracking.set_tracker() as global_tracker:
messages = [
- OpenAIResponseAPIMessageBuilder.user()
- .add_text("What's the weather in Boston?")
- .prepare_message()[0]
+ OpenAIResponseAPIMessageBuilder.user().add_text("What's the weather in Boston?")
]
- parsed_output = model(messages)
+ payload = APIPayload(messages=messages)
+ parsed_output = model(payload)
mock_create_method.assert_called_once()
fn_calls = [
- content for content in parsed_output.raw_response.output if content.type == "function_call"
+ content
+ for content in parsed_output.tool_calls.raw_calls.output
+ if content.type == "function_call"
]
assert parsed_output.action == "get_current_weather(location='Boston, MA', unit='celsius')"
assert fn_calls[0].call_id == "call_abc123"
@@ -368,43 +416,20 @@ def test_openai_chat_completion_model_pricy_call():
max_new_tokens=100,
)
- tools = [
- {
- "type": "function",
- "name": "get_weather",
- "description": "Get the current weather in a given location.",
- "parameters": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "The location to get the weather for.",
- },
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"],
- "description": "The unit of temperature.",
- },
- },
- "required": ["location"],
- },
- }
- ]
-
- model = args.make_model(tools=tools, tool_choice="required")
+ tools = chat_api_tools
+ model = args.make_model()
with tracking.set_tracker() as global_tracker:
messages = [
- OpenAIChatCompletionAPIMessageBuilder.user()
- .add_text("What is the weather in Paris?")
- .prepare_message()[0]
+ OpenAIChatCompletionAPIMessageBuilder.user().add_text("What is the weather in Paris?")
]
- parsed_output = model(messages)
+ payload = APIPayload(messages=messages, tools=tools, tool_choice="required")
+ parsed_output = model(payload)
assert parsed_output.raw_response is not None
assert (
- parsed_output.action == 'get_weather(location="Paris")'
- ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}"""
+ parsed_output.action == "get_weather(location='Paris')"
+ ), f""" Expected get_weather(location='Paris') but got {parsed_output.action}"""
assert global_tracker.stats["input_tokens"] > 0
assert global_tracker.stats["output_tokens"] > 0
assert global_tracker.stats["cost"] > 0
@@ -420,36 +445,18 @@ def test_claude_response_model_pricy_call():
temperature=1e-5,
max_new_tokens=100,
)
- tools = [
- {
- "name": "get_weather",
- "description": "Get the current weather in a given location.",
- "input_schema": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "The location to get the weather for.",
- },
- },
- "required": ["location"],
- },
- }
- ]
- model = args.make_model(tools=tools)
+ tools = anthropic_tools
+ model = args.make_model()
with tracking.set_tracker() as global_tracker:
- messages = [
- AnthropicAPIMessageBuilder.user()
- .add_text("What is the weather in Paris?")
- .prepare_message()[0]
- ]
- parsed_output = model(messages)
+ messages = [AnthropicAPIMessageBuilder.user().add_text("What is the weather in Paris?")]
+ payload = APIPayload(messages=messages, tools=tools)
+ parsed_output = model(payload)
assert parsed_output.raw_response is not None
assert (
- parsed_output.action == 'get_weather(location="Paris")'
- ), f'Expected get_weather("Paris") but got {parsed_output.action}'
+ parsed_output.action == "get_weather(location='Paris')"
+ ), f"""Expected get_weather('Paris') but got {parsed_output.action}"""
assert global_tracker.stats["input_tokens"] > 0
assert global_tracker.stats["output_tokens"] > 0
assert global_tracker.stats["cost"] > 0
@@ -464,42 +471,20 @@ def test_openai_response_model_pricy_call():
"""
args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100)
- tools = [
- {
- "type": "function",
- "name": "get_weather",
- "description": "Get the current weather in a given location.",
- "parameters": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "The location to get the weather for.",
- },
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"],
- "description": "The unit of temperature.",
- },
- },
- "required": ["location"],
- },
- }
- ]
- model = args.make_model(tools=tools)
+ tools = responses_api_tools
+ model = args.make_model()
with tracking.set_tracker() as global_tracker:
messages = [
- OpenAIResponseAPIMessageBuilder.user()
- .add_text("What is the weather in Paris?")
- .prepare_message()[0]
+ OpenAIResponseAPIMessageBuilder.user().add_text("What is the weather in Paris?")
]
- parsed_output = model(messages)
+ payload = APIPayload(messages=messages, tools=tools)
+ parsed_output = model(payload)
assert parsed_output.raw_response is not None
assert (
- parsed_output.action == """get_weather(location="Paris")"""
- ), f""" Expected get_weather(location="Paris") but got {parsed_output.action}"""
+ parsed_output.action == """get_weather(location='Paris')"""
+ ), f""" Expected get_weather(location='Paris') but got {parsed_output.action}"""
assert global_tracker.stats["input_tokens"] > 0
assert global_tracker.stats["output_tokens"] > 0
assert global_tracker.stats["cost"] > 0
@@ -514,61 +499,43 @@ def test_openai_response_model_with_multiple_messages_and_cost_tracking():
"""
args = OpenAIResponseModelArgs(model_name="gpt-4.1", temperature=1e-5, max_new_tokens=100)
- tools = [
- {
- "type": "function",
- "name": "get_weather",
- "description": "Get the current weather in a given location.",
- "parameters": {
- "type": "object",
- "properties": {
- "location": {
- "type": "string",
- "description": "The location to get the weather for.",
- },
- "unit": {
- "type": "string",
- "enum": ["celsius", "fahrenheit"],
- "description": "The unit of temperature.",
- },
- },
- "required": ["location"],
- },
- }
- ]
-
- model = args.make_model(tools=tools, tool_choice="required")
+ tools = responses_api_tools
+ model = args.make_model()
builder = args.get_message_builder()
messages = [builder.user().add_text("What is the weather in Paris?")]
with tracking.set_tracker() as tracker:
- # First turn: get initial tool call
- parsed = model(messages)
+ payload = APIPayload(messages=messages, tools=tools, tool_choice="required")
+ parsed = model(payload)
prev_input = tracker.stats["input_tokens"]
prev_output = tracker.stats["output_tokens"]
prev_cost = tracker.stats["cost"]
+ assert parsed.tool_calls, "Expected tool calls in the response"
+ # Set tool responses
+ for tool_call in parsed.tool_calls:
+ tool_call.response_text("Its sunny! 25°C")
# Simulate tool execution and user follow-up
messages += [
- parsed.tool_calls, # Add tool call from the model
- builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"),
+ builder.add_responded_tool_calls(parsed.tool_calls),
builder.user().add_text("What is the weather in Delhi?"),
]
- parsed = model(messages)
+ payload = APIPayload(messages=messages, tools=tools, tool_choice="required")
+ parsed = model(payload)
- # Token and cost deltas
delta_input = tracker.stats["input_tokens"] - prev_input
delta_output = tracker.stats["output_tokens"] - prev_output
delta_cost = tracker.stats["cost"] - prev_cost
- # Assertions
assert prev_input > 0
assert prev_output > 0
assert prev_cost > 0
assert parsed.raw_response is not None
- assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}"
+ assert (
+ parsed.action == """get_weather(location='Delhi')"""
+ ), f"Unexpected action: {parsed.action}"
assert delta_input > 0
assert delta_output > 0
assert delta_cost > 0
@@ -609,38 +576,41 @@ def test_openai_chat_completion_model_with_multiple_messages_and_cost_tracking()
}
]
- model = args.make_model(tools=tools, tool_choice="required")
+ model = args.make_model()
builder = args.get_message_builder()
messages = [builder.user().add_text("What is the weather in Paris?")]
with tracking.set_tracker() as tracker:
- # First turn: get initial tool call
- parsed = model(messages)
+ payload = APIPayload(messages=messages, tools=tools, tool_choice="required")
+ parsed = model(payload)
prev_input = tracker.stats["input_tokens"]
prev_output = tracker.stats["output_tokens"]
prev_cost = tracker.stats["cost"]
+ for tool_call in parsed.tool_calls:
+ tool_call.response_text("Its sunny! 25°C")
# Simulate tool execution and user follow-up
messages += [
- parsed.tool_calls, # Add tool call from the model
- builder.tool(parsed.raw_response).add_text("Its sunny! 25°C"),
+ builder.add_responded_tool_calls(parsed.tool_calls),
builder.user().add_text("What is the weather in Delhi?"),
]
+ # Set tool responses
- parsed = model(messages)
+ payload = APIPayload(messages=messages, tools=tools, tool_choice="required")
+ parsed = model(payload)
- # Token and cost deltas
delta_input = tracker.stats["input_tokens"] - prev_input
delta_output = tracker.stats["output_tokens"] - prev_output
delta_cost = tracker.stats["cost"] - prev_cost
- # Assertions
assert prev_input > 0
assert prev_output > 0
assert prev_cost > 0
assert parsed.raw_response is not None
- assert parsed.action == 'get_weather(location="Delhi")', f"Unexpected action: {parsed.action}"
+ assert (
+ parsed.action == """get_weather(location='Delhi')"""
+ ), f"Unexpected action: {parsed.action}"
assert delta_input > 0
assert delta_output > 0
assert delta_cost > 0
@@ -676,35 +646,38 @@ def test_claude_model_with_multiple_messages_pricy_call():
},
}
]
- model = model_factory.make_model(tools=tools)
+ model = model_factory.make_model()
msg_builder = model_factory.get_message_builder()
messages = []
messages.append(msg_builder.user().add_text("What is the weather in Paris?"))
with tracking.set_tracker() as global_tracker:
- llm_output1 = model(messages)
+ payload = APIPayload(messages=messages, tools=tools)
+ llm_output1 = model(payload)
prev_input = global_tracker.stats["input_tokens"]
prev_output = global_tracker.stats["output_tokens"]
prev_cost = global_tracker.stats["cost"]
- messages.append(llm_output1.tool_calls)
- messages.append(msg_builder.tool(llm_output1.raw_response).add_text("Its sunny! 25°C"))
- messages.append(msg_builder.user().add_text("What is the weather in Delhi?"))
- llm_output2 = model(messages)
- # Token and cost deltas
+ for tool_call in llm_output1.tool_calls:
+ tool_call.response_text("It's sunny! 25°C")
+ messages += [
+ msg_builder.add_responded_tool_calls(llm_output1.tool_calls),
+ msg_builder.user().add_text("What is the weather in Delhi?"),
+ ]
+ payload = APIPayload(messages=messages, tools=tools)
+ llm_output2 = model(payload)
delta_input = global_tracker.stats["input_tokens"] - prev_input
delta_output = global_tracker.stats["output_tokens"] - prev_output
delta_cost = global_tracker.stats["cost"] - prev_cost
- # Assertions
assert prev_input > 0, "Expected previous input tokens to be greater than 0"
assert prev_output > 0, "Expected previous output tokens to be greater than 0"
assert prev_cost > 0, "Expected previous cost value to be greater than 0"
assert llm_output2.raw_response is not None
assert (
- llm_output2.action == 'get_weather(location="Delhi", unit="celsius")'
- ), f'Expected get_weather("Delhi") but got {llm_output2.action}'
+ llm_output2.action == """get_weather(location='Delhi', unit='celsius')"""
+ ), f"""Expected get_weather('Delhi') but got {llm_output2.action}"""
assert delta_input > 0, "Expected new input tokens to be greater than 0"
assert delta_output > 0, "Expected new output tokens to be greater than 0"
assert delta_cost > 0, "Expected new cost value to be greater than 0"
@@ -713,9 +686,71 @@ def test_claude_model_with_multiple_messages_pricy_call():
assert global_tracker.stats["cost"] == pytest.approx(prev_cost + delta_cost)
-# TODO: Add tests for image token costing (this is complex and model-specific)
-# - For OpenAI, you'd need to know how they bill for images (e.g., fixed cost per image + tokens for text parts)
-# - You'd likely need to mock the response from client.chat.completions.create to include specific usage for images.
+## Test multiaction
+@pytest.mark.pricy
+def test_multi_action_tool_calls():
+ """
+ Test that the model can produce multiple tool calls in parallel.
+ Uncomment commented lines to see the full behaviour of models and tool choices.
+ """
+ # test_config (setting name, BaseModelArgs, model_name, tools)
+ tool_test_configs = [
+ (
+ "gpt-4.1-responses API",
+ OpenAIResponseModelArgs,
+ "gpt-4.1-2025-04-14",
+ responses_api_tools,
+ ),
+ ("gpt-4.1-chat Completions API", OpenAIChatModelArgs, "gpt-4.1-2025-04-14", chat_api_tools),
+ # ("claude-3", ClaudeResponseModelArgs, "claude-3-haiku-20240307", anthropic_tools), # fails
+ # ("claude-3.7", ClaudeResponseModelArgs, "claude-3-7-sonnet-20250219", anthropic_tools), # fails
+ ("claude-4-sonnet", ClaudeResponseModelArgs, "claude-sonnet-4-20250514", anthropic_tools),
+ # add more models as needed
+ ]
+
+ def add_user_messages(msg_builder):
+ return [
+ msg_builder.user().add_text("What is the weather in Paris and Delhi?"),
+ msg_builder.user().add_text("You must call multiple tools to achieve the task."),
+ ]
+
+ res_df = []
+
+ for tool_choice in [
+ # 'none',
+ # 'required', # fails for Responses API
+ # 'any', # fails for Responses API
+ "auto",
+ # 'get_weather'
+ ]:
+ for name, llm_class, checkpoint_name, tools in tool_test_configs:
+ print(name, "tool choice:", tool_choice, "\n", "**" * 10)
+ model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None)
+ llm, msg_builder = model_args.make_model(), model_args.get_message_builder()
+ messages = add_user_messages(msg_builder)
+ if tool_choice == "get_weather": # force a specific tool call
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice)
+ )
+ else:
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, tool_choice=tool_choice)
+ )
+ num_tool_calls = len(response.tool_calls) if response.tool_calls else 0
+ res_df.append(
+ {
+ "model": name,
+ "checkpoint": checkpoint_name,
+ "tool_choice": tool_choice,
+ "num_tool_calls": num_tool_calls,
+ "action": response.action,
+ }
+ )
+ assert (
+ num_tool_calls == 2
+ ), f"Expected 2 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice}"
+ # import pandas as pd
+ # print(pd.DataFrame(res_df))
EDGE_CASES = [