From 30af2f1d48454355c4f8256703eba016d0f0c3ef Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Wed, 21 Jan 2026 01:03:08 +0530 Subject: [PATCH 1/6] fix: Extract grounding_metadata from Live API server_content Fixes #3542 Extract grounding_metadata from message.server_content.grounding_metadata in the Live API receive() method and include it in LlmResponse events. This allows VertexAiSearchTool grounding data to be accessible to agents. --- .../adk/models/gemini_llm_connection.py | 49 ++++++-- .../models/test_gemini_llm_connection.py | 113 ++++++++++++++++++ 2 files changed, 155 insertions(+), 7 deletions(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 158a5cabc1..495491a0b6 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -139,7 +139,11 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response(self, text: str): + def __build_full_text_response( + self, + text: str, + grounding_metadata: types.GroundingMetadata | None = None, + ): """Builds a full text response. The text should not partial and the returned LlmResponse is not be @@ -147,6 +151,7 @@ def __build_full_text_response(self, text: str): Args: text: The text to be included in the response. + grounding_metadata: Optional grounding metadata to include. Returns: An LlmResponse containing the full text. @@ -156,6 +161,7 @@ def __build_full_text_response(self, text: str): role='model', parts=[types.Part.from_text(text=text)], ), + grounding_metadata=grounding_metadata, ) async def receive(self) -> AsyncGenerator[LlmResponse, None]: @@ -166,6 +172,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: """ text = '' + last_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -179,17 +186,36 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn + # Extract grounding_metadata from server_content (for VertexAiSearchTool, etc.) + grounding_metadata = message.server_content.grounding_metadata + if grounding_metadata: + last_grounding_metadata = grounding_metadata + # Warn if grounding_metadata is incomplete (has queries but no chunks) + # This helps identify backend issues with Vertex AI Search + if ( + grounding_metadata.retrieval_queries + and not grounding_metadata.grounding_chunks + ): + logger.warning( + 'Incomplete grounding_metadata received: retrieval_queries=%s ' + 'but grounding_chunks is empty. This may indicate a transient ' + 'issue with the Vertex AI Search backend.', + grounding_metadata.retrieval_queries, + ) if content and content.parts: llm_response = LlmResponse( - content=content, interrupted=message.server_content.interrupted + content=content, + interrupted=message.server_content.interrupted, + grounding_metadata=grounding_metadata, ) if content.parts[0].text: text += content.parts[0].text llm_response.partial = True # don't yield the merged text event when receiving audio data elif text and not content.parts[0].inline_data: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response(text, last_grounding_metadata) text = '' + last_grounding_metadata = None yield llm_response # Note: in some cases, tool_call may arrive before # generation_complete, causing transcription to appear after @@ -266,12 +292,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: self._output_transcription_text = '' if message.server_content.turn_complete: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response(text, last_grounding_metadata) text = '' + last_grounding_metadata = None yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, + grounding_metadata=last_grounding_metadata, ) + last_grounding_metadata = None # Reset after yielding break # in case of empty content or parts, we sill surface it # in case it's an interrupted message, we merge the previous partial @@ -279,14 +308,20 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # safety threshold is triggered if message.server_content.interrupted: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response(text, last_grounding_metadata) text = '' + last_grounding_metadata = None else: - yield LlmResponse(interrupted=message.server_content.interrupted) + yield LlmResponse( + interrupted=message.server_content.interrupted, + grounding_metadata=last_grounding_metadata, + ) + last_grounding_metadata = None # Reset after yielding if message.tool_call: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response(text, last_grounding_metadata) text = '' + last_grounding_metadata = None parts = [ types.Part(function_call=function_call) for function_call in message.tool_call.function_calls diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index ac65b2ac2a..d3e9d9bb9f 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -202,6 +202,7 @@ async def test_receive_usage_metadata_and_server_content( mock_server_content.input_transcription = None mock_server_content.output_transcription = None mock_server_content.turn_complete = False + mock_server_content.grounding_metadata = None mock_message = mock.AsyncMock() mock_message.usage_metadata = usage_metadata @@ -261,6 +262,7 @@ async def test_receive_transcript_finished_on_interrupt( message1.server_content.output_transcription = None message1.server_content.turn_complete = False message1.server_content.generation_complete = False + message1.server_content.grounding_metadata = None message1.tool_call = None message1.session_resumption_update = None @@ -275,6 +277,7 @@ async def test_receive_transcript_finished_on_interrupt( ) message2.server_content.turn_complete = False message2.server_content.generation_complete = False + message2.server_content.grounding_metadata = None message2.tool_call = None message2.session_resumption_update = None @@ -287,6 +290,7 @@ async def test_receive_transcript_finished_on_interrupt( message3.server_content.output_transcription = None message3.server_content.turn_complete = False message3.server_content.generation_complete = False + message3.server_content.grounding_metadata = None message3.tool_call = None message3.session_resumption_update = None @@ -408,6 +412,7 @@ async def test_receive_transcript_finished_on_turn_complete( message1.server_content.output_transcription = None message1.server_content.turn_complete = False message1.server_content.generation_complete = False + message1.server_content.grounding_metadata = None message1.tool_call = None message1.session_resumption_update = None @@ -422,6 +427,7 @@ async def test_receive_transcript_finished_on_turn_complete( ) message2.server_content.turn_complete = False message2.server_content.generation_complete = False + message2.server_content.grounding_metadata = None message2.tool_call = None message2.session_resumption_update = None @@ -434,6 +440,7 @@ async def test_receive_transcript_finished_on_turn_complete( message3.server_content.output_transcription = None message3.server_content.turn_complete = True message3.server_content.generation_complete = False + message3.server_content.grounding_metadata = None message3.tool_call = None message3.session_resumption_update = None @@ -774,3 +781,109 @@ async def test_send_history_filters_various_audio_mime_types( # No content should be sent since the only part is audio mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_receive_extracts_grounding_metadata( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is extracted from server_content and included in LlmResponse.""" + mock_content = types.Content( + role='model', parts=[types.Part.from_text(text='response text')] + ) + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + web_search_queries=['web search query'], + ) + + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = True + mock_server_content.generation_complete = False + mock_server_content.grounding_metadata = mock_grounding_metadata + + mock_message = mock.Mock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Should have at least 2 responses: content with grounding and turn_complete + assert len(responses) >= 2 + + # Find response with content + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.grounding_metadata == mock_grounding_metadata + assert content_response.grounding_metadata.retrieval_queries == ['test query'] + assert content_response.grounding_metadata.web_search_queries == [ + 'web search query' + ] + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_at_turn_complete( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is included in turn_complete response if no text was built.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with grounding but no content + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = None + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with turn_complete + mock_server_content2 = mock.Mock() + mock_server_content2.model_turn = None + mock_server_content2.interrupted = False + mock_server_content2.input_transcription = None + mock_server_content2.output_transcription = None + mock_server_content2.turn_complete = True + mock_server_content2.generation_complete = False + mock_server_content2.grounding_metadata = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock_server_content2 + message2.tool_call = None + message2.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find turn_complete response + turn_complete_response = next((r for r in responses if r.turn_complete), None) + assert turn_complete_response is not None + # The grounding_metadata should be carried over to turn_complete + assert turn_complete_response.grounding_metadata == mock_grounding_metadata From b0bd310b9e69ad8290783e79be5e2a5261aded7b Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Wed, 21 Jan 2026 01:16:03 +0530 Subject: [PATCH 2/6] Address Gemini Code Assist review feedback - Fix critical bug: Remove premature reset of last_grounding_metadata before turn_complete response to prevent data loss - Simplify duplicate reset logic in interrupted handling - Add grounding_metadata propagation to tool_call responses - Add test for grounding_metadata with text content + turn_complete - Add test for grounding_metadata with tool_call responses All 27 tests pass. --- .../adk/models/gemini_llm_connection.py | 11 +- .../models/test_gemini_llm_connection.py | 104 ++++++++++++++++++ 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 495491a0b6..d9880b9328 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -294,7 +294,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if text: yield self.__build_full_text_response(text, last_grounding_metadata) text = '' - last_grounding_metadata = None yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, @@ -310,23 +309,25 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if text: yield self.__build_full_text_response(text, last_grounding_metadata) text = '' - last_grounding_metadata = None else: yield LlmResponse( interrupted=message.server_content.interrupted, grounding_metadata=last_grounding_metadata, ) - last_grounding_metadata = None # Reset after yielding + last_grounding_metadata = None # Reset after yielding if message.tool_call: if text: yield self.__build_full_text_response(text, last_grounding_metadata) text = '' - last_grounding_metadata = None parts = [ types.Part(function_call=function_call) for function_call in message.tool_call.function_calls ] - yield LlmResponse(content=types.Content(role='model', parts=parts)) + yield LlmResponse( + content=types.Content(role='model', parts=parts), + grounding_metadata=last_grounding_metadata, + ) + last_grounding_metadata = None # Reset after yielding if message.session_resumption_update: logger.debug('Received session resumption message: %s', message) yield ( diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d3e9d9bb9f..9c56cc388a 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -887,3 +887,107 @@ async def mock_receive_generator(): assert turn_complete_response is not None # The grounding_metadata should be carried over to turn_complete assert turn_complete_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_with_text_and_turn_complete( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is preserved when text content is followed by turn_complete.""" + mock_content = types.Content( + role='model', parts=[types.Part.from_text(text='response text')] + ) + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # Message with both content and grounding, followed by turn_complete + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = True + mock_server_content.generation_complete = False + mock_server_content.grounding_metadata = mock_grounding_metadata + + mock_message = mock.Mock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find content response with grounding + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.grounding_metadata == mock_grounding_metadata + + # Find turn_complete response - should also have grounding_metadata + turn_complete_response = next((r for r in responses if r.turn_complete), None) + assert turn_complete_response is not None + assert turn_complete_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_with_tool_call( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is propagated with tool_call responses.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with grounding metadata + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = None + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with tool_call + mock_function_call = types.FunctionCall( + name='test_function', args={'param': 'value'} + ) + mock_tool_call = mock.Mock() + mock_tool_call.function_calls = [mock_function_call] + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = None + message2.tool_call = mock_tool_call + message2.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find tool_call response + tool_call_response = next( + (r for r in responses if r.content and r.content.parts[0].function_call), + None, + ) + assert tool_call_response is not None + # The grounding_metadata should be carried over to tool_call + assert tool_call_response.grounding_metadata == mock_grounding_metadata From 2fabdb4c37619f7f7800dcd7cd493a18dc478916 Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:16:48 +0530 Subject: [PATCH 3/6] fix: Apply autoformat to gemini_llm_connection.py Run autoformat.sh to fix formatting issues as requested in PR review. --- src/google/adk/models/gemini_llm_connection.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index b2c552d5d6..644e3de363 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -197,9 +197,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: and not grounding_metadata.grounding_chunks ): logger.warning( - 'Incomplete grounding_metadata received: retrieval_queries=%s ' - 'but grounding_chunks is empty. This may indicate a transient ' - 'issue with the Vertex AI Search backend.', + 'Incomplete grounding_metadata received: retrieval_queries=%s' + ' but grounding_chunks is empty. This may indicate a' + ' transient issue with the Vertex AI Search backend.', grounding_metadata.retrieval_queries, ) if content and content.parts: @@ -213,7 +213,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: llm_response.partial = True # don't yield the merged text event when receiving audio data elif text and not content.parts[0].inline_data: - yield self.__build_full_text_response(text, last_grounding_metadata) + yield self.__build_full_text_response( + text, last_grounding_metadata + ) text = '' last_grounding_metadata = None yield llm_response @@ -292,7 +294,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: self._output_transcription_text = '' if message.server_content.turn_complete: if text: - yield self.__build_full_text_response(text, last_grounding_metadata) + yield self.__build_full_text_response( + text, last_grounding_metadata + ) text = '' yield LlmResponse( turn_complete=True, @@ -307,7 +311,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # safety threshold is triggered if message.server_content.interrupted: if text: - yield self.__build_full_text_response(text, last_grounding_metadata) + yield self.__build_full_text_response( + text, last_grounding_metadata + ) text = '' else: yield LlmResponse( From 8e9fee6ff64f69e45c7444a024ec84d957787a0a Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Fri, 23 Jan 2026 15:17:53 +0530 Subject: [PATCH 4/6] fix: Don't reset grounding_metadata after tool_call tool_call is part of an ongoing turn, not a terminal event. Removing the premature reset of last_grounding_metadata ensures subsequent messages in the same turn (like another tool_call or turn_complete) retain the grounding information. --- src/google/adk/models/gemini_llm_connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 644e3de363..2d4d109b5f 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -333,7 +333,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: content=types.Content(role='model', parts=parts), grounding_metadata=last_grounding_metadata, ) - last_grounding_metadata = None # Reset after yielding if message.session_resumption_update: logger.debug('Received session resumption message: %s', message) yield ( From dfe53f757df65a207378663eac62f793cd90e90c Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:52:08 +0530 Subject: [PATCH 5/6] fix: Resolve HIGH priority grounding_metadata state management issues - Add interrupted parameter to __build_full_text_response to preserve interrupted signal when flushing pending text - Pass interrupted flag in turn_complete and interrupted blocks - Remove premature reset of last_grounding_metadata after interrupted (not a terminal event) - Add documentation for tool_call metadata persistence design decision Addresses review comments: - HIGH: Lost interrupted signal in full text response - HIGH: Premature reset after interrupted - MEDIUM: Duplicate reset logic (simplified by removing premature reset) --- src/google/adk/models/gemini_llm_connection.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 2d4d109b5f..d770a9ce48 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -143,6 +143,7 @@ def __build_full_text_response( self, text: str, grounding_metadata: types.GroundingMetadata | None = None, + interrupted: bool = False, ): """Builds a full text response. @@ -152,6 +153,7 @@ def __build_full_text_response( Args: text: The text to be included in the response. grounding_metadata: Optional grounding metadata to include. + interrupted: Whether this response was interrupted. Returns: An LlmResponse containing the full text. @@ -162,6 +164,7 @@ def __build_full_text_response( parts=[types.Part.from_text(text=text)], ), grounding_metadata=grounding_metadata, + interrupted=interrupted, ) async def receive(self) -> AsyncGenerator[LlmResponse, None]: @@ -295,7 +298,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if message.server_content.turn_complete: if text: yield self.__build_full_text_response( - text, last_grounding_metadata + text, + last_grounding_metadata, + interrupted=message.server_content.interrupted, ) text = '' yield LlmResponse( @@ -312,7 +317,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: if message.server_content.interrupted: if text: yield self.__build_full_text_response( - text, last_grounding_metadata + text, last_grounding_metadata, interrupted=True ) text = '' else: @@ -320,7 +325,6 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: interrupted=message.server_content.interrupted, grounding_metadata=last_grounding_metadata, ) - last_grounding_metadata = None # Reset after yielding if message.tool_call: if text: yield self.__build_full_text_response(text, last_grounding_metadata) @@ -333,6 +337,10 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: content=types.Content(role='model', parts=parts), grounding_metadata=last_grounding_metadata, ) + # Note: last_grounding_metadata is NOT reset here because tool_call + # is part of an ongoing turn. The metadata persists until turn_complete + # or interrupted with break, ensuring subsequent messages in the same + # turn can access the grounding information. if message.session_resumption_update: logger.debug('Received session resumption message: %s', message) yield ( From 0cd44c828a120f9ee940cd00b518ce100fe0317b Mon Sep 17 00:00:00 2001 From: Vedant Madane <6527493+VedantMadane@users.noreply.github.com> Date: Fri, 23 Jan 2026 16:53:28 +0530 Subject: [PATCH 6/6] test: Add test for interrupted signal preservation with pending text Add test_receive_interrupted_with_pending_text_preserves_flag to verify: - interrupted flag is preserved when flushing pending text - grounding_metadata is carried through to the flushed response - accumulated text is properly merged before interruption Addresses MEDIUM priority review comment about missing test coverage for edge cases. --- .../models/test_gemini_llm_connection.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 44f6686b1d..c3d507ca26 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -991,3 +991,92 @@ async def mock_receive_generator(): assert tool_call_response is not None # The grounding_metadata should be carried over to tool_call assert tool_call_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_interrupted_with_pending_text_preserves_flag( + gemini_connection, mock_gemini_session +): + """Test that interrupted flag is preserved when flushing pending text.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with text content and grounding + mock_content1 = types.Content( + role='model', parts=[types.Part.from_text(text='partial')] + ) + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = mock_content1 + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with more text + mock_content2 = types.Content( + role='model', parts=[types.Part.from_text(text=' text')] + ) + mock_server_content2 = mock.Mock() + mock_server_content2.model_turn = mock_content2 + mock_server_content2.interrupted = False + mock_server_content2.input_transcription = None + mock_server_content2.output_transcription = None + mock_server_content2.turn_complete = False + mock_server_content2.generation_complete = False + mock_server_content2.grounding_metadata = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock_server_content2 + message2.tool_call = None + message2.session_resumption_update = None + + # Third message with interrupted signal + mock_server_content3 = mock.Mock() + mock_server_content3.model_turn = None + mock_server_content3.interrupted = True + mock_server_content3.input_transcription = None + mock_server_content3.output_transcription = None + mock_server_content3.turn_complete = False + mock_server_content3.generation_complete = False + mock_server_content3.grounding_metadata = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock_server_content3 + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find the full text response that should have been flushed with interrupted=True + full_text_responses = [ + r for r in responses if r.content and not r.partial and r.interrupted + ] + assert ( + len(full_text_responses) > 0 + ), 'Should have interrupted full text response' + + # The full text response should have the accumulated text + assert full_text_responses[0].content.parts[0].text == 'partial text' + # And should carry the grounding_metadata + assert full_text_responses[0].grounding_metadata == mock_grounding_metadata + # And should have interrupted=True + assert full_text_responses[0].interrupted is True