-
Notifications
You must be signed in to change notification settings - Fork 105
Add Litellm API integration #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,326 @@ | ||
| import json | ||
| import logging | ||
| from dataclasses import dataclass | ||
| from functools import partial | ||
| from typing import Any, Dict, List, Optional, Type | ||
|
|
||
| import litellm | ||
| from litellm import completion | ||
| from openai.types.chat import ChatCompletion as OpenAIChatCompletion | ||
|
|
||
| from agentlab.llm.base_api import BaseModelArgs | ||
| from agentlab.llm.response_api import ( | ||
| AgentlabAction, | ||
| APIPayload, | ||
| BaseModelWithPricing, | ||
| LLMOutput, | ||
| Message, | ||
| MessageBuilder, | ||
| OpenAIChatCompletionAPIMessageBuilder, | ||
| ToolCall, | ||
| ToolCalls, | ||
| ) | ||
|
|
||
| litellm.modify_params = True | ||
|
|
||
|
|
||
| class LiteLLMModel(BaseModelWithPricing): | ||
| def __init__( | ||
| self, | ||
| model_name: str, | ||
| base_url: Optional[str] = None, | ||
| api_key: Optional[str] = None, | ||
| temperature: float | None = None, | ||
| max_tokens: int | None = 100, | ||
| use_only_first_toolcall: bool = False, | ||
| ): | ||
| super().__init__( | ||
| model_name=model_name, | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| ) | ||
| self.action_space_as_tools = True # this should be a config | ||
| client_args = {} | ||
| if base_url is not None: | ||
| client_args["base_url"] = base_url | ||
| if api_key is not None: | ||
| client_args["api_key"] = api_key | ||
| self.client = partial(completion, **client_args) | ||
| self.init_pricing_tracker(pricing_api="litellm") | ||
| self.use_only_first_toolcall = use_only_first_toolcall | ||
| try: | ||
| self.litellm_info = litellm.get_model_info(model_name) | ||
| # maybe log this in xray | ||
|
|
||
| except Exception as e: | ||
| logging.error(f"Failed to get litellm model info: {e}") | ||
|
|
||
| def _call_api(self, payload: APIPayload) -> "OpenAIChatCompletion": | ||
| """ | ||
| Calls the LiteLLM API with the given payload. | ||
| Args: | ||
| payload (APIPayload): The payload to send to the API. | ||
| Returns: | ||
| OpenAIChatCompletion: An object with the same keys as OpenAIChatCompletion. | ||
| """ | ||
| input = [] | ||
| for msg in payload.messages: # type: ignore | ||
| input.extend(msg.prepare_message()) | ||
|
Comment on lines
+69
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inefficient Message List Building
Tell me moreWhat is the issue?Messages are being processed one at a time with extend() calls, which can be inefficient for large message histories. Why this mattersMultiple extend operations on a list cause repeated memory allocations and copies. This becomes more significant with larger message histories. Suggested change ∙ Feature PreviewUse a list comprehension to build the input list in one operation: input = [msg for messages in payload.messages for msg in messages.prepare_message()] # type: ignoreProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| api_params: Dict[str, Any] = { | ||
| "model": self.model_name, | ||
| "messages": input, | ||
| } | ||
| if self.temperature is not None: | ||
| api_params["temperature"] = self.temperature | ||
|
|
||
| if self.max_tokens is not None: | ||
| api_params["max_completion_tokens"] = self.max_tokens | ||
|
|
||
| if payload.tools is not None: | ||
| api_params["tools"] = ( | ||
| self.format_tools_for_chat_completion(payload.tools) | ||
| if "function" not in payload.tools[0] # convert if responses_api_tools | ||
| else payload.tools | ||
| ) | ||
|
|
||
| if payload.tool_choice is not None and payload.force_call_tool is None: | ||
| api_params["tool_choice"] = ( | ||
| "required" if payload.tool_choice in ("required", "any") else payload.tool_choice | ||
| ) | ||
|
|
||
| if payload.force_call_tool is not None: | ||
| api_params["tool_choice"] = { | ||
| "type": "function", | ||
| "function": {"name": payload.force_call_tool}, | ||
| } | ||
|
|
||
| if payload.reasoning_effort is not None: | ||
| api_params["reasoning_effort"] = payload.reasoning_effort | ||
|
|
||
| if "tools" in api_params and payload.cache_tool_definition: | ||
| api_params["tools"][-1]["cache_control"] = {"type": "ephemeral"} # type: ignore | ||
|
|
||
| if payload.cache_complete_prompt: | ||
| # Indicating cache control for the last message enables caching of the complete prompt. | ||
| api_params["messages"][-1]["content"][-1]["cache_control"] = {"type": "ephemeral"} | ||
|
|
||
| response = self.client(**api_params, num_retries=5) | ||
|
|
||
| return response # type: ignore | ||
|
|
||
| def _parse_response(self, response: "OpenAIChatCompletion") -> LLMOutput: | ||
| think_output = self._extract_thinking_content_from_response(response) | ||
| tool_calls = self._extract_tool_calls_from_response(response) | ||
|
|
||
| if self.action_space_as_tools: | ||
| env_action = self._extract_env_actions_from_toolcalls(tool_calls) # type: ignore | ||
| else: | ||
| env_action = self._extract_env_actions_from_text_response(response) | ||
| return LLMOutput( | ||
| raw_response=response, | ||
| think=think_output, | ||
| action=env_action if env_action is not None else None, | ||
| tool_calls=tool_calls if tool_calls is not None else None, | ||
| ) | ||
|
|
||
| def _extract_thinking_content_from_response( | ||
| self, response: OpenAIChatCompletion, wrap_tag="think" | ||
| ): | ||
| """Extracts the content from the message, including reasoning if available. | ||
| It wraps the reasoning around <think>...</think> for easy identification of reasoning content, | ||
| When LLM produces 'text' and 'reasoning' in the same message. | ||
| Note: The wrapping of 'thinking' content may not be nedeed and may be reconsidered. | ||
| Args: | ||
| response: The message object or dict containing content and reasoning. | ||
| wrap_tag: The tag name to wrap reasoning content (default: "think"). | ||
| Returns: | ||
| str: The extracted content with reasoning wrapped in specified tags. | ||
| """ | ||
| message = response.choices[0].message | ||
| if not isinstance(message, dict): | ||
| message = message.to_dict() | ||
|
|
||
| reasoning_content = message.get("reasoning", None) | ||
| msg_content = message.get("text", "") # works for Open-router | ||
| if reasoning_content: | ||
| # Wrap reasoning in <think> tags with newlines for clarity | ||
| reasoning_content = f"<{wrap_tag}>{reasoning_content}</{wrap_tag}>\n" | ||
| logging.debug("Extracting content from response.choices[i].message.reasoning") | ||
| else: | ||
| reasoning_content = "" | ||
| return f"{reasoning_content}{msg_content}{message.get('content', '')}" | ||
|
|
||
| def _extract_tool_calls_from_response(self, response: OpenAIChatCompletion) -> ToolCalls | None: | ||
| """Extracts tool calls from the response.""" | ||
| message = response.choices[0].message.to_dict() | ||
| tool_calls = message.get("tool_calls", None) | ||
| if tool_calls is None: | ||
| return None | ||
| tool_call_list = [] | ||
| for tc in tool_calls: # type: ignore | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unexplained type ignore
Tell me moreWhat is the issue?Type ignore comment used without explanation of why type checking needs to be suppressed Why this mattersUnexplained type ignores make it difficult to understand the code's type safety assumptions and maintainability Suggested change ∙ Feature PreviewAdd a comment explaining why the type ignore is necessary: # type: ignore # tool_calls is dynamically typed from API response
for tc in tool_calls:Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| tool_call_list.append( | ||
| ToolCall( | ||
| name=tc["function"]["name"], | ||
| arguments=json.loads(tc["function"]["arguments"]), | ||
| raw_call=tc, | ||
| ) | ||
| ) | ||
| if self.use_only_first_toolcall: | ||
| break | ||
| return ToolCalls(tool_calls=tool_call_list, raw_calls=response) # type: ignore | ||
|
|
||
| def _extract_env_actions_from_toolcalls(self, toolcalls: ToolCalls) -> Any | None: | ||
| """Extracts actions from the response.""" | ||
| if not toolcalls: | ||
| return None | ||
|
|
||
| actions = [ | ||
| AgentlabAction.convert_toolcall_to_agentlab_action_format(call) for call in toolcalls | ||
| ] | ||
| actions = ( | ||
| AgentlabAction.convert_multiactions_to_agentlab_action_format(actions) | ||
| if len(actions) > 1 | ||
| else actions[0] | ||
| ) | ||
| return actions | ||
|
|
||
| def _extract_env_actions_from_text_response( | ||
| self, response: "OpenAIChatCompletion" | ||
| ) -> str | None: | ||
| """Extracts environment actions from the text response.""" | ||
| # Use when action space is not given as tools. | ||
| # TODO: Add support to pass action space as prompt in LiteLLM. | ||
| # Check: https://docs.litellm.ai/docs/completion/function_call#function-calling-for-models-wout-function-calling-support | ||
| pass | ||
|
|
||
| @staticmethod | ||
| def format_tools_for_chat_completion(tools): | ||
| """Formats response tools format for OpenAI Chat Completion API. | ||
| Why we need this? | ||
| Ans: actionset.to_tool_description() in bgym only returns description | ||
| format valid for OpenAI Response API. | ||
| Args: | ||
| tools: List of tool descriptions to format for Chat Completion API. | ||
| Returns: | ||
| Formatted tools list compatible with OpenAI Chat Completion API, or None if tools is None. | ||
| """ | ||
| formatted_tools = None | ||
| if tools is not None: | ||
| formatted_tools = [ | ||
| { | ||
| "type": tool["type"], | ||
| "function": {k: tool[k] for k in ("name", "description", "parameters")}, | ||
| } | ||
| for tool in tools | ||
| ] | ||
| return formatted_tools | ||
|
|
||
|
|
||
| class LiteLLMAPIMessageBuilder(OpenAIChatCompletionAPIMessageBuilder): | ||
| """Message builder for LiteLLM API, extending OpenAIChatCompletionAPIMessageBuilder.""" | ||
|
|
||
| def prepare_message(self, use_only_first_toolcall: bool = False) -> List[Message]: | ||
| """Prepare the message for the OpenAI API.""" | ||
| content = [] | ||
| for item in self.content: | ||
| content.append(self.convert_content_to_expected_format(item)) | ||
| output = [{"role": self.role, "content": content}] | ||
| return output if self.role != "tool" else self.handle_tool_call(use_only_first_toolcall) | ||
|
|
||
| def handle_tool_call(self, use_only_first_toolcall: bool = False) -> List[Message]: | ||
| """Handle the tool call response from the last raw response.""" | ||
| if self.responded_tool_calls is None: | ||
| raise ValueError("No tool calls found in responded_tool_calls") | ||
| output = [] | ||
| raw_call = self.responded_tool_calls.raw_calls.choices[0].message # type: ignore | ||
| if use_only_first_toolcall: | ||
| raw_call.tool_calls = raw_call.tool_calls[:1] | ||
| output.append(raw_call) # add raw calls to output | ||
| for fn_call in self.responded_tool_calls: | ||
| raw_call = fn_call.raw_call | ||
| assert ( | ||
| "image" not in fn_call.tool_response | ||
| ), "Image output is not supported in function calls response." | ||
| # a function_call_output dict has keys "role", "tool_call_id" and "content" | ||
| tool_call_reponse = { | ||
| "name": raw_call["function"]["name"], # required with OpenRouter | ||
| "role": "tool", | ||
| "tool_call_id": raw_call["id"], | ||
| "content": self.convert_content_to_expected_format(fn_call.tool_response)["text"], | ||
| } | ||
| output.append(tool_call_reponse) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| @dataclass | ||
| class LiteLLMModelArgs(BaseModelArgs): | ||
| """Serializable arguments for LiteLMMModel.""" | ||
|
|
||
| api = "openai" # tool description format used by actionset.to_tool_description() in bgym | ||
| base_url: Optional[str] = None | ||
| api_key: Optional[str] = None | ||
| use_only_first_toolcall: bool = False | ||
|
|
||
| def make_model(self): | ||
| return LiteLLMModel( | ||
| model_name=self.model_name, | ||
| base_url=self.base_url, | ||
| api_key=self.api_key, | ||
| max_tokens=self.max_new_tokens, | ||
| temperature=self.temperature, | ||
| use_only_first_toolcall=self.use_only_first_toolcall, | ||
| ) | ||
|
|
||
| def get_message_builder(self) -> Type[MessageBuilder]: | ||
| """Returns a message builder for the LiteLMMModel.""" | ||
| return LiteLLMAPIMessageBuilder | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| """ | ||
| Some simple tests to run the LiteLLMModel with different models. | ||
| """ | ||
|
|
||
| import os | ||
|
|
||
| from agentlab.agents.tool_use_agent import DEFAULT_PROMPT_CONFIG, ToolUseAgentArgs | ||
| from agentlab.experiments.study import Study | ||
| from agentlab.llm.litellm_api import LiteLLMModelArgs | ||
|
|
||
| os.environ["LITELLM_LOG"] = "WARNING" | ||
|
|
||
| def get_agent(model_name: str) -> ToolUseAgentArgs: | ||
| return ToolUseAgentArgs( | ||
| model_args=LiteLLMModelArgs( | ||
| model_name=model_name, | ||
| max_new_tokens=2000, | ||
| temperature=None, | ||
| ), | ||
| config=DEFAULT_PROMPT_CONFIG, | ||
| ) | ||
|
|
||
| models = [ | ||
| "openai/gpt-4.1", | ||
| "openai/gpt-4.1-mini", | ||
| "openai/gpt-4.1-nano", | ||
| "openai/o3-2025-04-16", | ||
| "anthropic/claude-3-7-sonnet-20250219", | ||
| "anthropic/claude-sonnet-4-20250514", | ||
| ## Add more models to test. | ||
| ] | ||
| agent_args = [get_agent(model) for model in models] | ||
|
|
||
| study = Study(agent_args, "miniwob_tiny_test", logging_level_stdout=logging.WARNING) | ||
| study.run( | ||
| n_jobs=5, | ||
| parallel_backend="ray", | ||
| strict_reproducibility=False, | ||
| n_relaunch=3, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| 3. Factory classes (inherits from BaseModelArgs) for creating instances of LLM Response models. | ||
| """ | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| ContentItem = Dict[str, Any] | ||
| Message = Dict[str, Union[str, List[ContentItem]]] | ||
|
|
@@ -388,10 +389,15 @@ class APIPayload: | |
| cache_complete_prompt: bool = ( | ||
| False # If True, will cache the complete prompt in the last message. | ||
| ) | ||
| reasoning_effort: Literal["low", "medium", "high"] | None = None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused Reasoning Effort Parameter
Tell me moreWhat is the issue?The 'reasoning_effort' parameter is introduced but not utilized in any of the response models (OpenAI, Claude, or OpenAIChatCompletion). Why this mattersAdding unused parameters creates confusion and technical debt. The code suggests this parameter is only for LiteLLM API, but the implementation is missing. Suggested change ∙ Feature PreviewEither remove the parameter until LiteLLM integration is implemented, or implement it in the _call_api methods. For example: def _call_api(self, payload: APIPayload) -> Any:
api_params = {...}
if payload.reasoning_effort is not None:
api_params['reasoning_effort'] = payload.reasoning_effort
# rest of the codeProvide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| def __post_init__(self): | ||
| if self.tool_choice and self.force_call_tool: | ||
| raise ValueError("tool_choice and force_call_tool are mutually exclusive") | ||
| if self.reasoning_effort is not None: | ||
| logger.info( | ||
| "In agentlab reasoning_effort is used by LiteLLM API only. We will eventually shift to LiteLLM API for all LLMs." | ||
| ) | ||
|
|
||
|
|
||
| # # Base class for all API Endpoints | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded Configuration Parameter
Tell me more
What is the issue?
The action_space_as_tools parameter is hardcoded to True but commented as 'should be a config', without being included in the initialization parameters.
Why this matters
Users cannot configure this setting, forcing them to use tools-based action space even if they need text-based actions.
Suggested change ∙ Feature Preview
Add the parameter to the constructor:
Provide feedback to improve future suggestions
💬 Looking for more details? Reply to this comment to chat with Korbit.