Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 191 additions & 62 deletions src/strands/models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,75 +269,29 @@ async def stream(
)
logger.debug("request=<%s>", request)

logger.debug("invoking model")
try:
if kwargs.get("stream") is False:
raise ValueError("stream parameter cannot be explicitly set to False")
response = await litellm.acompletion(**self.client_args, **request)
except ContextWindowExceededError as e:
logger.warning("litellm client raised context window overflow")
raise ContextWindowOverflowException(e) from e
# Check if streaming is disabled in the params
config = self.get_config()
params = config.get("params") or {}
is_streaming = params.get("stream", True)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})
litellm_request = {**request}

tool_calls: dict[int, list[Any]] = {}
data_type: str | None = None
litellm_request["stream"] = is_streaming

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]
logger.debug("invoking model with stream=%s", litellm_request.get("stream"))

if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
try:
if is_streaming:
async for chunk in self._handle_streaming_response(litellm_request):
yield chunk

yield self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": data_type,
"data": choice.delta.reasoning_content,
}
)

if choice.delta.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
else:
async for chunk in self._handle_non_streaming_response(litellm_request):
yield chunk
except ContextWindowExceededError as e:
logger.warning("litellm client raised context window overflow")
raise ContextWindowOverflowException(e) from e

yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content}
)

for tool_call in choice.delta.tool_calls or []:
tool_calls.setdefault(tool_call.index, []).append(tool_call)

if choice.finish_reason:
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event

if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")
logger.debug("finished processing response from model")

@override
async def structured_output(
Expand Down Expand Up @@ -422,6 +376,181 @@ async def _structured_output_using_tool(
except (json.JSONDecodeError, TypeError, ValueError) as e:
raise ValueError(f"Failed to parse or load content into model: {e}") from e

async def _process_choice_content(
self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True
) -> AsyncGenerator[tuple[str | None, StreamEvent], None]:
"""Process content from a choice object (streaming or non-streaming).

Args:
choice: The choice object from the response.
data_type: Current data type being processed.
tool_calls: Dictionary to collect tool calls.
is_streaming: Whether this is from a streaming response.

Yields:
Tuples of (updated_data_type, stream_event).
"""
# Get the content source - this is the only difference between streaming/non-streaming
# We use duck typing here: both choice.delta and choice.message have the same interface
# (reasoning_content, content, tool_calls attributes) but different object structures
content_source = choice.delta if is_streaming else choice.message

# Process reasoning content
if hasattr(content_source, "reasoning_content") and content_source.reasoning_content:
chunks, data_type = self._stream_switch_content("reasoning_content", data_type)
for chunk in chunks:
yield data_type, chunk
chunk = self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "reasoning_content",
"data": content_source.reasoning_content,
}
)
yield data_type, chunk

# Process text content
if hasattr(content_source, "content") and content_source.content:
chunks, data_type = self._stream_switch_content("text", data_type)
for chunk in chunks:
yield data_type, chunk
chunk = self.format_chunk(
{
"chunk_type": "content_delta",
"data_type": "text",
"data": content_source.content,
}
)
yield data_type, chunk

# Process tool calls
if hasattr(content_source, "tool_calls") and content_source.tool_calls:
if is_streaming:
# Streaming: tool calls have index attribute for out-of-order delivery
for tool_call in content_source.tool_calls:
tool_calls.setdefault(tool_call.index, []).append(tool_call)
else:
# Non-streaming: tool calls arrive in order, use enumerated index
for i, tool_call in enumerate(content_source.tool_calls):
tool_calls.setdefault(i, []).append(tool_call)

async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]:
"""Process and yield tool call events.

Args:
tool_calls: Dictionary of tool calls indexed by their position.

Yields:
Formatted tool call chunks.
"""
for tool_deltas in tool_calls.values():
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]})

for tool_delta in tool_deltas:
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta})

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"})

async def _handle_non_streaming_response(
self, litellm_request: dict[str, Any]
) -> AsyncGenerator[StreamEvent, None]:
"""Handle non-streaming response from LiteLLM.

Args:
litellm_request: The formatted request for LiteLLM.

Yields:
Formatted message chunks from the model.
"""
response = await litellm.acompletion(**self.client_args, **litellm_request)

logger.debug("got non-streaming response from model")
yield self.format_chunk({"chunk_type": "message_start"})

tool_calls: dict[int, list[Any]] = {}
data_type: str | None = None
finish_reason: str | None = None

if hasattr(response, "choices") and response.choices and len(response.choices) > 0:
choice = response.choices[0]

if hasattr(choice, "message") and choice.message:
# Process content using shared logic
async for updated_data_type, chunk in self._process_choice_content(
choice, data_type, tool_calls, is_streaming=False
):
data_type = updated_data_type
yield chunk

if hasattr(choice, "finish_reason"):
finish_reason = choice.finish_reason

# Stop the current content block if we have one
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})

# Process tool calls
async for chunk in self._process_tool_calls(tool_calls):
yield chunk

yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})

# Add usage information if available
if hasattr(response, "usage"):
yield self.format_chunk({"chunk_type": "metadata", "data": response.usage})

async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]:
"""Handle streaming response from LiteLLM.

Args:
litellm_request: The formatted request for LiteLLM.

Yields:
Formatted message chunks from the model.
"""
# For streaming, use the streaming API
response = await litellm.acompletion(**self.client_args, **litellm_request)

logger.debug("got response from model")
yield self.format_chunk({"chunk_type": "message_start"})

tool_calls: dict[int, list[Any]] = {}
data_type: str | None = None
finish_reason: str | None = None

async for event in response:
# Defensive: skip events with empty or missing choices
if not getattr(event, "choices", None):
continue
choice = event.choices[0]

# Process content using shared logic
async for updated_data_type, chunk in self._process_choice_content(
choice, data_type, tool_calls, is_streaming=True
):
data_type = updated_data_type
yield chunk

if choice.finish_reason:
finish_reason = choice.finish_reason
if data_type:
yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type})
break

# Process tool calls
async for chunk in self._process_tool_calls(tool_calls):
yield chunk

yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason})

# Skip remaining events as we don't have use for anything except the final usage payload
async for event in response:
_ = event
if event.usage:
yield self.format_chunk({"chunk_type": "metadata", "data": event.usage})

logger.debug("finished streaming response from model")

def _apply_proxy_prefix(self) -> None:
"""Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True.

Expand Down
Loading
Loading