diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index cba3517e..21c36fb6 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -3,7 +3,8 @@ from a2a.types import AgentCard, AgentExtension -HTTP_EXTENSION_HEADER = 'X-A2A-Extensions' +HTTP_EXTENSION_HEADER = 'A2A-Extensions' +HTTP_EXTENSION_HEADER_DEPRECATED = 'X-A2A-Extensions' def get_requested_extensions(values: list[str]) -> set[str]: @@ -33,7 +34,7 @@ def update_extension_header( http_kwargs: dict[str, Any] | None, extensions: list[str] | None, ) -> dict[str, Any]: - """Update the X-A2A-Extensions header with active extensions.""" + """Update the A2A-Extensions header with active extensions.""" http_kwargs = http_kwargs or {} if extensions is not None: headers = http_kwargs.setdefault('headers', {}) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 3e7c2854..4ad892d1 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -13,6 +13,7 @@ from a2a.auth.user import User as A2AUser from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER_DEPRECATED, get_requested_extensions, ) from a2a.server.context import ServerCallContext @@ -136,12 +137,13 @@ def build(self, request: Request) -> ServerCallContext: user = StarletteUserProxy(request.user) state['auth'] = request.auth state['headers'] = dict(request.headers) + extension_values = request.headers.getlist( + HTTP_EXTENSION_HEADER + ) or request.headers.getlist(HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state, - requested_extensions=get_requested_extensions( - request.headers.getlist(HTTP_EXTENSION_HEADER) - ), + requested_extensions=get_requested_extensions(extension_values), ) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e2ec69a1..43cf23ab 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -26,6 +26,7 @@ from a2a.auth.user import UnauthenticatedUser from a2a.extensions.common import ( HTTP_EXTENSION_HEADER, + HTTP_EXTENSION_HEADER_DEPRECATED, get_requested_extensions, ) from a2a.grpc import a2a_pb2 @@ -72,12 +73,13 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext: state = {} with contextlib.suppress(Exception): state['grpc_context'] = context + extension_values = _get_metadata_value( + context, HTTP_EXTENSION_HEADER + ) or _get_metadata_value(context, HTTP_EXTENSION_HEADER_DEPRECATED) return ServerCallContext( user=user, state=state, - requested_extensions=get_requested_extensions( - _get_metadata_value(context, HTTP_EXTENSION_HEADER) - ), + requested_extensions=get_requested_extensions(extension_values), ) diff --git a/tests/client/transports/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py index edbcd6c7..f05000c1 100644 --- a/tests/client/transports/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -847,7 +847,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = JsonRpcTransport( diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index cd68b443..d54c1145 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -96,7 +96,7 @@ async def test_send_message_streaming_with_new_extensions( mock_httpx_client: AsyncMock, mock_agent_card: MagicMock, ): - """Test X-A2A-Extensions header in send_message_streaming.""" + """Test A2A-Extensions header in send_message_streaming.""" new_extensions = ['https://example.com/test-ext/v2'] extensions = ['https://example.com/test-ext/v1'] client = RestTransport( diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index e0a564ee..a8407991 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,4 +1,5 @@ import asyncio + from collections.abc import AsyncGenerator from typing import NamedTuple from unittest.mock import ANY, AsyncMock, patch @@ -7,6 +8,7 @@ import httpx import pytest import pytest_asyncio + from grpc.aio import Channel from a2a.client import ClientConfig @@ -38,6 +40,7 @@ TransportProtocol, ) + # --- Test Constants --- TASK_FROM_STREAM = Task( @@ -819,9 +822,9 @@ async def test_base_client_sends_message_with_extensions( call_args, _ = mock_send_request.call_args kwargs = call_args[1] headers = kwargs.get('headers', {}) - assert 'X-A2A-Extensions' in headers + assert 'A2A-Extensions' in headers assert ( - headers['X-A2A-Extensions'] + headers['A2A-Extensions'] == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' )