From c44c86d7988cec5163588cd1bc43dce4188f5df0 Mon Sep 17 00:00:00 2001 From: wgd3 Date: Sat, 6 Dec 2025 12:03:22 -0500 Subject: [PATCH 1/3] fix: update A2A extension HTTP header --- src/a2a/extensions/common.py | 5 +++-- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 10 +++++++--- src/a2a/server/request_handlers/grpc_handler.py | 10 +++++++--- tests/client/transports/test_jsonrpc_client.py | 2 +- tests/client/transports/test_rest_client.py | 2 +- tests/integration/test_client_server_integration.py | 7 +++++-- 6 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index cba3517e..21c36fb6 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -3,7 +3,8 @@ from a2a.types import AgentCard, AgentExtension -HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +HTTP_EXTENSION_HEADER = 'A2A-Extensions' +HTTP_EXTENSION_HEADER_DEPRECATED = 'X-A2A-Extensions' def get_requested_extensions(values: list[str]) -> set[str]: @@ -33,7 +34,7 @@ def update_extension_header( http_kwargs: dict[str, Any] | None, extensions: list[str] | None, ) -> dict[str, Any]: - """Update the X-A2A-Extensions header with active extensions.""" + """Update the A2A-Extensions header with active extensions.""" http_kwargs = http_kwargs or {} if extensions is not None: headers = http_kwargs.setdefault('headers', {}) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 3e7c2854..246e6fc3 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -13,6 +13,7 @@ from a2a.auth.user import User as A2AUser from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER_DEPRECATED, get_requested_extensions, ) from a2a.server.context import ServerCallContext @@ -136,12 +137,15 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) + extension_values = request.headers.getlist(HTTP_EXTENSION_HEADER) + if not extension_values: + extension_values = request.headers.getlist( + HTTP_EXTENSION_HEADER_DEPRECATED + ) return ServerCallContext( user=user, state=state, - requested_extensions=get_requested_extensions( - request.headers.getlist(HTTP_EXTENSION_HEADER) - ), + requested_extensions=get_requested_extensions(extension_values), ) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e2ec69a1..3be3a5e1 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -26,6 +26,7 @@ from a2a.auth.user import UnauthenticatedUser from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER_DEPRECATED, get_requested_extensions, ) from a2a.grpc import a2a_pb2 @@ -72,12 +73,15 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context + extension_values = _get_metadata_value(context, HTTP_EXTENSION_HEADER) + if not extension_values: + extension_values = _get_metadata_value( + context, HTTP_EXTENSION_HEADER_DEPRECATED + ) return ServerCallContext( user=user, state=state, - requested_extensions=get_requested_extensions( - _get_metadata_value(context, HTTP_EXTENSION_HEADER) - ), + requested_extensions=get_requested_extensions(extension_values), ) diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index edbcd6c7..f05000c1 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -847,7 +847,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = JsonRpcTransport( diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index cd68b443..d54c1145 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -96,7 +96,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = RestTransport( diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e0a564ee..a8407991 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -7,6 +8,7 @@ import httpx import pytest import pytest_asyncio + from grpc.aio import Channel from a2a.client import ClientConfig @@ -38,6 +40,7 @@ TransportProtocol, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -819,9 +822,9 @@ async def test_base_client_sends_message_with_extensions( call_args, _ = mock_send_request.call_args kwargs = call_args[1] headers = kwargs.get('headers', {}) - assert 'X-A2A-Extensions' in headers + assert 'A2A-Extensions' in headers assert ( - headers['X-A2A-Extensions'] + headers['A2A-Extensions'] == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' ) From 738205ba3b812df7971f416eb3cd8d8c8511d7e7 Mon Sep 17 00:00:00 2001 From: Wallace Daniel Date: Sat, 6 Dec 2025 12:24:43 -0500 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 6 +----- src/a2a/server/request_handlers/grpc_handler.py | 6 +----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 246e6fc3..e2992ff5 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -137,11 +137,7 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) - extension_values = request.headers.getlist(HTTP_EXTENSION_HEADER) - if not extension_values: - extension_values = request.headers.getlist( - HTTP_EXTENSION_HEADER_DEPRECATED - ) + extension_values = request.headers.getlist(HTTP_EXTENSION_HEADER) or request.headers.getlist(HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 3be3a5e1..54455e5e 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -73,11 +73,7 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context - extension_values = _get_metadata_value(context, HTTP_EXTENSION_HEADER) - if not extension_values: - extension_values = _get_metadata_value( - context, HTTP_EXTENSION_HEADER_DEPRECATED - ) + extension_values = _get_metadata_value(context, HTTP_EXTENSION_HEADER) or _get_metadata_value(context, HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state, From ee188562282d0e3d2353b1b5be095a83a840078e Mon Sep 17 00:00:00 2001 From: wgd3 Date: Sat, 6 Dec 2025 12:26:53 -0500 Subject: [PATCH 3/3] chore: updated formatting after code review commits --- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 4 +++- src/a2a/server/request_handlers/grpc_handler.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index e2992ff5..4ad892d1 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -137,7 +137,9 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) - extension_values = request.headers.getlist(HTTP_EXTENSION_HEADER) or request.headers.getlist(HTTP_EXTENSION_HEADER_DEPRECATED) + extension_values = request.headers.getlist( + HTTP_EXTENSION_HEADER + ) or request.headers.getlist(HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state, diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 54455e5e..43cf23ab 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -73,7 +73,9 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context - extension_values = _get_metadata_value(context, HTTP_EXTENSION_HEADER) or _get_metadata_value(context, HTTP_EXTENSION_HEADER_DEPRECATED) + extension_values = _get_metadata_value( + context, HTTP_EXTENSION_HEADER + ) or _get_metadata_value(context, HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state,