Skip to content

Commit bb8a269

Browse files
GWealecopybara-github
authored andcommitted
fix: Extract and propagate task_id in RemoteA2aAgent
The RemoteA2aAgent now extracts a "task_id" from the custom metadata of the last agent event in the session, alongside the existing "context_id". This task_id is then included in the A2AMessage sent to the remote A2A service. Close #3765 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 839996774
1 parent 507424a commit bb8a269

File tree

2 files changed

+88
-34
lines changed

2 files changed

+88
-34
lines changed

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,18 +343,18 @@ def _create_a2a_request_for_user_function_response(
343343

344344
def _construct_message_parts_from_session(
345345
self, ctx: InvocationContext
346-
) -> tuple[list[A2APart], Optional[str]]:
346+
) -> tuple[list[A2APart], Optional[str], Optional[str]]:
347347
"""Construct A2A message parts from session events.
348348
349349
Args:
350350
ctx: The invocation context
351351
352352
Returns:
353-
List of A2A parts extracted from session events, context ID,
354-
request metadata
353+
List of A2A parts extracted from session events, context ID, task ID
355354
"""
356355
message_parts: list[A2APart] = []
357356
context_id = None
357+
task_id = None
358358

359359
events_to_process = []
360360
for event in reversed(ctx.session.events):
@@ -364,6 +364,7 @@ def _construct_message_parts_from_session(
364364
if event.custom_metadata:
365365
metadata = event.custom_metadata
366366
context_id = metadata.get(A2A_METADATA_PREFIX + "context_id")
367+
task_id = metadata.get(A2A_METADATA_PREFIX + "task_id")
367368
break
368369
events_to_process.append(event)
369370

@@ -384,7 +385,7 @@ def _construct_message_parts_from_session(
384385
else:
385386
logger.warning("Failed to convert part to A2A format: %s", part)
386387

387-
return message_parts, context_id
388+
return message_parts, context_id, task_id
388389

389390
async def _handle_a2a_response(
390391
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
@@ -500,8 +501,8 @@ async def _run_async_impl(
500501
# Create A2A request for function response or regular message
501502
a2a_request = self._create_a2a_request_for_user_function_response(ctx)
502503
if not a2a_request:
503-
message_parts, context_id = self._construct_message_parts_from_session(
504-
ctx
504+
message_parts, context_id, task_id = (
505+
self._construct_message_parts_from_session(ctx)
505506
)
506507

507508
if not message_parts:
@@ -521,6 +522,7 @@ async def _run_async_impl(
521522
parts=message_parts,
522523
role="user",
523524
context_id=context_id,
525+
task_id=task_id,
524526
)
525527

526528
logger.debug(build_a2a_request_log(a2a_request))

tests/unittests/agents/test_remote_a2a_agent.py

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
from pathlib import Path
1717
import tempfile
18+
from unittest import mock
1819
from unittest.mock import AsyncMock
1920
from unittest.mock import create_autospec
2021
from unittest.mock import Mock
@@ -612,13 +613,14 @@ def test_construct_message_parts_from_session_success(self):
612613
mock_a2a_part = Mock()
613614
self.mock_genai_part_converter.return_value = mock_a2a_part
614615

615-
parts, context_id = self.agent._construct_message_parts_from_session(
616-
self.mock_context
616+
parts, context_id, task_id = (
617+
self.agent._construct_message_parts_from_session(self.mock_context)
617618
)
618619

619620
assert len(parts) == 1
620621
assert parts[0] == mock_a2a_part
621622
assert context_id is None
623+
assert task_id is None
622624

623625
def test_construct_message_parts_from_session_success_multiple_parts(self):
624626
"""Test successful message parts construction from session."""
@@ -646,23 +648,54 @@ def test_construct_message_parts_from_session_success_multiple_parts(self):
646648
mock_a2a_part2,
647649
]
648650

649-
parts, context_id = self.agent._construct_message_parts_from_session(
650-
self.mock_context
651+
parts, context_id, task_id = (
652+
self.agent._construct_message_parts_from_session(self.mock_context)
651653
)
652654

653655
assert parts == [mock_a2a_part1, mock_a2a_part2]
654656
assert context_id is None
657+
assert task_id is None
655658

656659
def test_construct_message_parts_from_session_empty_events(self):
657660
"""Test message parts construction with empty events."""
658661
self.mock_session.events = []
659662

660-
parts, context_id = self.agent._construct_message_parts_from_session(
661-
self.mock_context
663+
parts, context_id, task_id = (
664+
self.agent._construct_message_parts_from_session(self.mock_context)
662665
)
663666

664667
assert parts == []
665668
assert context_id is None
669+
assert task_id is None
670+
671+
def test_construct_message_parts_from_session_reads_ids_from_metadata(self):
672+
"""Metadata from last agent event is reused for context and task IDs."""
673+
mock_part = Mock()
674+
mock_part.text = "User message"
675+
mock_content = Mock()
676+
mock_content.parts = [mock_part]
677+
user_event = Mock()
678+
user_event.content = mock_content
679+
user_event.author = "user"
680+
681+
agent_event = Mock()
682+
agent_event.author = self.agent.name
683+
agent_event.custom_metadata = {
684+
A2A_METADATA_PREFIX + "context_id": "context-xyz",
685+
A2A_METADATA_PREFIX + "task_id": "task-abc",
686+
}
687+
688+
# Agent reply is before the latest user message (chronological order).
689+
self.mock_session.events = [agent_event, user_event]
690+
self.mock_genai_part_converter.return_value = Mock()
691+
692+
parts, context_id, task_id = (
693+
self.agent._construct_message_parts_from_session(self.mock_context)
694+
)
695+
696+
assert len(parts) == 1 # the latest user message
697+
assert context_id == "context-xyz"
698+
assert task_id == "task-abc"
666699

667700
@pytest.mark.asyncio
668701
async def test_handle_a2a_response_success_with_message(self):
@@ -786,13 +819,14 @@ def mock_converter(part):
786819

787820
self.mock_genai_part_converter.side_effect = mock_converter
788821

789-
parts, context_id = self.agent._construct_message_parts_from_session(
790-
self.mock_context
822+
parts, context_id, task_id = (
823+
self.agent._construct_message_parts_from_session(self.mock_context)
791824
)
792825

793826
# Verify the parts are in correct order
794827
assert len(parts) == 3 # 1 user part + 2 other agent parts
795828
assert context_id is None
829+
assert task_id is None
796830

797831
# Verify order: user part, then "For context:", then agent message
798832
assert converted_parts[0].original_text == "User question"
@@ -1109,24 +1143,14 @@ def test_construct_message_parts_from_session_success(self):
11091143
mock_a2a_part = Mock()
11101144
mock_convert_part.return_value = mock_a2a_part
11111145

1112-
parts, context_id = self.agent._construct_message_parts_from_session(
1113-
self.mock_context
1146+
parts, context_id, task_id = (
1147+
self.agent._construct_message_parts_from_session(self.mock_context)
11141148
)
11151149

11161150
assert len(parts) == 1
11171151
assert parts[0] == mock_a2a_part
11181152
assert context_id is None
1119-
1120-
def test_construct_message_parts_from_session_empty_events(self):
1121-
"""Test message parts construction with empty events."""
1122-
self.mock_session.events = []
1123-
1124-
parts, context_id = self.agent._construct_message_parts_from_session(
1125-
self.mock_context
1126-
)
1127-
1128-
assert parts == []
1129-
assert context_id is None
1153+
assert task_id is None
11301154

11311155
@pytest.mark.asyncio
11321156
async def test_handle_a2a_response_success_with_message(self):
@@ -1463,7 +1487,8 @@ async def test_run_async_impl_no_message_parts(self):
14631487
mock_construct.return_value = (
14641488
[],
14651489
None,
1466-
) # Tuple with empty parts and no context_id
1490+
None,
1491+
) # Tuple with empty parts and no context/task ids
14671492

14681493
events = []
14691494
async for event in self.agent._run_async_impl(self.mock_context):
@@ -1493,7 +1518,8 @@ async def test_run_async_impl_successful_request(self):
14931518
mock_construct.return_value = (
14941519
[mock_a2a_part],
14951520
"context-123",
1496-
) # Tuple with parts and context_id
1521+
"task-789",
1522+
) # Tuple with parts and context/task ids
14971523

14981524
# Mock A2A client
14991525
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1545,6 +1571,13 @@ async def test_run_async_impl_successful_request(self):
15451571
A2A_METADATA_PREFIX + "request"
15461572
in mock_event.custom_metadata
15471573
)
1574+
mock_message_class.assert_called_once_with(
1575+
message_id=mock.ANY,
1576+
parts=[mock_a2a_part],
1577+
role="user",
1578+
context_id="context-123",
1579+
task_id="task-789",
1580+
)
15481581

15491582
@pytest.mark.asyncio
15501583
async def test_run_async_impl_a2a_client_error(self):
@@ -1565,7 +1598,8 @@ async def test_run_async_impl_a2a_client_error(self):
15651598
mock_construct.return_value = (
15661599
[mock_a2a_part],
15671600
"context-123",
1568-
) # Tuple with parts and context_id
1601+
"task-789",
1602+
) # Tuple with parts and context/task ids
15691603

15701604
# Mock A2A client that throws an exception
15711605
mock_a2a_client = AsyncMock()
@@ -1632,7 +1666,8 @@ async def test_run_async_impl_with_meta_provider(self):
16321666
mock_construct.return_value = (
16331667
[mock_a2a_part],
16341668
"context-123",
1635-
) # Tuple with parts and context_id
1669+
"task-789",
1670+
) # Tuple with parts and context/task ids
16361671

16371672
# Mock A2A client
16381673
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1683,6 +1718,13 @@ async def test_run_async_impl_with_meta_provider(self):
16831718
request=mock_message,
16841719
request_metadata=request_metadata,
16851720
)
1721+
mock_message_class.assert_called_once_with(
1722+
message_id=mock.ANY,
1723+
parts=[mock_a2a_part],
1724+
role="user",
1725+
context_id="context-123",
1726+
task_id="task-789",
1727+
)
16861728

16871729

16881730
class TestRemoteA2aAgentExecutionFromFactory:
@@ -1737,7 +1779,8 @@ async def test_run_async_impl_no_message_parts(self):
17371779
mock_construct.return_value = (
17381780
[],
17391781
None,
1740-
) # Tuple with empty parts and no context_id
1782+
None,
1783+
) # Tuple with empty parts and no context/task ids
17411784

17421785
events = []
17431786
async for event in self.agent._run_async_impl(self.mock_context):
@@ -1767,7 +1810,8 @@ async def test_run_async_impl_successful_request(self):
17671810
mock_construct.return_value = (
17681811
[mock_a2a_part],
17691812
"context-123",
1770-
) # Tuple with parts and context_id
1813+
"task-789",
1814+
) # Tuple with parts and context/task ids
17711815

17721816
# Mock A2A client
17731817
mock_a2a_client = create_autospec(spec=A2AClient, instance=True)
@@ -1821,6 +1865,13 @@ async def test_run_async_impl_successful_request(self):
18211865
A2A_METADATA_PREFIX + "request"
18221866
in mock_event.custom_metadata
18231867
)
1868+
mock_message_class.assert_called_once_with(
1869+
message_id=mock.ANY,
1870+
parts=[mock_a2a_part],
1871+
role="user",
1872+
context_id="context-123",
1873+
task_id="task-789",
1874+
)
18241875

18251876
@pytest.mark.asyncio
18261877
async def test_run_async_impl_a2a_client_error(self):
@@ -1841,7 +1892,8 @@ async def test_run_async_impl_a2a_client_error(self):
18411892
mock_construct.return_value = (
18421893
[mock_a2a_part],
18431894
"context-123",
1844-
) # Tuple with parts and context_id
1895+
"task-789",
1896+
) # Tuple with parts and context/task ids
18451897

18461898
# Mock A2A client that throws an exception
18471899
mock_a2a_client = AsyncMock()

0 commit comments

Comments
 (0)