diff --git a/src/agentlab/agents/tool_use_agent/agent.py b/src/agentlab/agents/tool_use_agent/agent.py index 7c115362..26949462 100644 --- a/src/agentlab/agents/tool_use_agent/agent.py +++ b/src/agentlab/agents/tool_use_agent/agent.py @@ -5,60 +5,28 @@ import bgym import numpy as np -from browsergym.core.observation import extract_screenshot from PIL import Image, ImageDraw +from agentlab.agents import agent_utils from agentlab.agents.agent_args import AgentArgs from agentlab.llm.llm_utils import image_to_png_base64_url from agentlab.llm.response_api import ( + BaseModelArgs, ClaudeResponseModelArgs, MessageBuilder, + OpenAIChatModelArgs, OpenAIResponseModelArgs, + OpenRouterModelArgs, ResponseLLMOutput, + VLLMModelArgs, ) from agentlab.llm.tracking import cost_tracker_decorator +from browsergym.core.observation import extract_screenshot if TYPE_CHECKING: from openai.types.responses import Response -def tag_screenshot_with_action(screenshot: Image, action: str) -> Image: - """ - If action is a coordinate action, try to render it on the screenshot. - - e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot - - Args: - screenshot: The screenshot to tag. - action: The action to tag the screenshot with. - - Returns: - The tagged screenshot. - - Raises: - ValueError: If the action parsing fails. - """ - if action.startswith("mouse_click"): - try: - coords = action[action.index("(") + 1 : action.index(")")].split(",") - coords = [c.strip() for c in coords] - if len(coords) != 2: - raise ValueError(f"Invalid coordinate format: {coords}") - if coords[0].startswith("x="): - coords[0] = coords[0][2:] - if coords[1].startswith("y="): - coords[1] = coords[1][2:] - x, y = float(coords[0].strip()), float(coords[1].strip()) - draw = ImageDraw.Draw(screenshot) - radius = 5 - draw.ellipse( - (x - radius, y - radius, x + radius, y + radius), fill="red", outline="red" - ) - except (ValueError, IndexError) as e: - logging.warning(f"Failed to parse action '{action}': {e}") - return screenshot - - @dataclass class ToolUseAgentArgs(AgentArgs): model_args: OpenAIResponseModelArgs = None @@ -97,19 +65,9 @@ def __init__( self.model_args = model_args self.use_first_obs = use_first_obs self.tag_screenshot = tag_screenshot - self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False) - self.tools = self.action_set.to_tool_description(api=model_args.api) - # count tools tokens - from agentlab.llm.llm_utils import count_tokens - - tool_str = json.dumps(self.tools, indent=2) - print(f"Tool description: {tool_str}") - tool_tokens = count_tokens(tool_str, model_args.model_name) - print(f"Tool tokens: {tool_tokens}") - self.call_ids = [] # self.tools.append( @@ -131,7 +89,7 @@ def __init__( # ) self.llm = model_args.make_model(extra_kwargs={"tools": self.tools}) - + self.msg_builder = model_args.get_message_builder() self.messages: list[MessageBuilder] = [] def obs_preprocessor(self, obs): @@ -140,7 +98,7 @@ def obs_preprocessor(self, obs): obs["screenshot"] = extract_screenshot(page) if self.tag_screenshot: screenshot = Image.fromarray(obs["screenshot"]) - screenshot = tag_screenshot_with_action(screenshot, obs["last_action"]) + screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"]) obs["screenshot_tag"] = np.array(screenshot) else: raise ValueError("No page found in the observation.") @@ -150,56 +108,31 @@ def obs_preprocessor(self, obs): @cost_tracker_decorator def get_action(self, obs: Any) -> float: if len(self.messages) == 0: - system_message = MessageBuilder.system().add_text( - "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." - ) - self.messages.append(system_message) - - goal_message = MessageBuilder.user() - for content in obs["goal_object"]: - if content["type"] == "text": - goal_message.add_text(content["text"]) - elif content["type"] == "image_url": - goal_message.add_image(content["image_url"]) - self.messages.append(goal_message) - - extra_info = [] - - extra_info.append( - """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" - ) - - self.messages.append(MessageBuilder.user().add_text("\n".join(extra_info))) - - if self.use_first_obs: - msg = "Here is the first observation." - screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" - if self.tag_screenshot: - msg += " A red dot on screenshots indicate the previous click action." - message = MessageBuilder.user().add_text(msg) - message.add_image(image_to_png_base64_url(obs[screenshot_key])) - self.messages.append(message) + self.initalize_messages(obs) else: - if obs["last_action_error"] == "": + if obs["last_action_error"] == "": # Check No error in the last action screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" - tool_message = MessageBuilder.tool().add_image( + tool_message = self.msg_builder.tool().add_image( image_to_png_base64_url(obs[screenshot_key]) ) + tool_message.update_last_raw_response(self.last_response) tool_message.add_tool_id(self.previous_call_id) self.messages.append(tool_message) else: - tool_message = MessageBuilder.tool().add_text( + tool_message = self.msg_builder.tool().add_text( f"Function call failed: {obs['last_action_error']}" ) tool_message.add_tool_id(self.previous_call_id) + tool_message.update_last_raw_response(self.last_response) self.messages.append(tool_message) response: ResponseLLMOutput = self.llm(messages=self.messages) action = response.action think = response.think + self.last_response = response self.previous_call_id = response.last_computer_call_id - self.messages.append(response.assistant_message) + self.messages.append(response.assistant_message) # this is tool call return ( action, @@ -210,6 +143,37 @@ def get_action(self, obs: Any) -> float: ), ) + def initalize_messages(self, obs: Any) -> None: + system_message = self.msg_builder.system().add_text( + "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal." + ) + self.messages.append(system_message) + + goal_message = self.msg_builder.user() + for content in obs["goal_object"]: + if content["type"] == "text": + goal_message.add_text(content["text"]) + elif content["type"] == "image_url": + goal_message.add_image(content["image_url"]) + self.messages.append(goal_message) + + extra_info = [] + + extra_info.append( + """Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n""" + ) + + self.messages.append(self.msg_builder.user().add_text("\n".join(extra_info))) + + if self.use_first_obs: + msg = "Here is the first observation." + screenshot_key = "screenshot_tag" if self.tag_screenshot else "screenshot" + if self.tag_screenshot: + msg += " A red dot on screenshots indicate the previous click action." + message = self.msg_builder.user().add_text(msg) + message.add_image(image_to_png_base64_url(obs[screenshot_key])) + self.messages.append(message) + OPENAI_MODEL_CONFIG = OpenAIResponseModelArgs( model_name="gpt-4.1", @@ -220,6 +184,14 @@ def get_action(self, obs: Any) -> float: vision_support=True, ) +OPENAI_CHATAPI_MODEL_CONFIG = OpenAIChatModelArgs( + model_name="gpt-4o-2024-08-06", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.1, + vision_support=True, +) CLAUDE_MODEL_CONFIG = ClaudeResponseModelArgs( model_name="claude-3-7-sonnet-20250219", @@ -231,6 +203,103 @@ def get_action(self, obs: Any) -> float: ) + + +# def get_openrouter_model(model_name: str, **open_router_args) -> OpenRouterModelArgs: +# default_model_args = { +# "max_total_tokens": 200_000, +# "max_input_tokens": 180_000, +# "max_new_tokens": 2_000, +# "temperature": 0.1, +# "vision_support": True, +# } +# merged_args = {**default_model_args, **open_router_args} + +# return OpenRouterModelArgs(model_name=model_name, **merged_args) + + +# def get_openrouter_tool_use_agent( +# model_name: str, +# model_args: dict = {}, +# use_first_obs=True, +# tag_screenshot=True, +# use_raw_page_output=True, +# ) -> ToolUseAgentArgs: +# # To Do : Check if OpenRouter endpoint specific args are working +# if not supports_tool_calling(model_name): +# raise ValueError(f"Model {model_name} does not support tool calling.") + +# model_args = get_openrouter_model(model_name, **model_args) + +# return ToolUseAgentArgs( +# model_args=model_args, +# use_first_obs=use_first_obs, +# tag_screenshot=tag_screenshot, +# use_raw_page_output=use_raw_page_output, +# ) + + +# OPENROUTER_MODEL = get_openrouter_tool_use_agent("google/gemini-2.5-pro-preview") + + AGENT_CONFIG = ToolUseAgentArgs( model_args=CLAUDE_MODEL_CONFIG, ) + +# MT_TOOL_USE_AGENT = ToolUseAgentArgs( +# model_args=OPENROUTER_MODEL, +# ) +CHATAPI_AGENT_CONFIG = ToolUseAgentArgs( + model_args=OpenAIChatModelArgs( + model_name="gpt-4o-2024-11-20", + max_total_tokens=200_000, + max_input_tokens=200_000, + max_new_tokens=2_000, + temperature=0.7, + vision_support=True, + ), +) + + +OAI_CHAT_TOOl_AGENT = ToolUseAgentArgs( + model_args=OpenAIChatModelArgs(model_name="gpt-4o-2024-08-06") +) + + +PROVIDER_FACTORY_MAP = { + "openai": {"chatcompletion": OpenAIChatModelArgs, "response": OpenAIResponseModelArgs}, + "openrouter": OpenRouterModelArgs, + "vllm": VLLMModelArgs, + "antrophic": ClaudeResponseModelArgs, +} + + +def get_tool_use_agent( + api_provider: str, + model_args: "BaseModelArgs", + tool_use_agent_args: dict = None, + api_provider_spec=None, +) -> ToolUseAgentArgs: + + if api_provider == "openai": + assert ( + api_provider_spec is not None + ), "Endpoint specification is required for OpenAI provider. Choose between 'chatcompletion' and 'response'." + + model_args_factory = ( + PROVIDER_FACTORY_MAP[api_provider] + if api_provider_spec is None + else PROVIDER_FACTORY_MAP[api_provider][api_provider_spec] + ) + + # Create the agent with model arguments from the factory + agent = ToolUseAgentArgs( + model_args=model_args_factory(**model_args), **(tool_use_agent_args or {}) + ) + return agent + + +## We have three providers that we want to support. +# Anthropic +# OpenAI +# vllm (uses OpenAI API) diff --git a/src/agentlab/llm/llm_utils.py b/src/agentlab/llm/llm_utils.py index 3dceda83..9f70ca42 100644 --- a/src/agentlab/llm/llm_utils.py +++ b/src/agentlab/llm/llm_utils.py @@ -12,6 +12,7 @@ from warnings import warn import numpy as np +import openai import tiktoken import yaml from langchain.schema import BaseMessage @@ -90,6 +91,102 @@ def retry( raise ParseError(f"Could not parse a valid value after {n_retry} retries.") +def call_with_retries(client_function, api_params, max_retries=5): + """ + Makes a API call with retries for transient failures, + rate limiting, and invalid or error-containing responses. + + Args: + client_function (Callable): Function to call the API (e.g., openai.ChatCompletion.create). + api_params (dict): Parameters to pass to the API function. + max_retries (int): Maximum number of retry attempts. + + Returns: + response: Valid API response object. + """ + for attempt in range(1, max_retries + 1): + try: + response = client_function(**api_params) + + # Check for explicit error field in response object + if getattr(response, "error", None): + logging.warning( + f"[Attempt {attempt}] API returned error: {response.error}. Retrying..." + ) + continue + + # Check for valid response with choices + if hasattr(response, "choices") and response.choices: + logging.info(f"[Attempt {attempt}] API call succeeded.") + return response + + logging.warning( + f"[Attempt {attempt}] API returned empty or malformed response. Retrying..." + ) + + except openai.APIError as e: + logging.error(f"[Attempt {attempt}] APIError: {e}") + if e.http_status == 429: + logging.warning("Rate limit exceeded. Retrying...") + elif e.http_status >= 500: + logging.warning("Server error encountered. Retrying...") + else: + logging.error("Non-retriable API error occurred.") + raise + + except Exception as e: + logging.exception(f"[Attempt {attempt}] Unexpected exception occurred: {e}") + raise + + logging.error("Exceeded maximum retry attempts. API call failed.") + raise RuntimeError("API call failed after maximum retries.") + + +def supports_tool_calling_for_openrouter( + model_name: str, +) -> bool: + """ + Check if the openrouter model supports tool calling. + + Args: + model_name (str): The name of the model. + + Returns: + bool: True if the model supports tool calling, False otherwise. + """ + import os + + import openai + + client = openai.Client( + api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1" + ) + try: + response = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Call the test tool"}], + tools=[ + { + "type": "function", + "function": { + "name": "dummy_tool", + "description": "Just a test tool", + "parameters": { + "type": "object", + "properties": {}, + }, + }, + } + ], + tool_choice="required", + ) + response = response.to_dict() + return "tool_calls" in response["choices"][0]["message"] + except Exception as e: + print(f"Model '{model_name}' error: {e}") + return False + + def retry_multiple( chat: "ChatModel", messages: "Discussion", diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 3240430b..49a8a775 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -1,17 +1,26 @@ import json import logging +import os from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type, Union import openai from anthropic import Anthropic from openai import OpenAI - +from .llm_utils import call_with_retries, supports_tool_calling_for_openrouter from agentlab.llm import tracking from .base_api import BaseModelArgs +"""This module contains utlity classes for building input messages and interacting with LLM APIs. +It includes: + 1. Message Builder for building input messages + 2. Base Reponse class for different LLM APIs (OpenAI, Anthropic, etc.) + 3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models. +""" + + type ContentItem = Dict[str, Any] type Message = Dict[str, Union[str, List[ContentItem]]] @@ -31,23 +40,32 @@ class MessageBuilder: def __init__(self, role: str): self.role = role self.content: List[ContentItem] = [] - self.tool_call_id = None + self.last_response: ResponseLLMOutput = None + self.tool_call_id: Optional[str] = None - @staticmethod - def system() -> "MessageBuilder": - return MessageBuilder(role="system") + @classmethod + def system(cls) -> "MessageBuilder": + return cls("system") - @staticmethod - def user() -> "MessageBuilder": - return MessageBuilder(role="user") + @classmethod + def user(cls) -> "MessageBuilder": + return cls("user") - @staticmethod - def assistant() -> "MessageBuilder": - return MessageBuilder(role="assistant") + @classmethod + def assistant(cls) -> "MessageBuilder": + return cls("assistant") - @staticmethod - def tool() -> "MessageBuilder": - return MessageBuilder(role="tool") + @classmethod + def tool(cls) -> "MessageBuilder": + return cls("tool") + + def update_last_raw_response(self, raw_response: Any) -> "MessageBuilder": + self.last_response = raw_response + return self + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self def add_text(self, text: str) -> "MessageBuilder": self.content.append({"text": text}) @@ -57,11 +75,35 @@ def add_image(self, image: str) -> "MessageBuilder": self.content.append({"image": image}) return self - def add_tool_id(self, tool_id: str) -> "MessageBuilder": - self.tool_call_id = tool_id + def to_markdown(self) -> str: + parts = [] + for item in self.content: + if "text" in item: + parts.append(item["text"]) + elif "image" in item: + parts.append(f"![Image]({item['image']})") + + markdown = f"## {self.role.capitalize()} Message\n\n" + markdown += "\n\n---\n\n".join(parts) + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + markdown += f"\n\n---\n\n**Tool Call ID:** `{self.tool_call_id}`" + + return markdown + + +class OpenAIResponseAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id return self - def to_openai(self) -> List[Message]: + def prepare_message(self) -> List[Message]: content = [] for item in self.content: if "text" in item: @@ -90,14 +132,25 @@ def to_openai(self) -> List[Message]: return res - def to_anthropic(self) -> List[Message]: + +class AnthropicAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + + def add_tool_id(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: content = [] if self.role == "system": - logging.warning( - "In the Anthropic API, system messages should be passed as a direct input to the client." + logging.info( + "Treating system message as 'user'. In the Anthropic API, system messages should be passed as a direct input to the client." ) - return [] + return [{"role": "user", "content": content}] for item in self.content: if "text" in item: @@ -132,30 +185,92 @@ def to_anthropic(self) -> List[Message]: ] return res - def to_chat_completion(self) -> List[Message]: ... - def to_markdown(self) -> str: +class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + self.tool_name = None + self.last_response = None + + def update_tool_info(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self + + def prepare_message(self) -> List[Message]: + """Prepare the message for the OpenAI API.""" content = [] for item in self.content: if "text" in item: - content.append(item["text"]) + content.append({"type": "text", "text": item["text"]}) elif "image" in item: - content.append(f"![Image]({item['image']})") + content.append({"type": "image_url", "image_url": {"url": item["image"]}}) + res = [{"role": self.role, "content": content}] + + if self.role == "tool": + assert self.tool_call_id is not None, "Tool call ID is required for tool messages" + # tool messages can only take text with openai + # we need to split the first content element if it's text and use it + # then open a new (user) message with the rest + # a function_call_output dict has keys "call_id", "type" and "output" + res[0]["tool_call_id"] = self.tool_call_id + res[0]["type"] = "function_call_output" + message = self.last_response.raw_response.choices[0].message.to_dict() + res[0]["tool_name"] = message["tool_calls"][0]["function"]["name"] + text_content = ( + content.pop(0)["text"] + if "text" in content[0] + else "Tool call answer in next message" + ) + res[0]["content"] = text_content + res.append({"role": "user", "content": content}) + return res + + +class OpenRouterAPIMessageBuilder(MessageBuilder): + + def __init__(self, role: str): + super().__init__(role) + self.tool_call_id = None + self.tool_name = None + self.last_response = None - # Format the role as a header - res = f"## {self.role.capitalize()} Message\n\n" + def update_tool_info(self, id: str) -> "MessageBuilder": + self.tool_call_id = id + return self - # Add content with spacing between items - res += "\n\n---\n\n".join(content) + def prepare_message(self) -> List[Message]: + """Prepare the message for the OpenAI API.""" + content = [] + for item in self.content: + if "text" in item: + content.append({"type": "text", "text": item["text"]}) + elif "image" in item: + content.append({"type": "image_url", "image_url": {"url": item["image"]}}) + res = [{"role": self.role, "content": content}] - # Add tool call ID if the role is "tool" if self.role == "tool": assert self.tool_call_id is not None, "Tool call ID is required for tool messages" - res += f"\n\n---\n\n**Tool Call ID:** `{self.tool_call_id}`" - + # tool messages can only take text with openai + # we need to split the first content element if it's text and use it + # then open a new (user) message with the rest + # a function_call_output dict has keys "call_id", "type" and "output" + res[0]["tool_call_id"] = self.tool_call_id + res[0]["type"] = "function_call_output" + message = self.last_response.raw_response.choices[0].message.to_dict() + res[0]["tool_name"] = message["tool_calls"][0]["function"]["name"] + text_content = ( + content.pop(0)["text"] + if "text" in content[0] + else "Tool call answer in next message" + ) + res[0]["content"] = text_content + res.append({"role": "user", "content": content}) return res +# # Base class for all API Endpoints class BaseResponseModel(ABC): def __init__( self, @@ -209,7 +324,7 @@ def _call_api(self, messages: list[Any | MessageBuilder]) -> dict: input = [] for msg in messages: if isinstance(msg, MessageBuilder): - input += msg.to_openai() + input += msg.prepare_message() else: input.append(msg) try: @@ -226,6 +341,7 @@ def _call_api(self, messages: list[Any | MessageBuilder]) -> dict: # "summary": "detailed", # }, ) + return response except openai.OpenAIError as e: logging.error(f"Failed to get a response from the API: {e}") @@ -254,6 +370,110 @@ def _parse_response(self, response: dict) -> dict: return result +class OpenAIChatCompletionModel(BaseResponseModel): + def __init__( + self, + model_name: str, + client_args: Optional[Dict[str, Any]] = {}, + temperature: float = 0.5, + max_tokens: int = 100, + extra_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + extra_kwargs=extra_kwargs, + ) + self.extra_kwargs["tools"] = self.format_tools_for_chat_completion( + self.extra_kwargs.get("tools", []) + ) + self.client = OpenAI( + **client_args + ) # Ensures client_args is a dict or defaults to an empty dict + + def _call_api(self, messages: list[dict | MessageBuilder]) -> openai.types.chat.ChatCompletion: + chat_messages: List[Message] = [] + for msg in messages: + if isinstance(msg, MessageBuilder): + chat_messages.extend(msg.prepare_message()) + else: + # Assuming msg is already in OpenAI Chat Completion message format + chat_messages.append(msg) # type: ignore + + api_params: Dict[str, Any] = { + "model": self.model_name, + "messages": chat_messages, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "tool_choice": "auto", + **self.extra_kwargs, # Pass tools, tool_choice, etc. here + } + + response = call_with_retries(self.client.chat.completions.create, api_params) + # Basic token tracking (if usage information is available) + if response.usage: + input_tokens = response.usage.prompt_tokens + output_tokens = response.usage.completion_tokens + # Cost calculation would require pricing data + # cost = ... + # if hasattr(tracking.TRACKER, "instance") and isinstance( + # tracking.TRACKER.instance, tracking.LLMTracker + # ): + # tracking.TRACKER.instance(input_tokens, output_tokens, cost) # Placeholder for cost + + return response + + def _parse_response(self, response: openai.types.chat.ChatCompletion) -> ResponseLLMOutput: + + output = ResponseLLMOutput( + raw_response=response, + think="", + action="noop()", # Default if no tool call + last_computer_call_id=None, + assistant_message={ + "role": "assistant", + "content": response.choices[0].message.content, + }, + ) + message = response.choices[0].message.to_dict() + + if tool_calls := message.get("tool_calls", None): + for tool_call in tool_calls: + function = tool_call["function"] + arguments = json.loads(function["arguments"]) + output.action = ( + f"{function['name']}({', '.join([f'{k}={v}' for k, v in arguments.items()])})" + ) + output.last_computer_call_id = tool_call["id"] + output.assistant_message = { + "role": "assistant", + "tool_calls": message["tool_calls"], + } + break # only first tool call is used + + elif "content" in message and message["content"]: + output.think = message["content"] + + return output + + @staticmethod + def format_tools_for_chat_completion(tools_flat): + """Formats response tools format for OpenAI Chat Completion API. + Why we need this? + Ans: actionset.to_tool_description() in bgym only returns description + format valid for OpenAI Response API. + """ + return [ + { + "type": tool["type"], + "function": {k: tool[k] for k in ("name", "description", "parameters")}, + } + for tool in tools_flat + ] + + + class ClaudeResponseModel(BaseResponseModel): def __init__( self, @@ -290,7 +510,7 @@ def _call_api(self, messages: list[dict | MessageBuilder]) -> dict: input = [] for msg in messages: if isinstance(msg, MessageBuilder): - input += msg.to_anthropic() + input += msg.prepare_message() else: input.append(msg) try: @@ -404,6 +624,7 @@ def cua_response_to_text(action): print(f"Error handling action {action}: {e}") +# Factory classes to create the appropriate model based on the API endpoint. @dataclass class OpenAIResponseModelArgs(BaseModelArgs): """Serializable object for instantiating a generic chat model with an OpenAI @@ -419,6 +640,9 @@ def make_model(self, extra_kwargs=None): extra_kwargs=extra_kwargs, ) + def get_message_builder(self) -> MessageBuilder: + return OpenAIResponseAPIMessageBuilder + @dataclass class ClaudeResponseModelArgs(BaseModelArgs): @@ -434,3 +658,124 @@ def make_model(self, extra_kwargs=None): max_tokens=self.max_new_tokens, extra_kwargs=extra_kwargs, ) + + def get_message_builder(self) -> MessageBuilder: + return AnthropicAPIMessageBuilder + + +@dataclass +class OpenAIChatModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenAI + model.""" + + api = "openai" + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + +@dataclass +class OpenRouterModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with an OpenRouter + model.""" + + api: str = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "https://openrouter.ai/api/v1", + "api_key": os.getenv("OPENROUTER_API_KEY"), + }, + model_name=self.model_name, + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenRouterAPIMessageBuilder + + def __post_init__(self): + # Some runtime checks + assert supports_tool_calling_for_openrouter( + self.model_name + ), f"Model {self.model_name} does not support tool calling." + +class VLLMModelArgs(BaseModelArgs): + """Serializable object for instantiating a generic chat model with a VLLM + model.""" + + api = "openai" # tool description format used by actionset.to_tool_description() in bgym + + def __post_init__(self): + # tests + assert self.is_model_available( + self.model_name + ), f"Model {self.model_name} is not available on the VLLM server. \ + Please check the model name or server configuration." + + def make_model(self, extra_kwargs=None): + return OpenAIChatCompletionModel( + client_args={ + "base_url": "http://localhost:8000/v1", + "api_key": os.getenv("VLLM_API_KEY", "EMPTY"), + }, + model_name=self.model_name, # this needs to be set + temperature=self.temperature, + max_tokens=self.max_new_tokens, + extra_kwargs=extra_kwargs, + ) + + def get_message_builder(self) -> MessageBuilder: + return OpenAIChatCompletionAPIMessageBuilder + + ## Some Tests for VLLM server in the works! + def test_vllm_server_reachability(self): + import requests + + try: + response = requests.get( + f"{self.client_args['base_url']}/v1/models", + headers={"Authorization": f"Bearer {self.client_args['api_key']}"}, + ) + if response.status_code == 200: + return True + else: + return False + except requests.RequestException as e: + logging.error(f"Error checking VLLM server reachability: {e}") + return False + + def is_model_available(self, model_name: str) -> bool: + # import requests + + # """Check if the model is available on the VLLM server.""" + # if not self.test_vllm_server_reachability(): + # logging.error("VLLM server is not reachable.") + # return False + # try: + # response = requests.get( + # f"{self.client_args['base_url']}/v1/models", + # headers={"Authorization": f"Bearer {self.client_args['api_key']}"}, + # ) + # if response.status_code == 200: + # models = response.json().get("data", []) + # return any(model.get("id") == model_name for model in models) + # else: + # logging.error( + # f"Failed to fetch vllm hosted models: {response.status_code} - {response.text}" + # ) + # return False + # except requests.RequestException as e: + # logging.error(f"Error checking model availability: {e}") + # return False + return True