Skip to content

Commit b0c3cc6

Browse files
Lin-NikaidoGWeale
authored andcommitted
fix: Support file uploads for OpenAI/Azure in LiteLLM
This change expands the supported file MIME types and introduces provider-specific handling for file uploads. For providers like OpenAI and Azure, inline file data is now uploaded via `litellm.acreate_file` to obtain a `file_id`, which is then used in the message content. Other providers continue to use base64 encoded file data. Affected functions have been updated to be asynchronous Merge:#2863 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 839996848
1 parent bb8a269 commit b0c3cc6

File tree

2 files changed

+334
-72
lines changed

2 files changed

+334
-72
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 76 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,49 @@
8282
"content_filter": types.FinishReason.SAFETY,
8383
}
8484

85-
_SUPPORTED_FILE_CONTENT_MIME_TYPES = set(
86-
["application/pdf", "application/json"]
87-
)
85+
# File MIME types supported for upload as file content (not decoded as text).
86+
# Note: text/* types are handled separately and decoded as text content.
87+
# These types are uploaded as files to providers that support it.
88+
_SUPPORTED_FILE_CONTENT_MIME_TYPES = frozenset({
89+
# Documents
90+
"application/pdf",
91+
"application/msword", # .doc
92+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", # .docx
93+
"application/vnd.openxmlformats-officedocument.presentationml.presentation", # .pptx
94+
# Data formats
95+
"application/json",
96+
# Scripts (when not detected as text/*)
97+
"application/x-sh", # .sh (Python mimetypes returns this)
98+
})
99+
100+
# Providers that require file_id instead of inline file_data
101+
_FILE_ID_REQUIRED_PROVIDERS = frozenset({"openai", "azure"})
102+
103+
104+
def _get_provider_from_model(model: str) -> str:
105+
"""Extracts the provider name from a LiteLLM model string.
106+
107+
Args:
108+
model: The model string (e.g., "openai/gpt-4o", "azure/gpt-4").
109+
110+
Returns:
111+
The provider name or empty string if not determinable.
112+
"""
113+
if not model:
114+
return ""
115+
# LiteLLM uses "provider/model" format
116+
if "/" in model:
117+
provider, _ = model.split("/", 1)
118+
return provider.lower()
119+
# Fallback heuristics for common patterns
120+
model_lower = model.lower()
121+
if "azure" in model_lower:
122+
return "azure"
123+
# Note: The 'openai' check is based on current naming conventions (e.g., gpt-, o1).
124+
# This might need updates if OpenAI introduces new model families with different prefixes.
125+
if model_lower.startswith("gpt-") or model_lower.startswith("o1"):
126+
return "openai"
127+
return ""
88128

89129

90130
def _decode_inline_text_data(raw_bytes: bytes) -> str:
@@ -349,8 +389,10 @@ def _extract_cached_prompt_tokens(usage: Any) -> int:
349389
return 0
350390

351391

352-
def _content_to_message_param(
392+
async def _content_to_message_param(
353393
content: types.Content,
394+
*,
395+
provider: str = "",
354396
) -> Union[Message, list[Message]]:
355397
"""Converts a types.Content to a litellm Message or list of Messages.
356398
@@ -359,6 +401,7 @@ def _content_to_message_param(
359401
360402
Args:
361403
content: The content to convert.
404+
provider: The LLM provider name (e.g., "openai", "azure").
362405
363406
Returns:
364407
A litellm Message, a list of litellm Messages.
@@ -379,7 +422,7 @@ def _content_to_message_param(
379422

380423
# Handle user or assistant messages
381424
role = _to_litellm_role(content.role)
382-
message_content = _get_content(content.parts) or None
425+
message_content = await _get_content(content.parts, provider=provider) or None
383426

384427
if role == "user":
385428
return ChatCompletionUserMessage(role="user", content=message_content)
@@ -418,13 +461,16 @@ def _content_to_message_param(
418461
)
419462

420463

421-
def _get_content(
464+
async def _get_content(
422465
parts: Iterable[types.Part],
466+
*,
467+
provider: str = "",
423468
) -> Union[OpenAIMessageContent, str]:
424469
"""Converts a list of parts to litellm content.
425470
426471
Args:
427472
parts: The parts to convert.
473+
provider: The LLM provider name (e.g., "openai", "azure").
428474
429475
Returns:
430476
The litellm content.
@@ -474,10 +520,22 @@ def _get_content(
474520
"audio_url": {"url": data_uri},
475521
})
476522
elif part.inline_data.mime_type in _SUPPORTED_FILE_CONTENT_MIME_TYPES:
477-
content_objects.append({
478-
"type": "file",
479-
"file": {"file_data": data_uri},
480-
})
523+
# OpenAI/Azure require file_id from uploaded file, not inline data
524+
if provider in _FILE_ID_REQUIRED_PROVIDERS:
525+
file_response = await litellm.acreate_file(
526+
file=part.inline_data.data,
527+
purpose="assistants",
528+
custom_llm_provider=provider,
529+
)
530+
content_objects.append({
531+
"type": "file",
532+
"file": {"file_id": file_response.id},
533+
})
534+
else:
535+
content_objects.append({
536+
"type": "file",
537+
"file": {"file_data": data_uri},
538+
})
481539
else:
482540
raise ValueError(
483541
"LiteLlm(BaseLlm) does not support content part with MIME type "
@@ -954,7 +1012,7 @@ def _to_litellm_response_format(
9541012
}
9551013

9561014

957-
def _get_completion_inputs(
1015+
async def _get_completion_inputs(
9581016
llm_request: LlmRequest,
9591017
) -> Tuple[
9601018
List[Message],
@@ -971,10 +1029,15 @@ def _get_completion_inputs(
9711029
The litellm inputs (message list, tool dictionary, response format and
9721030
generation params).
9731031
"""
1032+
# Determine provider for file handling
1033+
provider = _get_provider_from_model(llm_request.model or "")
1034+
9741035
# 1. Construct messages
9751036
messages: List[Message] = []
9761037
for content in llm_request.contents or []:
977-
message_param_or_list = _content_to_message_param(content)
1038+
message_param_or_list = await _content_to_message_param(
1039+
content, provider=provider
1040+
)
9781041
if isinstance(message_param_or_list, list):
9791042
messages.extend(message_param_or_list)
9801043
elif message_param_or_list: # Ensure it's not None before appending
@@ -1240,7 +1303,7 @@ async def generate_content_async(
12401303
logger.debug(_build_request_log(llm_request))
12411304

12421305
messages, tools, response_format, generation_params = (
1243-
_get_completion_inputs(llm_request)
1306+
await _get_completion_inputs(llm_request)
12441307
)
12451308

12461309
if "functions" in self._additional_args:

0 commit comments

Comments
 (0)