From 1305f33d6a320f70f9db537dca52af7c3a6e87d5 Mon Sep 17 00:00:00 2001
From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com>
Date: Thu, 31 Jul 2025 16:47:34 -0400
Subject: [PATCH 1/2] Add LiteLLM API integration
- Implement LiteLLMModel and LiteLLMAPIMessageBuilder for LiteLLM API interaction.
Enhance APIPayload to include reasoning_effort parameter.
- Introduce get_pricing_litellm function for pricing information retrieval.
- Create tests for multi-action, single tool call scenarios and force tool call.
- Added an option to discard extra tool calls by setting use_only_first_toolcall.
---
src/agentlab/llm/litellm_api.py | 322 +++++++++++++++++++++++++++++++
src/agentlab/llm/response_api.py | 6 +
src/agentlab/llm/tracking.py | 26 ++-
tests/llm/test_litellm_api.py | 165 ++++++++++++++++
4 files changed, 515 insertions(+), 4 deletions(-)
create mode 100644 src/agentlab/llm/litellm_api.py
create mode 100644 tests/llm/test_litellm_api.py
diff --git a/src/agentlab/llm/litellm_api.py b/src/agentlab/llm/litellm_api.py
new file mode 100644
index 00000000..43113740
--- /dev/null
+++ b/src/agentlab/llm/litellm_api.py
@@ -0,0 +1,322 @@
+import json
+import logging
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Dict, List, Optional, Type
+
+import litellm
+from litellm import completion
+from openai.types.chat import ChatCompletion as OpenAIChatCompletion
+
+from agentlab.llm.base_api import BaseModelArgs
+from agentlab.llm.response_api import (
+ AgentlabAction,
+ APIPayload,
+ BaseModelWithPricing,
+ LLMOutput,
+ Message,
+ MessageBuilder,
+ OpenAIChatCompletionAPIMessageBuilder,
+ ToolCall,
+ ToolCalls,
+)
+
+litellm.modify_params = True
+
+
+class LiteLLMModel(BaseModelWithPricing):
+ def __init__(
+ self,
+ model_name: str,
+ base_url: Optional[str] = None,
+ api_key: Optional[str] = None,
+ temperature: float | None = None,
+ max_tokens: int | None = 100,
+ use_only_first_toolcall: bool = False,
+ ):
+ super().__init__(
+ model_name=model_name,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ self.action_space_as_tools = True # this should be a config
+ client_args = {}
+ if base_url is not None:
+ client_args["base_url"] = base_url
+ if api_key is not None:
+ client_args["api_key"] = api_key
+ self.client = partial(completion, **client_args)
+ self.init_pricing_tracker(pricing_api="litellm")
+ self.use_only_first_toolcall = use_only_first_toolcall
+ try:
+ self.litellm_info = litellm.get_model_info(model_name)
+ # maybe log this in xray
+
+ except Exception as e:
+ logging.error(f"Failed to get litellm model info: {e}")
+
+ def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion":
+ """
+ Calls the LiteLLM API with the given payload.
+
+ Args:
+ payload (APIPayload): The payload to send to the API.
+
+ Returns:
+ OpenAIChatCompletion: An object with the same keys as OpenAIChatCompletion.
+ """
+ input = []
+ for msg in payload.messages: # type: ignore
+ input.extend(msg.prepare_message())
+ api_params: Dict[str, Any] = {
+ "model": self.model_name,
+ "messages": input,
+ }
+ if self.temperature is not None:
+ api_params["temperature"] = self.temperature
+
+ if self.max_tokens is not None:
+ api_params["max_completion_tokens"] = self.max_tokens
+
+ if payload.tools is not None:
+ api_params["tools"] = (
+ self.format_tools_for_chat_completion(payload.tools)
+ if "function" not in payload.tools[0] # convert if responses_api_tools
+ else payload.tools
+ )
+
+ if payload.tool_choice is not None and payload.force_call_tool is None:
+ api_params["tool_choice"] = (
+ "required" if payload.tool_choice in ("required", "any") else payload.tool_choice
+ )
+
+ if payload.force_call_tool is not None:
+ api_params["tool_choice"] = {
+ "type": "function",
+ "function": {"name": payload.force_call_tool},
+ }
+
+ if payload.reasoning_effort is not None:
+ api_params["reasoning_effort"] = payload.reasoning_effort
+
+ if "tools" in api_params and payload.cache_tool_definition:
+ api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore
+
+ if payload.cache_complete_prompt:
+ # Indicating cache control for the last message enables caching of the complete prompt.
+ api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
+
+ response = self.client(**api_params, num_retries=5)
+
+ return response # type: ignore
+
+ def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput:
+ think_output = self._extract_thinking_content_from_response(response)
+ tool_calls = self._extract_tool_calls_from_response(response)
+
+ if self.action_space_as_tools:
+ env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore
+ else:
+ env_action = self._extract_env_actions_from_text_response(response)
+ return LLMOutput(
+ raw_response=response,
+ think=think_output,
+ action=env_action if env_action is not None else None,
+ tool_calls=tool_calls if tool_calls is not None else None,
+ )
+
+ def _extract_thinking_content_from_response(
+ self, response: OpenAIChatCompletion, wrap_tag="think"
+ ):
+ """Extracts the content from the message, including reasoning if available.
+ It wraps the reasoning around ... for easy identification of reasoning content,
+ When LLM produces 'text' and 'reasoning' in the same message.
+ Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered.
+
+ Args:
+ response: The message object or dict containing content and reasoning.
+ wrap_tag: The tag name to wrap reasoning content (default: "think").
+
+ Returns:
+ str: The extracted content with reasoning wrapped in specified tags.
+ """
+ message = response.choices[0].message
+ if not isinstance(message, dict):
+ message = message.to_dict()
+
+ reasoning_content = message.get("reasoning", None)
+ msg_content = message.get("text", "") # works for Open-router
+ if reasoning_content:
+ # Wrap reasoning in tags with newlines for clarity
+ reasoning_content = f"<{wrap_tag}>{reasoning_content}{wrap_tag}>\n"
+ logging.debug("Extracting content from response.choices[i].message.reasoning")
+ else:
+ reasoning_content = ""
+ return f"{reasoning_content}{msg_content}{message.get('content', '')}"
+
+ def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None:
+ """Extracts tool calls from the response."""
+ message = response.choices[0].message.to_dict()
+ tool_calls = message.get("tool_calls", None)
+ if tool_calls is None:
+ return None
+ tool_call_list = []
+ for tc in tool_calls: # type: ignore
+ tool_call_list.append(
+ ToolCall(
+ name=tc["function"]["name"],
+ arguments=json.loads(tc["function"]["arguments"]),
+ raw_call=tc,
+ )
+ )
+ if self.use_only_first_toolcall:
+ break
+ return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore
+
+ def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None:
+ """Extracts actions from the response."""
+ if not toolcalls:
+ return None
+
+ actions = [
+ AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls
+ ]
+ actions = (
+ AgentlabAction.convert_multiactions_to_agentlab_action_format(actions)
+ if len(actions) > 1
+ else actions[0]
+ )
+ return actions
+
+ def _extract_env_actions_from_text_response(
+ self, response: "OpenAIChatCompletion"
+ ) -> str | None:
+ """Extracts environment actions from the text response."""
+ # Use when action space is not given as tools.
+ # TODO: Add support to pass action space as prompt in LiteLLM.
+ # Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support
+ pass
+
+ @staticmethod
+ def format_tools_for_chat_completion(tools):
+ """Formats response tools format for OpenAI Chat Completion API.
+ Why we need this?
+ Ans: actionset.to_tool_description() in bgym only returns description
+ format valid for OpenAI Response API.
+
+ Args:
+ tools: List of tool descriptions to format for Chat Completion API.
+
+ Returns:
+ Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None.
+ """
+ formatted_tools = None
+ if tools is not None:
+ formatted_tools = [
+ {
+ "type": tool["type"],
+ "function": {k: tool[k] for k in ("name", "description", "parameters")},
+ }
+ for tool in tools
+ ]
+ return formatted_tools
+
+
+class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder):
+ """Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder."""
+
+ def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]:
+ """Prepare the message for the OpenAI API."""
+ content = []
+ for item in self.content:
+ content.append(self.convert_content_to_expected_format(item))
+ output = [{"role": self.role, "content": content}]
+ return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall)
+
+ def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]:
+ """Handle the tool call response from the last raw response."""
+ if self.responded_tool_calls is None:
+ raise ValueError("No tool calls found in responded_tool_calls")
+ output = []
+ raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore
+ if use_only_first_toolcall:
+ raw_call.tool_calls = raw_call.tool_calls[:1]
+ output.append(raw_call) # add raw calls to output
+ for fn_call in self.responded_tool_calls:
+ raw_call = fn_call.raw_call
+ assert (
+ "image" not in fn_call.tool_response
+ ), "Image output is not supported in function calls response."
+ # a function_call_output dict has keys "role", "tool_call_id" and "content"
+ tool_call_reponse = {
+ "name": raw_call["function"]["name"], # required with OpenRouter
+ "role": "tool",
+ "tool_call_id": raw_call["id"],
+ "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"],
+ }
+ output.append(tool_call_reponse)
+
+ return output
+
+
+@dataclass
+class LiteLLMModelArgs(BaseModelArgs):
+ """Serializable arguments for LiteLMMModel."""
+
+ api = "openai" # tool description format used by actionset.to_tool_description() in bgym
+ base_url: Optional[str] = None
+ api_key: Optional[str] = None
+ use_only_first_toolcall: bool = False
+
+ def make_model(self):
+ return LiteLLMModel(
+ model_name=self.model_name,
+ base_url=self.base_url,
+ api_key=self.api_key,
+ max_tokens=self.max_new_tokens,
+ temperature=self.temperature,
+ use_only_first_toolcall=self.use_only_first_toolcall,
+ )
+
+ def get_message_builder(self) -> Type[MessageBuilder]:
+ """Returns a message builder for the LiteLMMModel."""
+ return LiteLLMAPIMessageBuilder
+
+
+if __name__ == "__main__":
+ """
+ Some simple tests to run the LiteLLMModel with different models.
+ """
+
+ from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs
+ from agentlab.experiments.study import Study
+ from agentlab.llm.litellm_api import LiteLLMModelArgs
+
+ def get_agent(model_name: str) -> ToolUseAgentArgs:
+ return ToolUseAgentArgs(
+ model_args=LiteLLMModelArgs(
+ model_name=model_name,
+ max_new_tokens=2000,
+ temperature=None,
+ ),
+ config=DEFAULT_PROMPT_CONFIG,
+ )
+
+ models = [
+ "openai/gpt-4.1",
+ "openai/gpt-4.1-mini",
+ "openai/gpt-4.1-nano",
+ "openai/o3-2025-04-16",
+ "anthropic/claude-3-7-sonnet-20250219",
+ "anthropic/claude-sonnet-4-20250514",
+ ## Add more models to test.
+ ]
+ agent_args = [get_agent(model) for model in models]
+
+ study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING)
+ study.run(
+ n_jobs=5,
+ parallel_backend="ray",
+ strict_reproducibility=False,
+ n_relaunch=3,
+ )
diff --git a/src/agentlab/llm/response_api.py b/src/agentlab/llm/response_api.py
index 1bbeeebc..890b27c1 100644
--- a/src/agentlab/llm/response_api.py
+++ b/src/agentlab/llm/response_api.py
@@ -27,6 +27,7 @@
3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models.
"""
+logger = logging.getLogger(__name__)
ContentItem = Dict[str, Any]
Message = Dict[str, Union[str, List[ContentItem]]]
@@ -388,10 +389,15 @@ class APIPayload:
cache_complete_prompt: bool = (
False # If True, will cache the complete prompt in the last message.
)
+ reasoning_effort: Literal["low", "medium", "high"] | None = None
def __post_init__(self):
if self.tool_choice and self.force_call_tool:
raise ValueError("tool_choice and force_call_tool are mutually exclusive")
+ if self.reasoning_effort is not None:
+ logger.info(
+ "In agentlab reasoning_effort is used by LiteLLM API only. We will eventually shift to LiteLLM API for all LLMs."
+ )
# # Base class for all API Endpoints
diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py
index b8bcce7c..e761a7f6 100644
--- a/src/agentlab/llm/tracking.py
+++ b/src/agentlab/llm/tracking.py
@@ -5,11 +5,12 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
-from functools import cache
+from functools import cache, partial
from typing import Optional
import requests
from langchain_community.callbacks import bedrock_anthropic_callback, openai_info
+from litellm import completion_cost, get_model_info
TRACKER = threading.local()
@@ -141,6 +142,21 @@ def get_pricing_anthropic():
return res
+def get_pricing_litellm(model_name):
+ """Returns a dictionary of model pricing for a LiteLLM model."""
+ try:
+ info = get_model_info(model_name)
+ except Exception as e:
+ logging.error(f"Error fetching model info for {model_name}: {e} from litellm")
+ info = {}
+ return {
+ model_name: {
+ "prompt": info.get("input_cost_per_token", 0.0),
+ "completion": info.get("output_cost_per_token", 0.0),
+ }
+ }
+
+
class TrackAPIPricingMixin:
"""Mixin class to handle pricing information for different models.
This populates the tracker.stats used by the cost_tracker_decorator
@@ -151,9 +167,8 @@ class TrackAPIPricingMixin:
def reset_stats(self):
self.stats = Stats()
- def init_pricing_tracker(
- self, pricing_api=None
- ): # TODO: Use this function in the base class init instead of having a init in the Mixin class.
+ def init_pricing_tracker(self, pricing_api=None):
+ """Initialize the pricing tracker with the given API."""
self._pricing_api = pricing_api
self.set_pricing_attributes()
self.reset_stats()
@@ -185,6 +200,7 @@ def fetch_pricing_information_from_provider(self) -> Optional[dict]:
"openai": get_pricing_openai,
"anthropic": get_pricing_anthropic,
"openrouter": get_pricing_openrouter,
+ "litellm": partial(get_pricing_litellm, self.model_name),
}
pricing_fn = pricing_fn_map.get(self._pricing_api, None)
if pricing_fn is None:
@@ -248,6 +264,8 @@ def get_effective_cost(self, response):
return self.get_effective_cost_from_antrophic_api(response)
elif self._pricing_api == "openai":
return self.get_effective_cost_from_openai_api(response)
+ elif self._pricing_api == "litellm":
+ return completion_cost(response)
else:
logging.warning(
f"Unsupported provider: {self._pricing_api}. No effective cost calculated."
diff --git a/tests/llm/test_litellm_api.py b/tests/llm/test_litellm_api.py
new file mode 100644
index 00000000..29e7a830
--- /dev/null
+++ b/tests/llm/test_litellm_api.py
@@ -0,0 +1,165 @@
+from functools import partial
+
+import pytest
+
+from agentlab.llm.litellm_api import LiteLLMModelArgs
+from agentlab.llm.response_api import APIPayload, LLMOutput
+
+chat_api_tools = [
+ {
+ "type": "function",
+ "name": "get_weather",
+ "description": "Get the current weather in a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The location to get the weather for.",
+ },
+ "unit": {
+ "type": "string",
+ "enum": ["celsius", "fahrenheit"],
+ "description": "The unit of temperature.",
+ },
+ },
+ "required": ["location"],
+ },
+ },
+ {
+ "type": "function",
+ "name": "get_time",
+ "description": "Get the current time in a given location.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The location to get the time for.",
+ }
+ },
+ "required": ["location"],
+ },
+ },
+]
+
+
+# test_config (setting name, BaseModelArgs, model_name, tools)
+tool_test_configs = [
+ ("gpt-4.1", LiteLLMModelArgs, "openai/gpt-4.1-2025-04-14", chat_api_tools),
+ # ("claude-3", LiteLLMModelArgs, "anthropic/claude-3-haiku-20240307", anthropic_tools), # fails for parallel tool calls
+ # ("claude-3.7", LiteLLMModelArgs, "anthropic/claude-3-7-sonnet-20250219", anthropic_tools), # fails for parallel tool calls
+ ("claude-4-sonnet", LiteLLMModelArgs, "anthropic/claude-sonnet-4-20250514", chat_api_tools),
+ # ("gpt-o3", LiteLLMModelArgs, "openai/o3-2025-04-16", chat_api_tools), # fails for parallel tool calls
+ # add more models as needed
+]
+
+
+def add_user_messages(msg_builder):
+ return [
+ msg_builder.user().add_text("What is the weather in Paris and Delhi?"),
+ msg_builder.user().add_text("You must call multiple tools to achieve the task."),
+ ]
+
+
+## Test multiaction
+@pytest.mark.pricy
+def test_multi_action_tool_calls():
+ """
+ Test that the model can produce multiple tool calls in parallel.
+ Note: Remove assert and Uncomment commented lines to see the full behaviour of models and tool choices.
+ """
+ res_df = []
+ for tool_choice in [
+ # "none",
+ "required", # fails for Responses API
+ "any", # fails for Responses API
+ "auto",
+ # "get_weather", # force a specific tool call
+ ]:
+ for name, llm_class, checkpoint_name, tools in tool_test_configs:
+ model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None)
+ llm, msg_builder = model_args.make_model(), model_args.get_message_builder()
+ messages = add_user_messages(msg_builder)
+ if tool_choice == "get_weather": # force a specific tool call
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice)
+ )
+ else:
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, tool_choice=tool_choice)
+ )
+ num_tool_calls = len(response.tool_calls) if response.tool_calls else 0
+ row = {
+ "model": name,
+ "checkpoint": checkpoint_name,
+ "tool_choice": tool_choice,
+ "num_tool_calls": num_tool_calls,
+ "action": response.action,
+ }
+ res_df.append(row)
+ assert (
+ num_tool_calls == 2
+ ), f"Expected 2 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice}"
+ # import pandas as pd
+ # print(pd.DataFrame(res_df))
+
+
+@pytest.mark.pricy
+def test_single_tool_call():
+ """
+ Test that the LLMOutput contains only one tool call when use_only_first_toolcall is True.
+ """
+ for tool_choice in [
+ # 'none',
+ "required",
+ "any",
+ "auto",
+ ]:
+ for name, llm_class, checkpoint_name, tools in tool_test_configs:
+ print(name, "tool choice:", tool_choice, "\n", "**" * 10)
+ llm_class = partial(llm_class, use_only_first_toolcall=True)
+ model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None)
+ llm, msg_builder = model_args.make_model(), model_args.get_message_builder()
+ messages = add_user_messages(msg_builder)
+ if tool_choice == "get_weather": # force a specific tool call
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, force_call_tool=tool_choice)
+ )
+ else:
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, tool_choice=tool_choice)
+ )
+ num_tool_calls = len(response.tool_calls) if response.tool_calls else 0
+ assert (
+ num_tool_calls == 1
+ ), f"Expected 1 tool calls, but got {num_tool_calls} for {name} with tool choice {tool_choice }"
+
+
+@pytest.mark.pricy
+def test_force_tool_call():
+ """
+ Test that the model can produce a specific tool call when requested.
+ The user message asks the 'weather' but we force call tool "get_time".
+ We test if 'get_time' is present in the tool calls.
+ Note: Model can have other tool calls as well.
+ """
+ force_call_tool = "get_time"
+ for name, llm_class, checkpoint_name, tools in tool_test_configs:
+ model_args = llm_class(model_name=checkpoint_name, max_new_tokens=200, temperature=None)
+ llm, msg_builder = model_args.make_model(), model_args.get_message_builder()
+ messages = add_user_messages(msg_builder) # asks weather in Paris and Delhi
+ response: LLMOutput = llm(
+ APIPayload(messages=messages, tools=tools, force_call_tool=force_call_tool)
+ )
+ called_fn_names = [call.name for call in response.tool_calls] if response.tool_calls else []
+ assert response.tool_calls is not None
+ assert any(
+ fn_name == "get_time" for fn_name in called_fn_names
+ ), f"Model:{name},Expected all tool calls to be 'get_time', but got {called_fn_names} with force call {force_call_tool}"
+
+
+if __name__ == "__main__":
+ test_multi_action_tool_calls()
+ test_force_tool_call()
+ test_single_tool_call()
From 971a3e95464eb1867b140d0f1bd128c99adeee23 Mon Sep 17 00:00:00 2001
From: Aman Jaiswal <66757799+amanjaiswal73892@users.noreply.github.com>
Date: Fri, 1 Aug 2025 17:47:23 -0400
Subject: [PATCH 2/2] reduce LiteLLM logging in litellm_api script
---
src/agentlab/llm/litellm_api.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/agentlab/llm/litellm_api.py b/src/agentlab/llm/litellm_api.py
index 43113740..bc8f48ab 100644
--- a/src/agentlab/llm/litellm_api.py
+++ b/src/agentlab/llm/litellm_api.py
@@ -288,10 +288,14 @@ def get_message_builder(self) -> Type[MessageBuilder]:
Some simple tests to run the LiteLLMModel with different models.
"""
+ import os
+
from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs
from agentlab.experiments.study import Study
from agentlab.llm.litellm_api import LiteLLMModelArgs
+ os.environ["LITELLM_LOG"] = "WARNING"
+
def get_agent(model_name: str) -> ToolUseAgentArgs:
return ToolUseAgentArgs(
model_args=LiteLLMModelArgs(