diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 1e56c51a2b..d770a9ce48 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -139,7 +139,12 @@ async def send_realtime(self, input: RealtimeInput): else: raise ValueError('Unsupported input type: %s' % type(input)) - def __build_full_text_response(self, text: str): + def __build_full_text_response( + self, + text: str, + grounding_metadata: types.GroundingMetadata | None = None, + interrupted: bool = False, + ): """Builds a full text response. The text should not partial and the returned LlmResponse is not be @@ -147,6 +152,8 @@ def __build_full_text_response(self, text: str): Args: text: The text to be included in the response. + grounding_metadata: Optional grounding metadata to include. + interrupted: Whether this response was interrupted. Returns: An LlmResponse containing the full text. @@ -156,6 +163,8 @@ def __build_full_text_response(self, text: str): role='model', parts=[types.Part.from_text(text=text)], ), + grounding_metadata=grounding_metadata, + interrupted=interrupted, ) async def receive(self) -> AsyncGenerator[LlmResponse, None]: @@ -166,6 +175,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: """ text = '' + last_grounding_metadata = None async with Aclosing(self._gemini_session.receive()) as agen: # TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate # partial content and emit responses as needed. @@ -179,17 +189,38 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ) if message.server_content: content = message.server_content.model_turn + # Extract grounding_metadata from server_content (for VertexAiSearchTool, etc.) + grounding_metadata = message.server_content.grounding_metadata + if grounding_metadata: + last_grounding_metadata = grounding_metadata + # Warn if grounding_metadata is incomplete (has queries but no chunks) + # This helps identify backend issues with Vertex AI Search + if ( + grounding_metadata.retrieval_queries + and not grounding_metadata.grounding_chunks + ): + logger.warning( + 'Incomplete grounding_metadata received: retrieval_queries=%s' + ' but grounding_chunks is empty. This may indicate a' + ' transient issue with the Vertex AI Search backend.', + grounding_metadata.retrieval_queries, + ) if content and content.parts: llm_response = LlmResponse( - content=content, interrupted=message.server_content.interrupted + content=content, + interrupted=message.server_content.interrupted, + grounding_metadata=grounding_metadata, ) if content.parts[0].text: text += content.parts[0].text llm_response.partial = True # don't yield the merged text event when receiving audio data elif text and not content.parts[0].inline_data: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, last_grounding_metadata + ) text = '' + last_grounding_metadata = None yield llm_response # Note: in some cases, tool_call may arrive before # generation_complete, causing transcription to appear after @@ -266,12 +297,18 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: self._output_transcription_text = '' if message.server_content.turn_complete: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, + last_grounding_metadata, + interrupted=message.server_content.interrupted, + ) text = '' yield LlmResponse( turn_complete=True, interrupted=message.server_content.interrupted, + grounding_metadata=last_grounding_metadata, ) + last_grounding_metadata = None # Reset after yielding break # in case of empty content or parts, we sill surface it # in case it's an interrupted message, we merge the previous partial @@ -279,19 +316,31 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: # safety threshold is triggered if message.server_content.interrupted: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response( + text, last_grounding_metadata, interrupted=True + ) text = '' else: - yield LlmResponse(interrupted=message.server_content.interrupted) + yield LlmResponse( + interrupted=message.server_content.interrupted, + grounding_metadata=last_grounding_metadata, + ) if message.tool_call: if text: - yield self.__build_full_text_response(text) + yield self.__build_full_text_response(text, last_grounding_metadata) text = '' parts = [ types.Part(function_call=function_call) for function_call in message.tool_call.function_calls ] - yield LlmResponse(content=types.Content(role='model', parts=parts)) + yield LlmResponse( + content=types.Content(role='model', parts=parts), + grounding_metadata=last_grounding_metadata, + ) + # Note: last_grounding_metadata is NOT reset here because tool_call + # is part of an ongoing turn. The metadata persists until turn_complete + # or interrupted with break, ensuring subsequent messages in the same + # turn can access the grounding information. if message.session_resumption_update: logger.debug('Received session resumption message: %s', message) yield ( diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index d065661c69..c3d507ca26 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -202,6 +202,7 @@ async def test_receive_usage_metadata_and_server_content( mock_server_content.input_transcription = None mock_server_content.output_transcription = None mock_server_content.turn_complete = False + mock_server_content.grounding_metadata = None mock_message = mock.AsyncMock() mock_message.usage_metadata = usage_metadata @@ -261,6 +262,7 @@ async def test_receive_transcript_finished_on_interrupt( message1.server_content.output_transcription = None message1.server_content.turn_complete = False message1.server_content.generation_complete = False + message1.server_content.grounding_metadata = None message1.tool_call = None message1.session_resumption_update = None @@ -275,6 +277,7 @@ async def test_receive_transcript_finished_on_interrupt( ) message2.server_content.turn_complete = False message2.server_content.generation_complete = False + message2.server_content.grounding_metadata = None message2.tool_call = None message2.session_resumption_update = None @@ -287,6 +290,7 @@ async def test_receive_transcript_finished_on_interrupt( message3.server_content.output_transcription = None message3.server_content.turn_complete = False message3.server_content.generation_complete = False + message3.server_content.grounding_metadata = None message3.tool_call = None message3.session_resumption_update = None @@ -408,6 +412,7 @@ async def test_receive_transcript_finished_on_turn_complete( message1.server_content.output_transcription = None message1.server_content.turn_complete = False message1.server_content.generation_complete = False + message1.server_content.grounding_metadata = None message1.tool_call = None message1.session_resumption_update = None @@ -422,6 +427,7 @@ async def test_receive_transcript_finished_on_turn_complete( ) message2.server_content.turn_complete = False message2.server_content.generation_complete = False + message2.server_content.grounding_metadata = None message2.tool_call = None message2.session_resumption_update = None @@ -434,6 +440,7 @@ async def test_receive_transcript_finished_on_turn_complete( message3.server_content.output_transcription = None message3.server_content.turn_complete = True message3.server_content.generation_complete = False + message3.server_content.grounding_metadata = None message3.tool_call = None message3.session_resumption_update = None @@ -774,3 +781,302 @@ async def test_send_history_filters_various_audio_mime_types( # No content should be sent since the only part is audio mock_gemini_session.send.assert_not_called() + + +@pytest.mark.asyncio +async def test_receive_extracts_grounding_metadata( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is extracted from server_content and included in LlmResponse.""" + mock_content = types.Content( + role='model', parts=[types.Part.from_text(text='response text')] + ) + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + web_search_queries=['web search query'], + ) + + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = True + mock_server_content.generation_complete = False + mock_server_content.grounding_metadata = mock_grounding_metadata + + mock_message = mock.Mock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Should have at least 2 responses: content with grounding and turn_complete + assert len(responses) >= 2 + + # Find response with content + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.grounding_metadata == mock_grounding_metadata + assert content_response.grounding_metadata.retrieval_queries == ['test query'] + assert content_response.grounding_metadata.web_search_queries == [ + 'web search query' + ] + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_at_turn_complete( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is included in turn_complete response if no text was built.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with grounding but no content + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = None + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with turn_complete + mock_server_content2 = mock.Mock() + mock_server_content2.model_turn = None + mock_server_content2.interrupted = False + mock_server_content2.input_transcription = None + mock_server_content2.output_transcription = None + mock_server_content2.turn_complete = True + mock_server_content2.generation_complete = False + mock_server_content2.grounding_metadata = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock_server_content2 + message2.tool_call = None + message2.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find turn_complete response + turn_complete_response = next((r for r in responses if r.turn_complete), None) + assert turn_complete_response is not None + # The grounding_metadata should be carried over to turn_complete + assert turn_complete_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_with_text_and_turn_complete( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is preserved when text content is followed by turn_complete.""" + mock_content = types.Content( + role='model', parts=[types.Part.from_text(text='response text')] + ) + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # Message with both content and grounding, followed by turn_complete + mock_server_content = mock.Mock() + mock_server_content.model_turn = mock_content + mock_server_content.interrupted = False + mock_server_content.input_transcription = None + mock_server_content.output_transcription = None + mock_server_content.turn_complete = True + mock_server_content.generation_complete = False + mock_server_content.grounding_metadata = mock_grounding_metadata + + mock_message = mock.Mock() + mock_message.usage_metadata = None + mock_message.server_content = mock_server_content + mock_message.tool_call = None + mock_message.session_resumption_update = None + + async def mock_receive_generator(): + yield mock_message + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find content response with grounding + content_response = next((r for r in responses if r.content), None) + assert content_response is not None + assert content_response.grounding_metadata == mock_grounding_metadata + + # Find turn_complete response - should also have grounding_metadata + turn_complete_response = next((r for r in responses if r.turn_complete), None) + assert turn_complete_response is not None + assert turn_complete_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_grounding_metadata_with_tool_call( + gemini_connection, mock_gemini_session +): + """Test that grounding_metadata is propagated with tool_call responses.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with grounding metadata + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = None + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with tool_call + mock_function_call = types.FunctionCall( + name='test_function', args={'param': 'value'} + ) + mock_tool_call = mock.Mock() + mock_tool_call.function_calls = [mock_function_call] + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = None + message2.tool_call = mock_tool_call + message2.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find tool_call response + tool_call_response = next( + (r for r in responses if r.content and r.content.parts[0].function_call), + None, + ) + assert tool_call_response is not None + # The grounding_metadata should be carried over to tool_call + assert tool_call_response.grounding_metadata == mock_grounding_metadata + + +@pytest.mark.asyncio +async def test_receive_interrupted_with_pending_text_preserves_flag( + gemini_connection, mock_gemini_session +): + """Test that interrupted flag is preserved when flushing pending text.""" + mock_grounding_metadata = types.GroundingMetadata( + retrieval_queries=['test query'], + ) + + # First message with text content and grounding + mock_content1 = types.Content( + role='model', parts=[types.Part.from_text(text='partial')] + ) + mock_server_content1 = mock.Mock() + mock_server_content1.model_turn = mock_content1 + mock_server_content1.interrupted = False + mock_server_content1.input_transcription = None + mock_server_content1.output_transcription = None + mock_server_content1.turn_complete = False + mock_server_content1.generation_complete = False + mock_server_content1.grounding_metadata = mock_grounding_metadata + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock_server_content1 + message1.tool_call = None + message1.session_resumption_update = None + + # Second message with more text + mock_content2 = types.Content( + role='model', parts=[types.Part.from_text(text=' text')] + ) + mock_server_content2 = mock.Mock() + mock_server_content2.model_turn = mock_content2 + mock_server_content2.interrupted = False + mock_server_content2.input_transcription = None + mock_server_content2.output_transcription = None + mock_server_content2.turn_complete = False + mock_server_content2.generation_complete = False + mock_server_content2.grounding_metadata = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock_server_content2 + message2.tool_call = None + message2.session_resumption_update = None + + # Third message with interrupted signal + mock_server_content3 = mock.Mock() + mock_server_content3.model_turn = None + mock_server_content3.interrupted = True + mock_server_content3.input_transcription = None + mock_server_content3.output_transcription = None + mock_server_content3.turn_complete = False + mock_server_content3.generation_complete = False + mock_server_content3.grounding_metadata = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock_server_content3 + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_connection.receive()] + + # Find the full text response that should have been flushed with interrupted=True + full_text_responses = [ + r for r in responses if r.content and not r.partial and r.interrupted + ] + assert ( + len(full_text_responses) > 0 + ), 'Should have interrupted full text response' + + # The full text response should have the accumulated text + assert full_text_responses[0].content.parts[0].text == 'partial text' + # And should carry the grounding_metadata + assert full_text_responses[0].grounding_metadata == mock_grounding_metadata + # And should have interrupted=True + assert full_text_responses[0].interrupted is True