From d8076b68fb8dd73e1ba11f2a8bc32c270372ebc3 Mon Sep 17 00:00:00 2001 From: ftnext Date: Sat, 31 Jan 2026 15:15:22 +0900 Subject: [PATCH] fix(a2a): avoid UUID session IDs by mapping A2A context IDs --- .../adk/a2a/converters/request_converter.py | 9 ++- .../adk/a2a/executor/a2a_agent_executor.py | 74 +++++++++++++------ .../a2a/converters/test_request_converter.py | 44 +++++++++-- .../a2a/executor/test_a2a_agent_executor.py | 44 +++++++---- 4 files changed, 127 insertions(+), 44 deletions(-) diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 17989374d6..84287795cc 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -26,6 +26,7 @@ from ..experimental import a2a_experimental from .part_converter import A2APartToGenAIPartConverter from .part_converter import convert_a2a_part_to_genai_part +from .utils import _from_a2a_context_id @a2a_experimental @@ -70,6 +71,10 @@ def _get_user_id(request: RequestContext) -> str: ): return request.call_context.user.user_name + _, user_id, _ = _from_a2a_context_id(request.context_id) + if user_id: + return user_id + # Get user from context id return f'A2A_USER_{request.context_id}' @@ -106,9 +111,11 @@ def convert_a2a_request_to_agent_run_request( genai_parts = [genai_parts] if genai_parts else [] output_parts.extend(genai_parts) + _, _, session_id = _from_a2a_context_id(request.context_id) + return AgentRunRequest( user_id=_get_user_id(request), - session_id=request.context_id, + session_id=session_id, new_message=genai_types.Content( role='user', parts=output_parts, diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index cca728dbfd..097d842dca 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -49,6 +49,7 @@ from ..converters.request_converter import AgentRunRequest from ..converters.request_converter import convert_a2a_request_to_agent_run_request from ..converters.utils import _get_adk_metadata_key +from ..converters.utils import _to_a2a_context_id from ..experimental import a2a_experimental from .task_result_aggregator import TaskResultAggregator @@ -135,21 +136,6 @@ async def execute( if not context.message: raise ValueError('A2A request must have a message') - # for new task, create a task submitted event - if not context.current_task: - await event_queue.enqueue_event( - TaskStatusUpdateEvent( - task_id=context.task_id, - status=TaskStatus( - state=TaskState.submitted, - message=context.message, - timestamp=datetime.now(timezone.utc).isoformat(), - ), - context_id=context.context_id, - final=False, - ) - ) - # Handle the request and publish updates to the event queue try: await self._handle_request(context, event_queue) @@ -194,6 +180,27 @@ async def _handle_request( # ensure the session exists session = await self._prepare_session(context, run_request, runner) + response_context_id = self._get_response_context_id( + context=context, + runner=runner, + run_request=run_request, + session_id=session.id, + ) + + # for new task, create a task submitted event + if not context.current_task: + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=context.task_id, + status=TaskStatus( + state=TaskState.submitted, + message=context.message, + timestamp=datetime.now(timezone.utc).isoformat(), + ), + context_id=response_context_id, + final=False, + ) + ) # create invocation context invocation_context = runner._new_invocation_context( @@ -210,7 +217,7 @@ async def _handle_request( state=TaskState.working, timestamp=datetime.now(timezone.utc).isoformat(), ), - context_id=context.context_id, + context_id=response_context_id, final=False, metadata={ _get_adk_metadata_key('app_name'): runner.app_name, @@ -227,7 +234,7 @@ async def _handle_request( adk_event, invocation_context, context.task_id, - context.context_id, + response_context_id, self._config.gen_ai_part_converter, ): task_result_aggregator.process_event(a2a_event) @@ -245,7 +252,7 @@ async def _handle_request( TaskArtifactUpdateEvent( task_id=context.task_id, last_chunk=True, - context_id=context.context_id, + context_id=response_context_id, artifact=Artifact( artifact_id=str(uuid.uuid4()), parts=task_result_aggregator.task_status_message.parts, @@ -260,7 +267,7 @@ async def _handle_request( state=TaskState.completed, timestamp=datetime.now(timezone.utc).isoformat(), ), - context_id=context.context_id, + context_id=response_context_id, final=True, ) ) @@ -273,21 +280,44 @@ async def _handle_request( timestamp=datetime.now(timezone.utc).isoformat(), message=task_result_aggregator.task_status_message, ), - context_id=context.context_id, + context_id=response_context_id, final=True, ) ) + def _get_response_context_id( + self, + *, + context: RequestContext, + runner: Runner, + run_request: AgentRunRequest, + session_id: str, + ) -> str: + try: + return _to_a2a_context_id( + runner.app_name, run_request.user_id, session_id + ) + except ValueError: + return context.context_id + async def _prepare_session( self, context: RequestContext, run_request: AgentRunRequest, runner: Runner, ): - session_id = run_request.session_id - # create a new session if not exists user_id = run_request.user_id + if not session_id: + session = await runner.session_service.create_session( + app_name=runner.app_name, + user_id=user_id, + state={}, + ) + run_request.session_id = session.id + return session + + # create a new session if not exists session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index cd284ea313..4fbb29d377 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -18,6 +18,7 @@ from a2a.server.agent_execution import RequestContext from google.adk.a2a.converters.request_converter import _get_user_id from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request +from google.adk.a2a.converters.utils import _to_a2a_context_id from google.adk.runners import RunConfig from google.genai import types as genai_types import pytest @@ -58,6 +59,16 @@ def test_get_user_id_from_context_when_no_call_context(self): # Assert assert result == "A2A_USER_test_context" + def test_get_user_id_from_adk_context_id(self): + """Test getting user ID from ADK-formatted context id.""" + request = Mock(spec=RequestContext) + request.call_context = None + request.context_id = _to_a2a_context_id("app", "user-123", "session-456") + + result = _get_user_id(request) + + assert result == "user-123" + def test_get_user_id_from_context_when_call_context_has_no_user(self): """Test getting user ID from context when call context has no user.""" # Arrange @@ -129,6 +140,27 @@ def test_get_user_id_with_none_context_id(self): class TestConvertA2aRequestToAgentRunRequest: """Test cases for convert_a2a_request_to_agent_run_request function.""" + def test_convert_a2a_request_with_adk_context_id(self): + """Test conversion uses ADK context id for user/session.""" + mock_message = Mock() + mock_message.parts = [Mock()] + + request = Mock(spec=RequestContext) + request.message = mock_message + request.context_id = _to_a2a_context_id("app", "user-1", "session-1") + request.call_context = None + request.metadata = {} + + mock_genai_part = genai_types.Part(text="test part") + mock_convert_part = Mock(return_value=mock_genai_part) + + result = convert_a2a_request_to_agent_run_request( + request, mock_convert_part + ) + + assert result.user_id == "user-1" + assert result.session_id == "session-1" + def test_convert_a2a_request_basic(self): """Test basic conversion of A2A request to ADK AgentRunRequest.""" # Arrange @@ -164,7 +196,7 @@ def test_convert_a2a_request_basic(self): # Assert assert result is not None assert result.user_id == "test_user" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part1, mock_genai_part2] @@ -213,7 +245,7 @@ def test_convert_a2a_request_multiple_parts(self): # Assert assert result is not None assert result.user_id == "test_user" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [ @@ -261,7 +293,7 @@ def test_convert_a2a_request_empty_parts(self): # Assert assert result is not None assert result.user_id == "A2A_USER_test_context_123" - assert result.session_id == "test_context_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [] @@ -328,7 +360,7 @@ def test_convert_a2a_request_no_auth(self): # Assert assert result is not None assert result.user_id == "A2A_USER_session_123" - assert result.session_id == "session_123" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] @@ -370,7 +402,7 @@ def test_end_to_end_conversion_with_auth_user(self): # Assert assert result is not None assert result.user_id == "auth_user" # Should use authenticated user - assert result.session_id == "mysession" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] @@ -404,7 +436,7 @@ def test_end_to_end_conversion_with_fallback_user(self): assert ( result.user_id == "A2A_USER_test_session_456" ) # Should fall back to context ID - assert result.session_id == "test_session_456" + assert result.session_id is None assert isinstance(result.new_message, genai_types.Content) assert result.new_message.role == "user" assert result.new_message.parts == [mock_genai_part] diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 40736d959c..2f461c137a 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -22,6 +22,7 @@ from a2a.types import TaskState from a2a.types import TextPart from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.converters.utils import _to_a2a_context_id from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig from google.adk.events.event import Event @@ -107,6 +108,10 @@ async def mock_run_async(**kwargs): # Execute await self.executor.execute(self.mock_context, self.mock_event_queue) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify request converter was called with proper arguments self.mock_request_converter.assert_called_once_with( self.mock_context, self.mock_a2a_part_converter @@ -117,7 +122,7 @@ async def mock_run_async(**kwargs): mock_event, mock_invocation_context, self.mock_context.task_id, - self.mock_context.context_id, + expected_context_id, self.mock_gen_ai_part_converter, ) @@ -128,11 +133,13 @@ async def mock_run_async(**kwargs): ] assert submitted_event.status.state == TaskState.submitted assert submitted_event.final == False + assert submitted_event.context_id == expected_context_id # Verify working event was enqueued working_event = self.mock_event_queue.enqueue_event.call_args_list[1][0][0] assert working_event.status.state == TaskState.working assert working_event.final == False + assert working_event.context_id == expected_context_id # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] @@ -141,6 +148,7 @@ async def mock_run_async(**kwargs): # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") assert final_event.status.state == TaskState.working + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_execute_no_message_error(self): @@ -190,6 +198,10 @@ async def mock_run_async(**kwargs): # Execute await self.executor.execute(self.mock_context, self.mock_event_queue) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify request converter was called with proper arguments self.mock_request_converter.assert_called_once_with( self.mock_context, self.mock_a2a_part_converter @@ -200,7 +212,7 @@ async def mock_run_async(**kwargs): mock_event, mock_invocation_context, self.mock_context.task_id, - self.mock_context.context_id, + expected_context_id, self.mock_gen_ai_part_converter, ) @@ -208,6 +220,7 @@ async def mock_run_async(**kwargs): working_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][0] assert working_event.status.state == TaskState.working assert working_event.final == False + assert working_event.context_id == expected_context_id # Verify final event was enqueued with proper message field final_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] @@ -216,6 +229,7 @@ async def mock_run_async(**kwargs): # are processed, it will publish a status event with the current state assert hasattr(final_event.status, "message") assert final_event.status.state == TaskState.working + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_prepare_session_new_session(self): @@ -613,16 +627,8 @@ async def test_execute_with_exception_handling(self): # Execute (should not raise since we catch the exception) await self.executor.execute(self.mock_context, self.mock_event_queue) - # Verify both submitted and failure events were enqueued - # First call should be submitted event, last should be failure event - assert self.mock_event_queue.enqueue_event.call_count >= 2 - - # Check submitted event (first) - submitted_event = self.mock_event_queue.enqueue_event.call_args_list[0][0][ - 0 - ] - assert submitted_event.status.state == TaskState.submitted - assert submitted_event.final == False + # Request converter error happens before submitted event is enqueued. + assert self.mock_event_queue.enqueue_event.call_count >= 1 # Check failure event (last) failure_event = self.mock_event_queue.enqueue_event.call_args_list[-1][0][0] @@ -846,6 +852,10 @@ async def mock_run_async(**kwargs): self.mock_context, self.mock_event_queue ) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify artifact update event was published artifact_events = [ call[0][0] @@ -855,7 +865,7 @@ async def mock_run_async(**kwargs): assert len(artifact_events) == 1 artifact_event = artifact_events[0] assert artifact_event.task_id == "test-task-id" - assert artifact_event.context_id == "test-context-id" + assert artifact_event.context_id == expected_context_id # Check that artifact parts correspond to message parts assert len(artifact_event.artifact.parts) == len(test_message.parts) assert artifact_event.artifact.parts == test_message.parts @@ -870,7 +880,7 @@ async def mock_run_async(**kwargs): final_event = final_events[-1] # Get the last final event assert final_event.status.state == TaskState.completed assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + assert final_event.context_id == expected_context_id @pytest.mark.asyncio async def test_handle_request_with_non_working_state_publishes_status_only( @@ -939,6 +949,10 @@ async def mock_run_async(**kwargs): self.mock_context, self.mock_event_queue ) + expected_context_id = _to_a2a_context_id( + self.mock_runner.app_name, "test-user", "test-session" + ) + # Verify no artifact update event was published artifact_events = [ call[0][0] @@ -958,4 +972,4 @@ async def mock_run_async(**kwargs): assert final_event.status.state == TaskState.auth_required assert final_event.status.message == test_message assert final_event.task_id == "test-task-id" - assert final_event.context_id == "test-context-id" + assert final_event.context_id == expected_context_id