Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions src/a2a/server/apps/rest/fastapi_app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import logging

from collections.abc import Callable
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.responses import JSONResponse

_package_fastapi_installed = True
else:
try:
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.responses import JSONResponse

_package_fastapi_installed = True
except ImportError:
Expand All @@ -23,6 +26,7 @@

from a2a.server.apps.jsonrpc.jsonrpc_app import CallContextBuilder
from a2a.server.apps.rest.rest_adapter import RESTAdapter
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types import AgentCard
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
Expand All @@ -39,24 +43,35 @@
(SSE).
"""

def __init__(
def __init__( # noqa: PLR0913
self,
agent_card: AgentCard,
http_handler: RequestHandler,
extended_agent_card: AgentCard | None = None,
context_builder: CallContextBuilder | None = None,
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
extended_card_modifier: Callable[
[AgentCard, ServerCallContext], AgentCard
]
| None = None,
):
"""Initializes the A2ARESTFastAPIApplication.

Args:
agent_card: The AgentCard describing the agent's capabilities.
http_handler: The handler instance responsible for processing A2A
requests via http.
extended_agent_card: An optional, distinct AgentCard to be served
at the authenticated extended card endpoint.
context_builder: The CallContextBuilder used to construct the
ServerCallContext passed to the http_handler. If None, no
ServerCallContext is passed.
card_modifier: An optional callback to dynamically modify the public
agent card before it is served.
extended_card_modifier: An optional callback to dynamically modify
the extended agent card before it is served. It receives the
call context.
"""

Check notice on line 74 in src/a2a/server/apps/rest/fastapi_app.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/rest/rest_adapter.py (55-83)
if not _package_fastapi_installed:
raise ImportError(
'The `fastapi` package is required to use the'
Expand All @@ -66,7 +81,10 @@
self._adapter = RESTAdapter(
agent_card=agent_card,
http_handler=http_handler,
extended_agent_card=extended_agent_card,
context_builder=context_builder,
card_modifier=card_modifier,
extended_card_modifier=extended_card_modifier,
)

def build(
Expand Down Expand Up @@ -95,7 +113,8 @@

@router.get(f'{rpc_url}{agent_card_url}')
async def get_agent_card(request: Request) -> Response:
return await self._adapter.handle_get_agent_card(request)
card = await self._adapter.handle_get_agent_card(request)
return JSONResponse(card)

app.include_router(router)
return app
67 changes: 52 additions & 15 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,35 @@
manages response generation including Server-Sent Events (SSE).
"""

def __init__(
def __init__( # noqa: PLR0913
self,
agent_card: AgentCard,
http_handler: RequestHandler,
extended_agent_card: AgentCard | None = None,
context_builder: CallContextBuilder | None = None,
card_modifier: Callable[[AgentCard], AgentCard] | None = None,
extended_card_modifier: Callable[
[AgentCard, ServerCallContext], AgentCard
]
| None = None,
):

Check notice on line 66 in src/a2a/server/apps/rest/rest_adapter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/jsonrpc/starlette_app.py (51-62)
"""Initializes the RESTApplication.

Args:
agent_card: The AgentCard describing the agent's capabilities.
http_handler: The handler instance responsible for processing A2A
requests via http.
extended_agent_card: An optional, distinct AgentCard to be served
at the authenticated extended card endpoint.
context_builder: The CallContextBuilder used to construct the
ServerCallContext passed to the http_handler. If None, no
ServerCallContext is passed.
card_modifier: An optional callback to dynamically modify the public
agent card before it is served.

Check notice on line 79 in src/a2a/server/apps/rest/rest_adapter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/jsonrpc/fastapi_app.py (69-97)

Check notice on line 79 in src/a2a/server/apps/rest/rest_adapter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/jsonrpc/jsonrpc_app.py (154-182)
extended_card_modifier: An optional callback to dynamically modify
the extended agent card before it is served. It receives the
call context.
"""

Check notice on line 83 in src/a2a/server/apps/rest/rest_adapter.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/server/apps/rest/fastapi_app.py (46-74)
if not _package_starlette_installed:
raise ImportError(
'Packages `starlette` and `sse-starlette` are required to use'
Expand All @@ -75,9 +88,20 @@
' optional dependencies, `a2a-sdk[http-server]`.'
)
self.agent_card = agent_card
self.extended_agent_card = extended_agent_card
self.card_modifier = card_modifier
self.extended_card_modifier = extended_card_modifier
self.handler = RESTHandler(
agent_card=agent_card, request_handler=http_handler
)
if (
self.agent_card.supports_authenticated_extended_card
and self.extended_agent_card is None
and self.extended_card_modifier is None
):
logger.error(
'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
)
self._context_builder = context_builder or DefaultCallContextBuilder()

@rest_error_handler
Expand Down Expand Up @@ -108,33 +132,35 @@
event_generator(method(request, call_context))
)

@rest_error_handler
async def handle_get_agent_card(self, request: Request) -> JSONResponse:
async def handle_get_agent_card(
self, request: Request, call_context: ServerCallContext | None = None
) -> dict[str, Any]:
"""Handles GET requests for the agent card endpoint.

Args:
request: The incoming Starlette Request object.
call_context: ServerCallContext

Returns:
A JSONResponse containing the agent card data.
"""
# The public agent card is a direct serialization of the agent_card
# provided at initialization.
return JSONResponse(
self.agent_card.model_dump(mode='json', exclude_none=True)
)
card_to_serve = self.agent_card
if self.card_modifier:
card_to_serve = self.card_modifier(card_to_serve)

return card_to_serve.model_dump(mode='json', exclude_none=True)

@rest_error_handler
async def handle_authenticated_agent_card(
self, request: Request
) -> JSONResponse:
self, request: Request, call_context: ServerCallContext | None = None
) -> dict[str, Any]:
"""Hook for per credential agent card response.

If a dynamic card is needed based on the credentials provided in the request
override this method and return the customized content.

Args:
request: The incoming Starlette Request object.
call_context: ServerCallContext

Returns:
A JSONResponse containing the authenticated card.
Expand All @@ -145,9 +171,18 @@
message='Authenticated card not supported'
)
)
return JSONResponse(
self.agent_card.model_dump(mode='json', exclude_none=True)
)
card_to_serve = self.extended_agent_card

if not card_to_serve:
card_to_serve = self.agent_card

if self.extended_card_modifier:
context = self._context_builder.build(request)
# If no base extended card is provided, pass the public card to the modifier
base_card = card_to_serve if card_to_serve else self.agent_card
card_to_serve = self.extended_card_modifier(base_card, context)

return card_to_serve.model_dump(mode='json', exclude_none=True)

def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
"""Constructs a dictionary of API routes and their corresponding handlers.
Expand Down Expand Up @@ -201,6 +236,8 @@
),
}
if self.agent_card.supports_authenticated_extended_card:
routes[('/v1/card', 'GET')] = self.handle_authenticated_agent_card
routes[('/v1/card', 'GET')] = functools.partial(
self._handle_request, self.handle_authenticated_agent_card
)

return routes
36 changes: 29 additions & 7 deletions tests/integration/test_client_server_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio

from collections.abc import AsyncGenerator
from typing import NamedTuple
from unittest.mock import ANY, AsyncMock
Expand All @@ -8,7 +7,6 @@
import httpx
import pytest
import pytest_asyncio

from grpc.aio import Channel

from a2a.client.transports import JsonRpcTransport, RestTransport
Expand Down Expand Up @@ -38,7 +36,6 @@
TransportProtocol,
)


# --- Test Constants ---

TASK_FROM_STREAM = Task(
Expand Down Expand Up @@ -130,7 +127,7 @@ def agent_card() -> AgentCard:
default_input_modes=['text/plain'],
default_output_modes=['text/plain'],
preferred_transport=TransportProtocol.jsonrpc,
supports_authenticated_extended_card=True,
supports_authenticated_extended_card=False,
additional_interfaces=[
AgentInterface(
transport=TransportProtocol.http_json, url='http://testserver'
Expand Down Expand Up @@ -709,9 +706,7 @@ async def test_http_transport_get_card(
transport_setup_fixture
)
transport = transport_setup.transport

# The transport starts with a minimal card, get_card() fetches the full one
transport.agent_card.supports_authenticated_extended_card = True
# Get the base card.
result = await transport.get_card()

assert result.name == agent_card.name
Expand All @@ -722,6 +717,33 @@ async def test_http_transport_get_card(
await transport.close()


@pytest.mark.asyncio
async def test_http_transport_get_authenticated_card(
agent_card: AgentCard,
mock_request_handler: AsyncMock,
) -> None:
agent_card.supports_authenticated_extended_card = True
extended_agent_card = agent_card.model_copy(deep=True)
extended_agent_card.name = 'Extended Agent Card'

app_builder = A2ARESTFastAPIApplication(
agent_card,
mock_request_handler,
extended_agent_card=extended_agent_card,
)
app = app_builder.build()
httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app))

transport = RestTransport(httpx_client=httpx_client, agent_card=agent_card)
result = await transport.get_card()
assert result.name == extended_agent_card.name
assert transport.agent_card.name == extended_agent_card.name
assert transport._needs_extended_card is False

if hasattr(transport, 'close'):
await transport.close()


@pytest.mark.asyncio
async def test_grpc_transport_get_card(
grpc_server_and_handler: tuple[str, AsyncMock],
Expand Down
Loading