diff --git a/src/agentlab/llm/litellm_api.py b/src/agentlab/llm/litellm_api.py new file mode 100644 index 00000000..bc8f48ab --- /dev/null +++ b/src/agentlab/llm/litellm_api.py @@ -0,0 +1,326 @@ +import json +import logging +from dataclasses import dataclass +from functools import partial +from typing import Any, Dict, List, Optional, Type + +import litellm +from litellm import completion +from openai.types.chat import ChatCompletion as OpenAIChatCompletion + +from agentlab.llm.base_api import BaseModelArgs +from agentlab.llm.response_api import ( + AgentlabAction, + APIPayload, + BaseModelWithPricing, + LLMOutput, + Message, + MessageBuilder, + OpenAIChatCompletionAPIMessageBuilder, + ToolCall, + ToolCalls, +) + +litellm.modify_params = True + + +class LiteLLMModel(BaseModelWithPricing): + def __init__( + self, + model_name: str, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + temperature: float | None = None, + max_tokens: int | None = 100, + use_only_first_toolcall: bool = False, + ): + super().__init__( + model_name=model_name, + temperature=temperature, + max_tokens=max_tokens, + ) + self.action_space_as_tools = True # this should be a config + client_args = {} + if base_url is not None: + client_args["base_url"] = base_url + if api_key is not None: + client_args["api_key"] = api_key + self.client = partial(completion, **client_args) + self.init_pricing_tracker(pricing_api="litellm") + self.use_only_first_toolcall = use_only_first_toolcall + try: + self.litellm_info = litellm.get_model_info(model_name) + # maybe log this in xray + + except Exception as e: + logging.error(f"Failed to get litellm model info: {e}") + + def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion": + """ + Calls the LiteLLM API with the given payload. + + Args: + payload (APIPayload): The payload to send to the API. + + Returns: + OpenAIChatCompletion: An object with the same keys as OpenAIChatCompletion. + """ + input = [] + for msg in payload.messages: # type: ignore + input.extend(msg.prepare_message()) + api_params: Dict[str, Any] = { + "model": self.model_name, + "messages": input, + } + if self.temperature is not None: + api_params["temperature"] = self.temperature + + if self.max_tokens is not None: + api_params["max_completion_tokens"] = self.max_tokens + + if payload.tools is not None: + api_params["tools"] = ( + self.format_tools_for_chat_completion(payload.tools) + if "function" not in payload.tools[0] # convert if responses_api_tools + else payload.tools + ) + + if payload.tool_choice is not None and payload.force_call_tool is None: + api_params["tool_choice"] = ( + "required" if payload.tool_choice in ("required", "any") else payload.tool_choice + ) + + if payload.force_call_tool is not None: + api_params["tool_choice"] = { + "type": "function", + "function": {"name": payload.force_call_tool}, + } + + if payload.reasoning_effort is not None: + api_params["reasoning_effort"] = payload.reasoning_effort + + if "tools" in api_params and payload.cache_tool_definition: + api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore + + if payload.cache_complete_prompt: + # Indicating cache control for the last message enables caching of the complete prompt. + api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} + + response = self.client(**api_params, num_retries=5) + + return response # type: ignore + + def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput: + think_output = self._extract_thinking_content_from_response(response) + tool_calls = self._extract_tool_calls_from_response(response) + + if self.action_space_as_tools: + env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore + else: + env_action = self._extract_env_actions_from_text_response(response) + return LLMOutput( + raw_response=response, + think=think_output, + action=env_action if env_action is not None else None, + tool_calls=tool_calls if tool_calls is not None else None, + ) + + def _extract_thinking_content_from_response( + self, response: OpenAIChatCompletion, wrap_tag="think" + ): + """Extracts the content from the message, including reasoning if available. + It wraps the reasoning around ... for easy identification of reasoning content, + When LLM produces 'text' and 'reasoning' in the same message. + Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. + + Args: + response: The message object or dict containing content and reasoning. + wrap_tag: The tag name to wrap reasoning content (default: "think"). + + Returns: + str: The extracted content with reasoning wrapped in specified tags. + """ + message = response.choices[0].message + if not isinstance(message, dict): + message = message.to_dict() + + reasoning_content = message.get("reasoning", None) + msg_content = message.get("text", "") # works for Open-router + if reasoning_content: + # Wrap reasoning in tags with newlines for clarity + reasoning_content = f"<{wrap_tag}>{reasoning_content}\n" + logging.debug("Extracting content from response.choices[i].message.reasoning") + else: + reasoning_content = "" + return f"{reasoning_content}{msg_content}{message.get('content', '')}" + + def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None: + """Extracts tool calls from the response.""" + message = response.choices[0].message.to_dict() + tool_calls = message.get("tool_calls", None) + if tool_calls is None: + return None + tool_call_list = [] + for tc in tool_calls: # type: ignore + tool_call_list.append( + ToolCall( + name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + raw_call=tc, + ) + ) + if self.use_only_first_toolcall: + break + return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore + + def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: + """Extracts actions from the response.""" + if not toolcalls: + return None + + actions = [ + AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls + ] + actions = ( + AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) + if len(actions) > 1 + else actions[0] + ) + return actions + + def _extract_env_actions_from_text_response( + self, response: "OpenAIChatCompletion" + ) -> str | None: + """Extracts environment actions from the text response.""" + # Use when action space is not given as tools. + # TODO: Add support to pass action space as prompt in LiteLLM. + # Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support + pass + + @staticmethod + def format_tools_for_chat_completion(tools): + """Formats response tools format for OpenAI Chat Completion API. + Why we need this? + Ans: actionset.to_tool_description() in bgym only returns description + format valid for OpenAI Response API. + + Args: + tools: List of tool descriptions to format for Chat Completion API. + + Returns: + Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None. + """ + formatted_tools = None + if tools is not None: + formatted_tools = [ + { + "type": tool["type"], + "function": {k: tool[k] for k in ("name", "description", "parameters")}, + } + for tool in tools + ] + return formatted_tools + + +class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder): + """Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder.""" + + def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]: + """Prepare the message for the OpenAI API.""" + content = [] + for item in self.content: + content.append(self.convert_content_to_expected_format(item)) + output = [{"role": self.role, "content": content}] + return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall) + + def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]: + """Handle the tool call response from the last raw response.""" + if self.responded_tool_calls is None: + raise ValueError("No tool calls found in responded_tool_calls") + output = [] + raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore + if use_only_first_toolcall: + raw_call.tool_calls = raw_call.tool_calls[:1] + output.append(raw_call) # add raw calls to output + for fn_call in self.responded_tool_calls: + raw_call = fn_call.raw_call + assert ( + "image" not in fn_call.tool_response + ), "Image output is not supported in function calls response." + # a function_call_output dict has keys "role", "tool_call_id" and "content" + tool_call_reponse = { + "name": raw_call["function"]["name"], # required with OpenRouter + "role": "tool", + "tool_call_id": raw_call["id"], + "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"], + } + output.append(tool_call_reponse) + + return output + + +@dataclass +class LiteLLMModelArgs(BaseModelArgs): + """Serializable arguments for LiteLMMModel.""" + + api = "openai" # tool description format used by actionset.to_tool_description() in bgym + base_url: Optional[str] = None + api_key: Optional[str] = None + use_only_first_toolcall: bool = False + + def make_model(self): + return LiteLLMModel( + model_name=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + max_tokens=self.max_new_tokens, + temperature=self.temperature, + use_only_first_toolcall=self.use_only_first_toolcall, + ) + + def get_message_builder(self) -> Type[MessageBuilder]: + """Returns a message builder for the LiteLMMModel.""" + return LiteLLMAPIMessageBuilder + + +if __name__ == "__main__": + """ + Some simple tests to run the LiteLLMModel with different models. + """ + + import os + + from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs + from agentlab.experiments.study import Study + from agentlab.llm.litellm_api import LiteLLMModelArgs + + os.environ["LITELLM_LOG"] = "WARNING" + + def get_agent(model_name: str) -> ToolUseAgentArgs: + return ToolUseAgentArgs( + model_args=LiteLLMModelArgs( + model_name=model_name, + max_new_tokens=2000, + temperature=None, + ), + config=DEFAULT_PROMPT_CONFIG, + ) + + models = [ + "openai/gpt-4.1", + "openai/gpt-4.1-mini", + "openai/gpt-4.1-nano", + "openai/o3-2025-04-16", + "anthropic/claude-3-7-sonnet-20250219", + "anthropic/claude-sonnet-4-20250514", + ## Add more models to test. + ] + agent_args = [get_agent(model) for model in models] + + study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING) + study.run( + n_jobs=5, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=3, + ) diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py index 1bbeeebc..890b27c1 100644 --- a/src/agentlab/llm/response_api.py +++ b/src/agentlab/llm/response_api.py @@ -27,6 +27,7 @@ 3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models. """ +logger = logging.getLogger(__name__) ContentItem = Dict[str, Any] Message = Dict[str, Union[str, List[ContentItem]]] @@ -388,10 +389,15 @@ class APIPayload: cache_complete_prompt: bool = ( False # If True, will cache the complete prompt in the last message. ) + reasoning_effort: Literal["low", "medium", "high"] | None = None def __post_init__(self): if self.tool_choice and self.force_call_tool: raise ValueError("tool_choice and force_call_tool are mutually exclusive") + if self.reasoning_effort is not None: + logger.info( + "In agentlab reasoning_effort is used by LiteLLM API only. We will eventually shift to LiteLLM API for all LLMs." + ) # # Base class for all API Endpoints diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index b8bcce7c..e761a7f6 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -5,11 +5,12 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field -from functools import cache +from functools import cache, partial from typing import Optional import requests from langchain_community.callbacks import bedrock_anthropic_callback, openai_info +from litellm import completion_cost, get_model_info TRACKER = threading.local() @@ -141,6 +142,21 @@ def get_pricing_anthropic(): return res +def get_pricing_litellm(model_name): + """Returns a dictionary of model pricing for a LiteLLM model.""" + try: + info = get_model_info(model_name) + except Exception as e: + logging.error(f"Error fetching model info for {model_name}: {e} from litellm") + info = {} + return { + model_name: { + "prompt": info.get("input_cost_per_token", 0.0), + "completion": info.get("output_cost_per_token", 0.0), + } + } + + class TrackAPIPricingMixin: """Mixin class to handle pricing information for different models. This populates the tracker.stats used by the cost_tracker_decorator @@ -151,9 +167,8 @@ class TrackAPIPricingMixin: def reset_stats(self): self.stats = Stats() - def init_pricing_tracker( - self, pricing_api=None - ): # TODO: Use this function in the base class init instead of having a init in the Mixin class. + def init_pricing_tracker(self, pricing_api=None): + """Initialize the pricing tracker with the given API.""" self._pricing_api = pricing_api self.set_pricing_attributes() self.reset_stats() @@ -185,6 +200,7 @@ def fetch_pricing_information_from_provider(self) -> Optional[dict]: "openai": get_pricing_openai, "anthropic": get_pricing_anthropic, "openrouter": get_pricing_openrouter, + "litellm": partial(get_pricing_litellm, self.model_name), } pricing_fn = pricing_fn_map.get(self._pricing_api, None) if pricing_fn is None: @@ -248,6 +264,8 @@ def get_effective_cost(self, response): return self.get_effective_cost_from_antrophic_api(response) elif self._pricing_api == "openai": return self.get_effective_cost_from_openai_api(response) + elif self._pricing_api == "litellm": + return completion_cost(response) else: logging.warning( f"Unsupported provider: {self._pricing_api}. No effective cost calculated." diff --git a/tests/llm/test_litellm_api.py b/tests/llm/test_litellm_api.py new file mode 100644 index 00000000..29e7a830 --- /dev/null +++ b/tests/llm/test_litellm_api.py @@ -0,0 +1,165 @@ +from functools import partial + +import pytest + +from agentlab.llm.litellm_api import LiteLLMModelArgs +from agentlab.llm.response_api import APIPayload, LLMOutput + +chat_api_tools = [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the weather for.", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit of temperature.", + }, + }, + "required": ["location"], + }, + }, + { + "type": "function", + "name": "get_time", + "description": "Get the current time in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The location to get the time for.", + } + }, + "required": ["location"], + }, + }, +] + + +# test_config (setting name, BaseModelArgs, model_name, tools) +tool_test_configs = [ + ("gpt-4.1", LiteLLMModelArgs, "openai/gpt-4.1-2025-04-14", chat_api_tools), + # ("claude-3", LiteLLMModelArgs, "anthropic/claude-3-haiku-20240307", anthropic_tools), # fails for parallel tool calls + # ("claude-3.7", LiteLLMModelArgs, "anthropic/claude-3-7-sonnet-20250219", anthropic_tools), # fails for parallel tool calls + ("claude-4-sonnet", LiteLLMModelArgs, "anthropic/claude-sonnet-4-20250514", chat_api_tools), + # ("gpt-o3", LiteLLMModelArgs, "openai/o3-2025-04-16", chat_api_tools), # fails for parallel tool calls + # add more models as needed +] + + +def add_user_messages(msg_builder): + return [ + msg_builder.user().add_text("What is the weather in Paris and Delhi?"), + msg_builder.user().add_text("You must call multiple tools to achieve the task."), + ] + + +## Test multiaction +@pytest.mark.pricy +def test_multi_action_tool_calls(): + """ + Test that the model can produce multiple tool calls in parallel. + Note: Remove assert and Uncomment commented lines to see the full behaviour of models and tool choices. + """ + res_df = [] + for tool_choice in [ + # "none", + "required", # fails for Responses API + "any", # fails for Responses API + "auto", + # "get_weather", # force a specific tool call + ]: + for name, llm_class, checkpoint_name, tools in tool_test_configs: + model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None) + llm, msg_builder = model_args.make_model(), model_args.get_message_builder() + messages = add_user_messages(msg_builder) + if tool_choice == "get_weather": # force a specific tool call + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice) + ) + else: + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, tool_choice=tool_choice) + ) + num_tool_calls = len(response.tool_calls) if response.tool_calls else 0 + row = { + "model": name, + "checkpoint": checkpoint_name, + "tool_choice": tool_choice, + "num_tool_calls": num_tool_calls, + "action": response.action, + } + res_df.append(row) + assert ( + num_tool_calls == 2 + ), f"Expected 2 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice}" + # import pandas as pd + # print(pd.DataFrame(res_df)) + + +@pytest.mark.pricy +def test_single_tool_call(): + """ + Test that the LLMOutput contains only one tool call when use_only_first_toolcall is True. + """ + for tool_choice in [ + # 'none', + "required", + "any", + "auto", + ]: + for name, llm_class, checkpoint_name, tools in tool_test_configs: + print(name, "tool choice:", tool_choice, "\n", "**" * 10) + llm_class = partial(llm_class, use_only_first_toolcall=True) + model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None) + llm, msg_builder = model_args.make_model(), model_args.get_message_builder() + messages = add_user_messages(msg_builder) + if tool_choice == "get_weather": # force a specific tool call + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice) + ) + else: + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, tool_choice=tool_choice) + ) + num_tool_calls = len(response.tool_calls) if response.tool_calls else 0 + assert ( + num_tool_calls == 1 + ), f"Expected 1 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice }" + + +@pytest.mark.pricy +def test_force_tool_call(): + """ + Test that the model can produce a specific tool call when requested. + The user message asks the 'weather' but we force call tool "get_time". + We test if 'get_time' is present in the tool calls. + Note: Model can have other tool calls as well. + """ + force_call_tool = "get_time" + for name, llm_class, checkpoint_name, tools in tool_test_configs: + model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None) + llm, msg_builder = model_args.make_model(), model_args.get_message_builder() + messages = add_user_messages(msg_builder) # asks weather in Paris and Delhi + response: LLMOutput = llm( + APIPayload(messages=messages, tools=tools, force_call_tool=force_call_tool) + ) + called_fn_names = [call.name for call in response.tool_calls] if response.tool_calls else [] + assert response.tool_calls is not None + assert any( + fn_name == "get_time" for fn_name in called_fn_names + ), f"Model:{name},Expected all tool calls to be 'get_time', but got {called_fn_names} with force call {force_call_tool}" + + +if __name__ == "__main__": + test_multi_action_tool_calls() + test_force_tool_call() + test_single_tool_call()