Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/a2a/extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from a2a.types import AgentCard, AgentExtension


HTTP_EXTENSION_HEADER = 'X-A2A-Extensions'
HTTP_EXTENSION_HEADER = 'A2A-Extensions'
HTTP_EXTENSION_HEADER_DEPRECATED = 'X-A2A-Extensions'


def get_requested_extensions(values: list[str]) -> set[str]:
Expand Down Expand Up @@ -33,7 +34,7 @@ def update_extension_header(
http_kwargs: dict[str, Any] | None,
extensions: list[str] | None,
) -> dict[str, Any]:
"""Update the X-A2A-Extensions header with active extensions."""
"""Update the A2A-Extensions header with active extensions."""
http_kwargs = http_kwargs or {}
if extensions is not None:
headers = http_kwargs.setdefault('headers', {})
Expand Down
10 changes: 7 additions & 3 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from a2a.auth.user import User as A2AUser
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER_DEPRECATED,
get_requested_extensions,
)
from a2a.server.context import ServerCallContext
Expand Down Expand Up @@ -136,12 +137,15 @@ def build(self, request: Request) -> ServerCallContext:
user = StarletteUserProxy(request.user)
state['auth'] = request.auth
state['headers'] = dict(request.headers)
extension_values = request.headers.getlist(HTTP_EXTENSION_HEADER)
if not extension_values:
extension_values = request.headers.getlist(
HTTP_EXTENSION_HEADER_DEPRECATED
)
return ServerCallContext(
user=user,
state=state,
requested_extensions=get_requested_extensions(
request.headers.getlist(HTTP_EXTENSION_HEADER)
),
requested_extensions=get_requested_extensions(extension_values),
)


Expand Down
10 changes: 7 additions & 3 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from a2a.auth.user import UnauthenticatedUser
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
HTTP_EXTENSION_HEADER_DEPRECATED,
get_requested_extensions,
)
from a2a.grpc import a2a_pb2
Expand Down Expand Up @@ -72,12 +73,15 @@ def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
state = {}
with contextlib.suppress(Exception):
state['grpc_context'] = context
extension_values = _get_metadata_value(context, HTTP_EXTENSION_HEADER)
if not extension_values:
extension_values = _get_metadata_value(
context, HTTP_EXTENSION_HEADER_DEPRECATED
)
return ServerCallContext(
user=user,
state=state,
requested_extensions=get_requested_extensions(
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
),
requested_extensions=get_requested_extensions(extension_values),
)


Expand Down
2 changes: 1 addition & 1 deletion tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ async def test_send_message_streaming_with_new_extensions(
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test X-A2A-Extensions header in send_message_streaming."""
"""Test A2A-Extensions header in send_message_streaming."""
new_extensions = ['https://example.com/test-ext/v2']
extensions = ['https://example.com/test-ext/v1']
client = JsonRpcTransport(
Expand Down
2 changes: 1 addition & 1 deletion tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def test_send_message_streaming_with_new_extensions(
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test X-A2A-Extensions header in send_message_streaming."""
"""Test A2A-Extensions header in send_message_streaming."""
new_extensions = ['https://example.com/test-ext/v2']
extensions = ['https://example.com/test-ext/v1']
client = RestTransport(
Expand Down
7 changes: 5 additions & 2 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio

from collections.abc import AsyncGenerator
from typing import NamedTuple
from unittest.mock import ANY, AsyncMock, patch
Expand All @@ -7,6 +8,7 @@
import httpx
import pytest
import pytest_asyncio

from grpc.aio import Channel

from a2a.client import ClientConfig
Expand Down Expand Up @@ -38,6 +40,7 @@
TransportProtocol,
)


# --- Test Constants ---

TASK_FROM_STREAM = Task(
Expand Down Expand Up @@ -819,9 +822,9 @@ async def test_base_client_sends_message_with_extensions(
call_args, _ = mock_send_request.call_args
kwargs = call_args[1]
headers = kwargs.get('headers', {})
assert 'X-A2A-Extensions' in headers
assert 'A2A-Extensions' in headers
assert (
headers['X-A2A-Extensions']
headers['A2A-Extensions']
== 'https://example.com/test-ext/v1,https://example.com/test-ext/v2'
)

Expand Down
Loading