From 41870e488a61422ec52335c1978ae8a00ce39f0e Mon Sep 17 00:00:00 2001 From: Daniel Schleicher Date: Tue, 22 Jul 2025 14:41:19 +0200 Subject: [PATCH 1/4] fix: LiteLLM + disable streaming breaks for Azure OpenAI #477 --- src/strands/models/litellm.py | 172 ++++++++++++++++++--------- tests/strands/models/test_litellm.py | 79 ++++++++++++ 2 files changed, 194 insertions(+), 57 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1f1e999d2..1915cff94 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -269,75 +269,133 @@ async def stream( ) logger.debug("request=<%s>", request) - logger.debug("invoking model") - try: - if kwargs.get("stream") is False: - raise ValueError("stream parameter cannot be explicitly set to False") - response = await litellm.acompletion(**self.client_args, **request) - except ContextWindowExceededError as e: - logger.warning("litellm client raised context window overflow") - raise ContextWindowOverflowException(e) from e + # Check if streaming is disabled in the params + params = self.get_config().get("params", {}) + is_streaming = params.get("stream", True) - logger.debug("got response from model") - yield self.format_chunk({"chunk_type": "message_start"}) - - tool_calls: dict[int, list[Any]] = {} - data_type: str | None = None - - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - chunks, data_type = self._stream_switch_content("reasoning_content", data_type) - for chunk in chunks: - yield chunk - - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": data_type, - "data": choice.delta.reasoning_content, - } - ) + litellm_request = {**request} - if choice.delta.content: - chunks, data_type = self._stream_switch_content("text", data_type) - for chunk in chunks: - yield chunk + litellm_request["stream"] = is_streaming - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": data_type, "data": choice.delta.content} - ) + logger.debug("invoking model with stream=%s", litellm_request.get("stream")) + + if not is_streaming: + response = await litellm.acompletion(**self.client_args, **litellm_request) - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) + logger.debug("got non-streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - if choice.finish_reason: - if data_type: - yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) - break + tool_calls: dict[int, list[Any]] = {} + finish_reason = None - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + if hasattr(choice, "message") and choice.message: + if hasattr(choice.message, "content") and choice.message.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.message.content} + ) - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + if hasattr(choice.message, "reasoning_content") and choice.message.reasoning_content: + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.message.reasoning_content, + } + ) - yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) + if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: + for i, tool_call in enumerate(choice.message.tool_calls): + tool_calls.setdefault(i, []).append(tool_call) - # Skip remaining events as we don't have use for anything except the final usage payload - async for event in response: - _ = event + if hasattr(choice, "finish_reason"): + finish_reason = choice.finish_reason - if event.usage: - yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - logger.debug("finished streaming response from model") + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Add usage information if available + if hasattr(response, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) + else: + # For streaming, use the streaming API + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) + + tool_calls: dict[int, list[Any]] = {} + finish_reason = None + + try: + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + if choice.delta.content: + yield self.format_chunk( + {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + ) + + if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: + yield self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": choice.delta.reasoning_content, + } + ) + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + finish_reason = choice.finish_reason + break + except Exception as e: + logger.warning("Error processing streaming response: %s", e) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) + + # Process tool calls + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + try: + last_event = None + async for event in response: + last_event = event + + # Use the last event for usage information + if last_event and hasattr(last_event, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": last_event.usage}) + except Exception: + # If there's an error collecting remaining events, just continue + pass + + logger.debug("finished processing response from model") @override async def structured_output( diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 832b5c836..5cf31179c 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -478,3 +478,82 @@ def test_format_request_messages_cache_point_support(): ] assert result == expected +async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alist): + """Test LiteLLM model with streaming disabled (stream=False). + + This test verifies that the LiteLLM model works correctly when streaming is disabled, + which was the issue reported in GitHub issue #477. + """ + + mock_function = unittest.mock.Mock() + mock_function.name = "calculator" + mock_function.arguments = '{"expression": "123981723 + 234982734"}' + + mock_tool_call = unittest.mock.Mock(index=0, function=mock_function, id="tool_call_id_123") + + mock_message = unittest.mock.Mock() + mock_message.content = "I'll calculate that for you" + mock_message.reasoning_content = "Let me think about this calculation" + mock_message.tool_calls = [mock_tool_call] + + mock_choice = unittest.mock.Mock() + mock_choice.message = mock_message + mock_choice.finish_reason = "tool_calls" + + mock_response = unittest.mock.Mock() + mock_response.choices = [mock_choice] + mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) + + model = LiteLLMModel( + client_args={"api_key": api_key}, + model_id=model_id, + params={"stream": False}, # This is the key setting that was causing the #477 isuue + ) + + messages = [{"role": "user", "content": [{"type": "text", "text": "What is 123981723 + 234982734?"}]}] + response = model.stream(messages) + + tru_events = await alist(response) + + exp_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Let me think about this calculation"}}}}, + {"contentBlockStop": {}}, + { + "contentBlockStart": { + "start": {"toolUse": {"name": "calculator", "toolUseId": mock_message.tool_calls[0].id}} + } + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "123981723 + 234982734"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use"}}, + { + "metadata": { + "usage": { + "inputTokens": 10, + "outputTokens": 20, + "totalTokens": 30, + }, + "metrics": {"latencyMs": 0}, + } + }, + ] + + assert len(tru_events) == len(exp_events) + + for i, (tru, exp) in enumerate(zip(tru_events, exp_events, strict=False)): + assert tru == exp, f"Event {i} mismatch: {tru} != {exp}" + + expected_request = { + "api_key": api_key, + "model": model_id, + "messages": [{"role": "user", "content": [{"text": "What is 123981723 + 234982734?", "type": "text"}]}], + "stream": False, # Verify that stream=False was passed to litellm + "stream_options": {"include_usage": True}, + "tools": [], + } + litellm_acompletion.assert_called_once_with(**expected_request) From 483b3d505a2914fd610905e12f1deec1c7c83ca8 Mon Sep 17 00:00:00 2001 From: Daniel Schleicher Date: Tue, 22 Jul 2025 15:17:30 +0200 Subject: [PATCH 2/4] Fixed errors from hatch run prepare --- src/strands/models/litellm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 1915cff94..3b56ab3aa 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -270,7 +270,8 @@ async def stream( logger.debug("request=<%s>", request) # Check if streaming is disabled in the params - params = self.get_config().get("params", {}) + config = self.get_config() + params = config.get("params") or {} is_streaming = params.get("stream", True) litellm_request = {**request} @@ -337,7 +338,7 @@ async def stream( yield self.format_chunk({"chunk_type": "message_start"}) yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - tool_calls: dict[int, list[Any]] = {} + streaming_tool_calls: dict[int, list[Any]] = {} finish_reason = None try: @@ -362,7 +363,7 @@ async def stream( ) for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) + streaming_tool_calls.setdefault(tool_call.index, []).append(tool_call) if choice.finish_reason: finish_reason = choice.finish_reason @@ -373,7 +374,7 @@ async def stream( yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) # Process tool calls - for tool_deltas in tool_calls.values(): + for tool_deltas in streaming_tool_calls.values(): yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) for tool_delta in tool_deltas: From 0ffc14cdc99dbe63e8f2e1d74fe8fe669aa01f19 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 5 Jan 2026 10:01:25 -0500 Subject: [PATCH 3/4] refactor: deduplicate litellm streaming and non streaming --- src/strands/models/litellm.py | 300 ++++++++++++++--------- tests/strands/models/test_litellm.py | 55 ++++- tests_integ/models/test_model_litellm.py | 30 ++- 3 files changed, 253 insertions(+), 132 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 3b56ab3aa..c120b0eda 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -280,121 +280,16 @@ async def stream( logger.debug("invoking model with stream=%s", litellm_request.get("stream")) - if not is_streaming: - response = await litellm.acompletion(**self.client_args, **litellm_request) - - logger.debug("got non-streaming response from model") - yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - - tool_calls: dict[int, list[Any]] = {} - finish_reason = None - - if hasattr(response, "choices") and response.choices and len(response.choices) > 0: - choice = response.choices[0] - - if hasattr(choice, "message") and choice.message: - if hasattr(choice.message, "content") and choice.message.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.message.content} - ) - - if hasattr(choice.message, "reasoning_content") and choice.message.reasoning_content: - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.message.reasoning_content, - } - ) - - if hasattr(choice.message, "tool_calls") and choice.message.tool_calls: - for i, tool_call in enumerate(choice.message.tool_calls): - tool_calls.setdefault(i, []).append(tool_call) - - if hasattr(choice, "finish_reason"): - finish_reason = choice.finish_reason - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - - for tool_deltas in tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - - yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) - - # Add usage information if available - if hasattr(response, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) - else: - # For streaming, use the streaming API - response = await litellm.acompletion(**self.client_args, **litellm_request) - - logger.debug("got streaming response from model") - yield self.format_chunk({"chunk_type": "message_start"}) - yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"}) - - streaming_tool_calls: dict[int, list[Any]] = {} - finish_reason = None - - try: - async for event in response: - # Defensive: skip events with empty or missing choices - if not getattr(event, "choices", None): - continue - choice = event.choices[0] - - if choice.delta.content: - yield self.format_chunk( - {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - ) - - if hasattr(choice.delta, "reasoning_content") and choice.delta.reasoning_content: - yield self.format_chunk( - { - "chunk_type": "content_delta", - "data_type": "reasoning_content", - "data": choice.delta.reasoning_content, - } - ) - - for tool_call in choice.delta.tool_calls or []: - streaming_tool_calls.setdefault(tool_call.index, []).append(tool_call) - - if choice.finish_reason: - finish_reason = choice.finish_reason - break - except Exception as e: - logger.warning("Error processing streaming response: %s", e) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"}) - - # Process tool calls - for tool_deltas in streaming_tool_calls.values(): - yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) - - for tool_delta in tool_deltas: - yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) - - yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) - - yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) - - try: - last_event = None - async for event in response: - last_event = event - - # Use the last event for usage information - if last_event and hasattr(last_event, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": last_event.usage}) - except Exception: - # If there's an error collecting remaining events, just continue - pass + try: + if is_streaming: + async for chunk in self._handle_streaming_response(litellm_request): + yield chunk + else: + async for chunk in self._handle_non_streaming_response(litellm_request): + yield chunk + except ContextWindowExceededError as e: + logger.warning("litellm client raised context window overflow") + raise ContextWindowOverflowException(e) from e logger.debug("finished processing response from model") @@ -481,6 +376,181 @@ async def _structured_output_using_tool( except (json.JSONDecodeError, TypeError, ValueError) as e: raise ValueError(f"Failed to parse or load content into model: {e}") from e + async def _process_choice_content( + self, choice: Any, data_type: str | None, tool_calls: dict[int, list[Any]], is_streaming: bool = True + ) -> AsyncGenerator[tuple[str | None, StreamEvent], None]: + """Process content from a choice object (streaming or non-streaming). + + Args: + choice: The choice object from the response. + data_type: Current data type being processed. + tool_calls: Dictionary to collect tool calls. + is_streaming: Whether this is from a streaming response. + + Yields: + Tuples of (updated_data_type, stream_event). + """ + # Get the content source - this is the only difference between streaming/non-streaming + # We use duck typing here: both choice.delta and choice.message have the same interface + # (reasoning_content, content, tool_calls attributes) but different object structures + content_source = choice.delta if is_streaming else choice.message + + # Process reasoning content + if hasattr(content_source, "reasoning_content") and content_source.reasoning_content: + chunks, data_type = self._stream_switch_content("reasoning_content", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "reasoning_content", + "data": content_source.reasoning_content, + } + ) + yield data_type, chunk + + # Process text content + if hasattr(content_source, "content") and content_source.content: + chunks, data_type = self._stream_switch_content("text", data_type) + for chunk in chunks: + yield data_type, chunk + chunk = self.format_chunk( + { + "chunk_type": "content_delta", + "data_type": "text", + "data": content_source.content, + } + ) + yield data_type, chunk + + # Process tool calls + if hasattr(content_source, "tool_calls") and content_source.tool_calls: + if is_streaming: + # Streaming: tool calls have index attribute for out-of-order delivery + for tool_call in content_source.tool_calls: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + else: + # Non-streaming: tool calls arrive in order, use enumerated index + for i, tool_call in enumerate(content_source.tool_calls): + tool_calls.setdefault(i, []).append(tool_call) + + async def _process_tool_calls(self, tool_calls: dict[int, list[Any]]) -> AsyncGenerator[StreamEvent, None]: + """Process and yield tool call events. + + Args: + tool_calls: Dictionary of tool calls indexed by their position. + + Yields: + Formatted tool call chunks. + """ + for tool_deltas in tool_calls.values(): + yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}) + + for tool_delta in tool_deltas: + yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}) + + yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool"}) + + async def _handle_non_streaming_response( + self, litellm_request: dict[str, Any] + ) -> AsyncGenerator[StreamEvent, None]: + """Handle non-streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got non-streaming response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + if hasattr(response, "choices") and response.choices and len(response.choices) > 0: + choice = response.choices[0] + + if hasattr(choice, "message") and choice.message: + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=False + ): + data_type = updated_data_type + yield chunk + + if hasattr(choice, "finish_reason"): + finish_reason = choice.finish_reason + + # Stop the current content block if we have one + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Add usage information if available + if hasattr(response, "usage"): + yield self.format_chunk({"chunk_type": "metadata", "data": response.usage}) + + async def _handle_streaming_response(self, litellm_request: dict[str, Any]) -> AsyncGenerator[StreamEvent, None]: + """Handle streaming response from LiteLLM. + + Args: + litellm_request: The formatted request for LiteLLM. + + Yields: + Formatted message chunks from the model. + """ + # For streaming, use the streaming API + response = await litellm.acompletion(**self.client_args, **litellm_request) + + logger.debug("got response from model") + yield self.format_chunk({"chunk_type": "message_start"}) + + tool_calls: dict[int, list[Any]] = {} + data_type: str | None = None + finish_reason: str | None = None + + async for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue + choice = event.choices[0] + + # Process content using shared logic + async for updated_data_type, chunk in self._process_choice_content( + choice, data_type, tool_calls, is_streaming=True + ): + data_type = updated_data_type + yield chunk + + if choice.finish_reason: + finish_reason = choice.finish_reason + if data_type: + yield self.format_chunk({"chunk_type": "content_stop", "data_type": data_type}) + break + + # Process tool calls + async for chunk in self._process_tool_calls(tool_calls): + yield chunk + + yield self.format_chunk({"chunk_type": "message_stop", "data": finish_reason}) + + # Skip remaining events as we don't have use for anything except the final usage payload + async for event in response: + _ = event + if event.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": event.usage}) + + logger.debug("finished streaming response from model") + def _apply_proxy_prefix(self) -> None: """Apply litellm_proxy/ prefix to model_id when use_litellm_proxy is True. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 5cf31179c..9ceaf59c3 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -285,7 +285,7 @@ async def test_stream_empty(litellm_acompletion, api_key, model_id, model, agene mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) - mock_event_3 = unittest.mock.Mock() + mock_event_3 = unittest.mock.Mock(usage=None) mock_event_4 = unittest.mock.Mock(usage=None) litellm_acompletion.side_effect = unittest.mock.AsyncMock( @@ -408,16 +408,6 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model pass -@pytest.mark.asyncio -async def test_stream_raises_error_when_stream_is_false(model): - """Test that stream raises ValueError when stream parameter is explicitly False.""" - messages = [{"role": "user", "content": [{"text": "test"}]}] - - with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): - async for _ in model.stream(messages, stream=False): - pass - - def test_format_request_messages_with_system_prompt_content(): """Test format_request_messages with system_prompt_content parameter.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}] @@ -478,6 +468,9 @@ def test_format_request_messages_cache_point_support(): ] assert result == expected + + +@pytest.mark.asyncio async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alist): """Test LiteLLM model with streaming disabled (stream=False). @@ -502,7 +495,15 @@ async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alis mock_response = unittest.mock.Mock() mock_response.choices = [mock_choice] - mock_response.usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Create a more explicit usage mock that doesn't have cache-related attributes + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 20 + mock_usage.total_tokens = 30 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + mock_response.usage = mock_usage litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=mock_response) @@ -520,9 +521,11 @@ async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alis exp_events = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Let me think about this calculation"}}}}, {"contentBlockStop": {}}, + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "I'll calculate that for you"}}}, + {"contentBlockStop": {}}, { "contentBlockStart": { "start": {"toolUse": {"name": "calculator", "toolUseId": mock_message.tool_calls[0].id}} @@ -557,3 +560,29 @@ async def test_stream_non_streaming(litellm_acompletion, api_key, model_id, alis "tools": [], } litellm_acompletion.assert_called_once_with(**expected_request) + + +@pytest.mark.asyncio +async def test_stream_path_validation(litellm_acompletion, api_key, model_id, model, agenerator, alist): + """Test that we're taking the correct streaming path and validate stream parameter.""" + mock_delta = unittest.mock.Mock(content=None, tool_calls=None, reasoning_content=None) + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(usage=None) + + litellm_acompletion.side_effect = unittest.mock.AsyncMock(return_value=agenerator([mock_event_1, mock_event_2])) + + messages = [{"role": "user", "content": []}] + response = model.stream(messages) + + # Consume the response + await alist(response) + + # Validate that litellm.acompletion was called with the expected parameters + call_args = litellm_acompletion.call_args + assert call_args is not None, "litellm.acompletion should have been called" + + # Check if stream parameter is being set + called_kwargs = call_args.kwargs + + # Validate we're going down the streaming path (should have stream=True) + assert called_kwargs.get("stream") is True, f"Expected stream=True, got {called_kwargs.get('stream')}" diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index d72937641..80e21bdfd 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -14,6 +14,16 @@ def model(): return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0") +@pytest.fixture +def streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": True}) + + +@pytest.fixture +def non_streaming_model(): + return LiteLLMModel(model_id="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", params={"stream": False}) + + @pytest.fixture def tools(): @strands.tool @@ -95,15 +105,21 @@ def lower(_, value): return Color(simple_color_name="yellow") -def test_agent_invoke(agent): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_agent_invoke(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() assert all(string in text for string in ["12:00", "sunny"]) +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_invoke_async(agent): +async def test_agent_invoke_async(model_fixture, tools, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model, tools=tools) result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() @@ -138,14 +154,20 @@ def test_agent_invoke_reasoning(agent, model): assert result.message["content"][0]["reasoningContent"]["reasoningText"]["text"] -def test_structured_output(agent, weather): +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) +def test_structured_output(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather +@pytest.mark.parametrize("model_fixture", ["streaming_model", "non_streaming_model"]) @pytest.mark.asyncio -async def test_agent_structured_output_async(agent, weather): +async def test_agent_structured_output_async(model_fixture, weather, request): + model = request.getfixturevalue(model_fixture) + agent = Agent(model=model) tru_weather = await agent.structured_output_async(type(weather), "The time is 12:00 and the weather is sunny") exp_weather = weather assert tru_weather == exp_weather From 98e57a1e4129312ec8f57aff4424b2074ed54782 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 5 Jan 2026 10:25:39 -0500 Subject: [PATCH 4/4] test: increase test coverage --- tests/strands/models/test_litellm.py | 125 +++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 9ceaf59c3..99df22a3f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -586,3 +586,128 @@ async def test_stream_path_validation(litellm_acompletion, api_key, model_id, mo # Validate we're going down the streaming path (should have stream=True) assert called_kwargs.get("stream") is True, f"Expected stream=True, got {called_kwargs.get('stream')}" + + +def test_format_request_message_content_reasoning(): + """Test formatting reasoning content.""" + content = {"reasoningContent": {"reasoningText": {"signature": "test_sig", "text": "test_thinking"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"signature": "test_sig", "thinking": "test_thinking", "type": "thinking"} + + assert result == expected + + +def test_format_request_message_content_video(): + """Test formatting video content.""" + content = {"video": {"source": {"bytes": "base64videodata"}}} + + result = LiteLLMModel.format_request_message_content(content) + expected = {"type": "video_url", "video_url": {"detail": "auto", "url": "base64videodata"}} + + assert result == expected + + +def test_apply_proxy_prefix_with_use_litellm_proxy(): + """Test _apply_proxy_prefix when use_litellm_proxy is True.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_already_has_prefix(): + """Test _apply_proxy_prefix when model_id already has prefix.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": True}, model_id="litellm_proxy/openai/gpt-4") + + # Should not add another prefix + assert model.get_config()["model_id"] == "litellm_proxy/openai/gpt-4" + + +def test_apply_proxy_prefix_disabled(): + """Test _apply_proxy_prefix when use_litellm_proxy is False.""" + model = LiteLLMModel(client_args={"use_litellm_proxy": False}, model_id="openai/gpt-4") + + assert model.get_config()["model_id"] == "openai/gpt-4" + + +def test_format_chunk_metadata_with_cache_tokens(): + """Test format_chunk for metadata with cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data with cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + + # Mock cache-related attributes + mock_tokens_details = unittest.mock.Mock() + mock_tokens_details.cached_tokens = 25 + mock_usage.prompt_tokens_details = mock_tokens_details + mock_usage.cache_creation_input_tokens = 10 + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert result["metadata"]["usage"]["cacheReadInputTokens"] == 25 + assert result["metadata"]["usage"]["cacheWriteInputTokens"] == 10 + + +def test_format_chunk_metadata_without_cache_tokens(): + """Test format_chunk for metadata without cache tokens.""" + model = LiteLLMModel(model_id="test") + + # Mock usage data without cache tokens + mock_usage = unittest.mock.Mock() + mock_usage.prompt_tokens = 100 + mock_usage.completion_tokens = 50 + mock_usage.total_tokens = 150 + mock_usage.prompt_tokens_details = None + mock_usage.cache_creation_input_tokens = None + + event = {"chunk_type": "metadata", "data": mock_usage} + + result = model.format_chunk(event) + + assert result["metadata"]["usage"]["inputTokens"] == 100 + assert result["metadata"]["usage"]["outputTokens"] == 50 + assert result["metadata"]["usage"]["totalTokens"] == 150 + assert "cacheReadInputTokens" not in result["metadata"]["usage"] + assert "cacheWriteInputTokens" not in result["metadata"]["usage"] + + +def test_stream_switch_content_same_type(): + """Test _stream_switch_content when data_type is the same as prev_data_type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "text") + + assert chunks == [] + assert data_type == "text" + + +def test_stream_switch_content_different_type_with_prev(): + """Test _stream_switch_content when switching from one type to another.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", "reasoning_content") + + assert len(chunks) == 2 + assert chunks[0]["contentBlockStop"] == {} + assert chunks[1]["contentBlockStart"] == {"start": {}} + assert data_type == "text" + + +def test_stream_switch_content_different_type_no_prev(): + """Test _stream_switch_content when switching to a type with no previous type.""" + model = LiteLLMModel(model_id="test") + + chunks, data_type = model._stream_switch_content("text", None) + + assert len(chunks) == 1 + assert chunks[0]["contentBlockStart"] == {"start": {}} + assert data_type == "text"