diff --git a/.github/workflows/linter.yaml b/.github/workflows/linter.yaml index a5e5da2b..bdd4c5b8 100644 --- a/.github/workflows/linter.yaml +++ b/.github/workflows/linter.yaml @@ -18,7 +18,7 @@ jobs: with: python-version-file: .python-version - name: Install uv - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 - name: Add uv to PATH run: | echo "$HOME/.cargo/bin" >> $GITHUB_PATH diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 96e87d9e..decb3b1d 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v5 - name: Install uv - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 - name: "Set up Python" uses: actions/setup-python@v6 @@ -26,7 +26,7 @@ jobs: run: uv build - name: Upload distributions - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: release-dists path: dist/ @@ -40,7 +40,7 @@ jobs: steps: - name: Retrieve release distributions - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: name: release-dists path: dist/ diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index ce8d62ab..16052ba1 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -46,7 +46,7 @@ jobs: echo "MYSQL_TEST_DSN=mysql+aiomysql://a2a:a2a_password@localhost:3306/a2a_test" >> $GITHUB_ENV - name: Install uv for Python ${{ matrix.python-version }} - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 with: python-version: ${{ matrix.python-version }} - name: Add uv to PATH diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml index cb4071e7..c019afeb 100644 --- a/.github/workflows/update-a2a-types.yml +++ b/.github/workflows/update-a2a-types.yml @@ -18,7 +18,7 @@ jobs: with: python-version: '3.10' - name: Install uv - uses: astral-sh/setup-uv@v6 + uses: astral-sh/setup-uv@v7 - name: Configure uv shell run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH - name: Install dependencies (datamodel-code-generator) diff --git a/CHANGELOG.md b/CHANGELOG.md index 449438cc..f5e6048d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,40 @@ -# Changelog + # Changelog + +## [0.3.15](https://github.com/a2aproject/a2a-python/compare/v0.3.14...v0.3.15) (2025-11-19) + + +### Features + +* Add client-side extension support ([#525](https://github.com/a2aproject/a2a-python/issues/525)) ([9a92bd2](https://github.com/a2aproject/a2a-python/commit/9a92bd238e7560b195165ac5f78742981760525e)) +* **rest, jsonrpc:** Add client-side extension support ([9a92bd2](https://github.com/a2aproject/a2a-python/commit/9a92bd238e7560b195165ac5f78742981760525e)) + +## [0.3.14](https://github.com/a2aproject/a2a-python/compare/v0.3.13...v0.3.14) (2025-11-17) + + +### Features + +* **jsonrpc:** add option to disable oversized payload check in JSONRPC applications ([ba142df](https://github.com/a2aproject/a2a-python/commit/ba142df821d1c06be0b96e576fd43015120fcb0b)) + +## [0.3.13](https://github.com/a2aproject/a2a-python/compare/v0.3.12...v0.3.13) (2025-11-13) + + +### Bug Fixes + +* return entire history when history_length=0 ([#537](https://github.com/a2aproject/a2a-python/issues/537)) ([acdc0de](https://github.com/a2aproject/a2a-python/commit/acdc0de4fa03d34a6b287ab252ff51b19c3016b5)) + +## [0.3.12](https://github.com/a2aproject/a2a-python/compare/v0.3.11...v0.3.12) (2025-11-12) + + +### Bug Fixes + +* **grpc:** Add `extensions` to `Artifact` converters. ([#523](https://github.com/a2aproject/a2a-python/issues/523)) ([c03129b](https://github.com/a2aproject/a2a-python/commit/c03129b99a663ae1f1ae72f20e4ead7807ede941)) + +## [0.3.11](https://github.com/a2aproject/a2a-python/compare/v0.3.10...v0.3.11) (2025-11-07) + + +### Bug Fixes + +* add metadata to send message request ([12b4a1d](https://github.com/a2aproject/a2a-python/commit/12b4a1d565a53794f5b55c8bd1728221c906ed41)) ## [0.3.10](https://github.com/a2aproject/a2a-python/compare/v0.3.9...v0.3.10) (2025-10-21) diff --git a/scripts/generate_types.sh b/scripts/generate_types.sh index b8d7dedf..6c01cff5 100755 --- a/scripts/generate_types.sh +++ b/scripts/generate_types.sh @@ -4,7 +4,35 @@ # Treat unset variables as an error. set -euo pipefail -REMOTE_URL="https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/main/specification/json/a2a.json" +# A2A specification version to use +# Can be overridden via environment variable: A2A_SPEC_VERSION=v1.2.0 ./generate_types.sh +# Or via command-line flag: ./generate_types.sh --version v1.2.0 output.py +# Use a specific git tag, branch name, or commit SHA +# Examples: "v1.0.0", "v1.2.0", "main", "abc123def" +A2A_SPEC_VERSION="${A2A_SPEC_VERSION:-v0.3.0}" + +# Build URL based on version format +# Tags use /refs/tags/, branches use /refs/heads/, commits use direct ref +build_remote_url() { + local version="$1" + local base_url="https://raw.githubusercontent.com/a2aproject/A2A" + local spec_path="specification/json/a2a.json" + local url_part + + if [[ "$version" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + # Looks like a version tag (v1.0.0, v1.2.3) + url_part="refs/tags/${version}" + elif [[ "$version" =~ ^[0-9a-f]{7,40}$ ]]; then + # Looks like a commit SHA (7+ hex chars) + url_part="${version}" + else + # Assume it's a branch name (main, develop, etc.) + url_part="refs/heads/${version}" + fi + echo "${base_url}/${url_part}/${spec_path}" +} + +REMOTE_URL=$(build_remote_url "$A2A_SPEC_VERSION") GENERATED_FILE="" INPUT_FILE="" @@ -12,20 +40,38 @@ INPUT_FILE="" # Parse command-line arguments while [[ $# -gt 0 ]]; do case "$1" in - --input-file) - INPUT_FILE="$2" - shift 2 - ;; - *) - GENERATED_FILE="$1" - shift 1 - ;; + --input-file) + INPUT_FILE="$2" + shift 2 + ;; + --version) + A2A_SPEC_VERSION="$2" + REMOTE_URL=$(build_remote_url "$A2A_SPEC_VERSION") + shift 2 + ;; + *) + GENERATED_FILE="$1" + shift 1 + ;; esac done if [ -z "$GENERATED_FILE" ]; then - echo "Error: Output file path must be provided." >&2 - echo "Usage: $0 [--input-file ] " + cat >&2 <] [--version ] +Options: + --input-file Use a local JSON schema file instead of fetching from remote + --version Specify A2A spec version (default: v0.3.0) + Can be a git tag (v1.0.0), branch (main), or commit SHA +Environment variables: + A2A_SPEC_VERSION Override default spec version +Examples: + $0 src/a2a/types.py + $0 --version v1.2.0 src/a2a/types.py + $0 --input-file local/a2a.json src/a2a/types.py + A2A_SPEC_VERSION=main $0 src/a2a/types.py +EOF exit 1 fi @@ -33,9 +79,30 @@ echo "Running datamodel-codegen..." declare -a source_args if [ -n "$INPUT_FILE" ]; then echo " - Source File: $INPUT_FILE" + if [ ! -f "$INPUT_FILE" ]; then + echo "Error: Input file does not exist: $INPUT_FILE" >&2 + exit 1 + fi source_args=("--input" "$INPUT_FILE") else + echo " - A2A Spec Version: $A2A_SPEC_VERSION" echo " - Source URL: $REMOTE_URL" + + # Validate that the remote URL is accessible + echo " - Validating remote URL..." + if ! curl --fail --silent --head "$REMOTE_URL" >/dev/null 2>&1; then + cat >&2 < AsyncIterator[ClientEvent | Message]: """Sends a message to the agent. @@ -57,6 +60,8 @@ async def send_message( Args: request: The message to send to the agent. context: The client call context. + request_metadata: Extensions Metadata attached to the request. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` or a final `Message` response. @@ -70,11 +75,13 @@ async def send_message( else None ), ) - params = MessageSendParams(message=request, configuration=config) + params = MessageSendParams( + message=request, configuration=config, metadata=request_metadata + ) if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport.send_message( - params, context=context + params, context=context, extensions=extensions ) result = ( (response, None) if isinstance(response, Task) else response @@ -84,7 +91,9 @@ async def send_message( return tracker = ClientTaskManager() - stream = self._transport.send_message_streaming(params, context=context) + stream = self._transport.send_message_streaming( + params, context=context, extensions=extensions + ) first_event = await anext(stream) # The response from a server may be either exactly one Message or a @@ -121,74 +130,91 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task. Args: request: The `TaskQueryParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object representing the current state of the task. """ - return await self._transport.get_task(request, context=context) + return await self._transport.get_task( + request, context=context, extensions=extensions + ) async def cancel_task( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task. Args: request: The `TaskIdParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object containing the updated task status. """ - return await self._transport.cancel_task(request, context=context) + return await self._transport.cancel_task( + request, context=context, extensions=extensions + ) async def set_task_callback( self, request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task. Args: request: The `TaskPushNotificationConfig` object with the new configuration. context: The client call context. + extensions: List of extensions to be activated. Returns: The created or updated `TaskPushNotificationConfig` object. """ - return await self._transport.set_task_callback(request, context=context) + return await self._transport.set_task_callback( + request, context=context, extensions=extensions + ) async def get_task_callback( self, request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. Args: request: The `GetTaskPushNotificationConfigParams` object specifying the task. context: The client call context. + extensions: List of extensions to be activated. Returns: A `TaskPushNotificationConfig` object containing the configuration. """ - return await self._transport.get_task_callback(request, context=context) + return await self._transport.get_task_callback( + request, context=context, extensions=extensions + ) async def resubscribe( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream. @@ -197,6 +223,7 @@ async def resubscribe( Args: request: Parameters to identify the task to resubscribe to. context: The client call context. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` objects. @@ -214,12 +241,15 @@ async def resubscribe( # we should never see Message updates, despite the typing of the service # definition indicating it may be possible. async for event in self._transport.resubscribe( - request, context=context + request, context=context, extensions=extensions ): yield await self._process_response(tracker, event) async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card. @@ -228,11 +258,14 @@ async def get_card( Args: context: The client call context. + extensions: List of extensions to be activated. Returns: The `AgentCard` for the agent. """ - card = await self._transport.get_card(context=context) + card = await self._transport.get_card( + context=context, extensions=extensions + ) self._card = card return card diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 7cc10423..fd97b4d1 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -67,6 +67,9 @@ class ClientConfig: ) """Push notification callbacks to use for every request.""" + extensions: list[str] = dataclasses.field(default_factory=list) + """A list of extension URIs the client supports.""" + UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None # Alias for emitted events from client @@ -110,6 +113,8 @@ async def send_message( request: Message, *, context: ClientCallContext | None = None, + request_metadata: dict[str, Any] | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. @@ -128,6 +133,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -137,6 +143,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -146,6 +153,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -155,6 +163,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -164,6 +173,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream.""" return @@ -171,7 +181,10 @@ async def resubscribe( @abstractmethod async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 65b3fb5f..fabd7270 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -80,6 +80,7 @@ def _register_defaults( card, url, interceptors, + config.extensions or None, ), ) if TransportProtocol.http_json in supported: @@ -90,6 +91,7 @@ def _register_defaults( card, url, interceptors, + config.extensions or None, ), ) if TransportProtocol.grpc in supported: @@ -113,6 +115,7 @@ async def connect( # noqa: PLR0913 relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, extra_transports: dict[str, TransportProducer] | None = None, + extensions: list[str] | None = None, ) -> Client: """Convenience method for constructing a client. @@ -142,6 +145,7 @@ async def connect( # noqa: PLR0913 A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. extra_transports: Additional transport protocols to enable when constructing the client. + extensions: List of extensions to be activated. Returns: A `Client` object. @@ -166,7 +170,7 @@ async def connect( # noqa: PLR0913 factory = cls(client_config) for label, generator in (extra_transports or {}).items(): factory.register(label, generator) - return factory.create(card, consumers, interceptors) + return factory.create(card, consumers, interceptors, extensions) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -177,6 +181,7 @@ def create( card: AgentCard, consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ) -> Client: """Create a new `Client` for the provided `AgentCard`. @@ -186,6 +191,7 @@ def create( interceptors: A list of interceptors to use for each request. These are used for things like attaching credentials or http headers to all outbound requests. + extensions: List of extensions to be activated. Returns: A `Client` object. @@ -226,12 +232,21 @@ def create( if consumers: all_consumers.extend(consumers) + all_extensions = self._config.extensions.copy() + if extensions: + all_extensions.extend(extensions) + self._config.extensions = all_extensions + transport = self._registry[transport_protocol]( card, transport_url, self._config, interceptors or [] ) return BaseClient( - card, self._config, transport, all_consumers, interceptors or [] + card, + self._config, + transport, + all_consumers, + interceptors or [], ) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 3573cb7c..8f114d95 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -25,6 +25,7 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" @@ -34,6 +35,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -47,6 +49,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -56,6 +59,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -65,6 +69,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -74,6 +79,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -83,6 +89,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -95,6 +102,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the AgentCard.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e50b0ea8..4e27953a 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -12,10 +12,12 @@ "'pip install a2a-sdk[grpc]'" ) from e + from a2a.client.client import ClientConfig from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -44,6 +46,7 @@ def __init__( self, channel: Channel, agent_card: AgentCard | None, + extensions: list[str] | None = None, ): """Initializes the GrpcTransport.""" self.agent_card = agent_card @@ -54,6 +57,18 @@ def __init__( if agent_card else True ) + self.extensions = extensions + + def _get_grpc_metadata( + self, + extensions: list[str] | None = None, + ) -> list[tuple[str, str]] | None: + """Creates gRPC metadata for extensions.""" + if extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(extensions))] + if self.extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))] + return None @classmethod def create( @@ -66,16 +81,14 @@ def create( """Creates a gRPC transport for the A2A client.""" if config.grpc_channel_factory is None: raise ValueError('grpc_channel_factory is required when using gRPC') - return cls( - config.grpc_channel_factory(url), - card, - ) + return cls(config.grpc_channel_factory(url), card, config.extensions) async def send_message( self, request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" response = await self.stub.SendMessage( @@ -85,7 +98,8 @@ async def send_message( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -96,6 +110,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -107,7 +122,8 @@ async def send_message_streaming( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -116,13 +132,18 @@ async def send_message_streaming( yield proto_utils.FromProto.stream_response(response) async def resubscribe( - self, request: TaskIdParams, *, context: ClientCallContext | None = None + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -135,13 +156,15 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" task = await self.stub.GetTask( a2a_pb2.GetTaskRequest( name=f'tasks/{request.id}', history_length=request.history_length, - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -150,10 +173,12 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -162,6 +187,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( @@ -171,7 +197,8 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config( request ), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -180,12 +207,14 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -193,6 +222,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -203,6 +233,7 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), + metadata=self._get_grpc_metadata(extensions), ) card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index bfba09d7..d8011cf4 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -18,6 +18,7 @@ ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import update_extension_header from a2a.types import ( AgentCard, CancelTaskRequest, @@ -62,6 +63,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ): """Initializes the JsonRpcTransport.""" if url: @@ -79,6 +81,7 @@ def __init__( if agent_card else True ) + self.extensions = extensions async def _apply_interceptors( self, @@ -113,13 +116,18 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/send', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -133,6 +141,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -140,13 +149,16 @@ async def send_message_streaming( rpc_request = SendStreamingMessageRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/stream', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) ) @@ -207,13 +219,18 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -227,13 +244,18 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/cancel', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -247,15 +269,20 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" rpc_request = SetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/set', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -271,15 +298,20 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" rpc_request = GetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -295,18 +327,22 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - modified_kwargs.setdefault('timeout', None) async with aconnect_sse( @@ -339,6 +375,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -356,13 +393,16 @@ async def get_card( return card request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - response_data = await self._send_request( payload, modified_kwargs, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index eef7b0f2..83c26787 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -13,6 +13,7 @@ from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import update_extension_header from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, @@ -43,6 +44,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ): """Initializes the RestTransport.""" if url: @@ -61,6 +63,7 @@ def __init__( if agent_card else True ) + self.extensions = extensions async def _apply_interceptors( self, @@ -79,7 +82,10 @@ def _get_http_args( return context.state.get('http_kwargs') if context else None async def _prepare_send_message( - self, request: MessageSendParams, context: ClientCallContext | None + self, + request: MessageSendParams, + context: ClientCallContext | None, + extensions: list[str] | None = None, ) -> tuple[dict[str, Any], dict[str, Any]]: pb = a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), @@ -93,9 +99,13 @@ async def _prepare_send_message( ), ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) return payload, modified_kwargs @@ -105,10 +115,11 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) response_data = await self._send_post_request( '/v1/message:send', payload, modified_kwargs @@ -122,12 +133,13 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Sends a streaming message request to the agent and yields responses as they arrive.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) modified_kwargs.setdefault('timeout', None) @@ -204,11 +216,16 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) _payload, modified_kwargs = await self._apply_interceptors( request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( @@ -227,13 +244,18 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_post_request( @@ -248,6 +270,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( @@ -256,8 +279,12 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config(request), ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( - payload, self._get_http_args(context), context + payload, modified_kwargs, context ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', @@ -273,15 +300,20 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" pb = a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( @@ -298,18 +330,22 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - http_kwargs = self._get_http_args(context) or {} - http_kwargs.setdefault('timeout', None) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, 'GET', f'{self.url}/v1/tasks/{request.id}:subscribe', - **http_kwargs, + **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): @@ -331,6 +367,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -347,9 +384,13 @@ async def get_card( if not self._needs_extended_card: return card + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) _, modified_kwargs = await self._apply_interceptors( {}, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 2f752caa..cba3517e 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,3 +1,5 @@ +from typing import Any + from a2a.types import AgentCard, AgentExtension @@ -25,3 +27,15 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: return ext return None + + +def update_extension_header( + http_kwargs: dict[str, Any] | None, + extensions: list[str] | None, +) -> dict[str, Any]: + """Update the X-A2A-Extensions header with active extensions.""" + http_kwargs = http_kwargs or {} + if extensions is not None: + headers = http_kwargs.setdefault('headers', {}) + headers[HTTP_EXTENSION_HEADER] = ','.join(extensions) + return http_kwargs diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index 4ba7fdce..ace2c6ae 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -77,6 +77,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the A2AFastAPIApplication. @@ -94,6 +95,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_fastapi_installed: raise ImportError( @@ -108,6 +111,7 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, + max_content_length=max_content_length, ) def add_routes_to_app( diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index d258916c..3e7c2854 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -91,8 +91,6 @@ Response = Any HTTP_413_REQUEST_ENTITY_TOO_LARGE = Any -MAX_CONTENT_LENGTH = 10_000_000 - class StarletteUserProxy(A2AUser): """Adapts the Starlette User class to the A2A user representation.""" @@ -185,6 +183,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the JSONRPCApplication. @@ -202,6 +201,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_starlette_installed: raise ImportError( @@ -220,6 +221,7 @@ def __init__( # noqa: PLR0913 extended_card_modifier=extended_card_modifier, ) self._context_builder = context_builder or DefaultCallContextBuilder() + self._max_content_length = max_content_length def _generate_error_response( self, request_id: str | int | None, error: JSONRPCError | A2AError @@ -261,6 +263,22 @@ def _generate_error_response( status_code=200, ) + def _allowed_content_length(self, request: Request) -> bool: + """Checks if the request content length is within the allowed maximum. + + Args: + request: The incoming Starlette Request object. + + Returns: + False if the content length is larger than the allowed maximum, True otherwise. + """ + if self._max_content_length is not None: + with contextlib.suppress(ValueError): + content_length = int(request.headers.get('content-length', '0')) + if content_length and content_length > self._max_content_length: + return False + return True + async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 """Handles incoming POST requests to the main A2A endpoint. @@ -291,18 +309,14 @@ async def _handle_requests(self, request: Request) -> Response: # noqa: PLR0911 request_id, str | int ): request_id = None - # Treat very large payloads as invalid request (-32600) before routing - with contextlib.suppress(Exception): - content_length = int(request.headers.get('content-length', '0')) - if content_length and content_length > MAX_CONTENT_LENGTH: - return self._generate_error_response( - request_id, - A2AError( - root=InvalidRequestError( - message='Payload too large' - ) - ), - ) + # Treat payloads lager than allowed as invalid request (-32600) before routing + if not self._allowed_content_length(request): + return self._generate_error_response( + request_id, + A2AError( + root=InvalidRequestError(message='Payload too large') + ), + ) logger.debug('Request body: %s', body) # 1) Validate base JSON-RPC structure only (-32600 on failure) try: diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index b268d043..1effa9d5 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -59,6 +59,7 @@ def __init__( # noqa: PLR0913 [AgentCard, ServerCallContext], AgentCard ] | None = None, + max_content_length: int | None = 10 * 1024 * 1024, # 10MB ) -> None: """Initializes the A2AStarletteApplication. @@ -76,6 +77,8 @@ def __init__( # noqa: PLR0913 extended_card_modifier: An optional callback to dynamically modify the extended agent card before it is served. It receives the call context. + max_content_length: The maximum allowed content length for incoming + requests. Defaults to 10MB. Set to None for unbounded maximum. """ if not _package_starlette_installed: raise ImportError( @@ -90,6 +93,7 @@ def __init__( # noqa: PLR0913 context_builder=context_builder, card_modifier=card_modifier, extended_card_modifier=extended_card_modifier, + max_content_length=max_content_length, ) def routes( diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index e619cd72..d077d62b 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -57,7 +57,7 @@ def make_dict_serializable(value: Any) -> Any: Returns: A serializable value. """ - if isinstance(value, (str, int, float, bool)) or value is None: + if isinstance(value, str | int | float | bool) or value is None: return value if isinstance(value, dict): return {k: make_dict_serializable(v) for k, v in value.items()} @@ -140,6 +140,7 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None: task_id=message.task_id or '', role=cls.role(message.role), metadata=cls.metadata(message.metadata), + extensions=message.extensions or [], ) @classmethod @@ -239,6 +240,7 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact: metadata=cls.metadata(artifact.metadata), name=artifact.name, parts=[cls.part(p) for p in artifact.parts], + extensions=artifact.extensions or [], ) @classmethod @@ -695,6 +697,7 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact: metadata=cls.metadata(artifact.metadata), name=artifact.name, parts=[cls.part(p) for p in artifact.parts], + extensions=artifact.extensions or None, ) @classmethod diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 5c5f3f07..d8215cec 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -83,11 +83,9 @@ def apply_history_length(task: Task, history_length: int | None) -> Task: A new task object with limited history """ # Apply historyLength parameter if specified - if history_length is not None and task.history: + if history_length is not None and history_length > 0 and task.history: # Limit history to the most recent N messages - limited_history = ( - task.history[-history_length:] if history_length > 0 else [] - ) + limited_history = task.history[-history_length:] # Create a new task instance with limited history return task.model_copy(update={'history': limited_history}) diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index d93a2203..f5ab2543 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -73,9 +73,14 @@ async def create_stream(*args, **kwargs): mock_transport.send_message_streaming.return_value = create_stream() - events = [event async for event in base_client.send_message(sample_message)] + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] mock_transport.send_message_streaming.assert_called_once() + assert ( + mock_transport.send_message_streaming.call_args[0][0].metadata == meta + ) assert not mock_transport.send_message.called assert len(events) == 1 assert events[0][0].id == 'task-123' @@ -92,9 +97,12 @@ async def test_send_message_non_streaming( status=TaskStatus(state=TaskState.completed), ) - events = [event async for event in base_client.send_message(sample_message)] + meta = {'test': 1} + stream = base_client.send_message(sample_message, request_metadata=meta) + events = [event async for event in stream] mock_transport.send_message.assert_called_once() + assert mock_transport.send_message.call_args[0][0].metadata == meta assert not mock_transport.send_message_streaming.called assert len(events) == 1 assert events[0][0].id == 'task-456' diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 847b256f..16a1433f 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -39,12 +39,14 @@ def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): TransportProtocol.jsonrpc, TransportProtocol.http_json, ], + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) assert client._transport.url == 'http://primary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_selects_secondary_transport_url( @@ -65,12 +67,14 @@ def test_client_factory_selects_secondary_transport_url( TransportProtocol.jsonrpc, ], use_client_preference=True, + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, RestTransport) assert client._transport.url == 'http://secondary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_server_preference(base_agent_card: AgentCard): diff --git a/tests/client/test_grpc_client.py b/tests/client/transports/test_grpc_client.py similarity index 80% rename from tests/client/test_grpc_client.py rename to tests/client/transports/test_grpc_client.py index 6dab75e9..111e44ba 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -4,6 +4,7 @@ import pytest from a2a.client.transports.grpc import GrpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -64,7 +65,14 @@ def grpc_transport( ) -> GrpcTransport: """Provides a GrpcTransport instance.""" channel = AsyncMock() - transport = GrpcTransport(channel=channel, agent_card=sample_agent_card) + transport = GrpcTransport( + channel=channel, + agent_card=sample_agent_card, + extensions=[ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ], + ) transport.stub = mock_grpc_stub return transport @@ -185,9 +193,19 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_transport.send_message(sample_message_send_params) + response = await grpc_transport.send_message( + sample_message_send_params, + extensions=['https://example.com/test-ext/v3'], + ) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v3', + ) + ] assert isinstance(response, Task) assert response.id == sample_task.id @@ -207,6 +225,13 @@ async def test_send_message_message_response( response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ] assert isinstance(response, Message) assert response.message_id == sample_message.message_id assert get_text_parts(response.parts) == get_text_parts( @@ -255,6 +280,13 @@ async def test_send_message_streaming( # noqa: PLR0913 ] mock_grpc_stub.SendStreamingMessage.assert_called_once() + _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ] assert isinstance(responses[0], Message) assert responses[0].message_id == sample_message.message_id assert isinstance(responses[1], Task) @@ -278,7 +310,13 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=None - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.id == sample_task.id @@ -297,7 +335,13 @@ async def test_get_task_with_history( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=history_len - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) @@ -312,11 +356,14 @@ async def test_cancel_task( cancelled_task ) params = TaskIdParams(id=sample_task.id) - - response = await grpc_transport.cancel_task(params) + extensions = [ + 'https://example.com/test-ext/v3', + ] + response = await grpc_transport.cancel_task(params, extensions=extensions) mock_grpc_stub.CancelTask.assert_awaited_once_with( - a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), + metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')], ) assert response.status.state == TaskState.canceled @@ -345,7 +392,13 @@ async def test_set_task_callback_with_valid_task( config=proto_utils.ToProto.task_push_notification_config( sample_task_push_notification_config ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -402,7 +455,13 @@ async def test_get_task_callback_with_valid_task( f'tasks/{params.id}/' f'pushNotificationConfigs/{params.push_notification_config_id}' ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -434,3 +493,50 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) + + +@pytest.mark.parametrize( + 'initial_extensions, input_extensions, expected_metadata', + [ + ( + None, + None, + None, + ), # Case 1: No initial, No input + ( + ['ext1'], + None, + [(HTTP_EXTENSION_HEADER, 'ext1')], + ), # Case 2: Initial, No input + ( + None, + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ), # Case 3: No initial, Input + ( + ['ext1'], + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ), # Case 4: Initial, Input (override) + ( + ['ext1'], + ['ext2', 'ext3'], + [(HTTP_EXTENSION_HEADER, 'ext2,ext3')], + ), # Case 5: Initial, Multiple inputs (override) + ( + ['ext1', 'ext2'], + ['ext3'], + [(HTTP_EXTENSION_HEADER, 'ext3')], + ), # Case 6: Multiple initial, Single input (override) + ], +) +def test_get_grpc_metadata( + grpc_transport: GrpcTransport, + initial_extensions: list[str] | None, + input_extensions: list[str] | None, + expected_metadata: list[tuple[str, str]] | None, +) -> None: + """Tests _get_grpc_metadata for correct metadata generation and self.extensions update.""" + grpc_transport.extensions = initial_extensions + metadata = grpc_transport._get_grpc_metadata(input_extensions) + assert metadata == expected_metadata diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py similarity index 89% rename from tests/client/test_jsonrpc_client.py rename to tests/client/transports/test_jsonrpc_client.py index 58feec25..bd705d93 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -17,6 +17,7 @@ create_text_message_object, ) from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCapabilities, AgentCard, @@ -785,3 +786,92 @@ async def test_close(self, mock_httpx_client: AsyncMock): ) await client.close() mock_httpx_client.aclose.assert_called_once() + + +class TestJsonRpcTransportExtensions: + @pytest.mark.asyncio + async def test_send_message_with_default_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds extension headers when extensions are provided.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + success_response = create_text_message_object( + role=Role.agent, content='Hi there!' + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + # Mock the response from httpx_client.post + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = rpc_response.model_dump(mode='json') + mock_httpx_client.post.return_value = mock_response + + await client.send_message(request=params) + + mock_httpx_client.post.assert_called_once() + _, mock_kwargs = mock_httpx_client.post.call_args + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_with_new_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] + extensions = ['https://example.com/test-ext/v1'] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py new file mode 100644 index 00000000..04bd1036 --- /dev/null +++ b/tests/client/transports/test_rest_client.py @@ -0,0 +1,121 @@ +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from httpx_sse import EventSource, ServerSentEvent + +from a2a.client import create_text_message_object +from a2a.client.transports.rest import RestTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.types import AgentCard, MessageSendParams, Role + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_agent_card() -> MagicMock: + mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') + mock.supports_authenticated_extended_card = False + return mock + + +async def async_iterable_from_list( + items: list[ServerSentEvent], +) -> AsyncGenerator[ServerSentEvent, None]: + """Helper to create an async iterable from a list.""" + for item in items: + yield item + + +class TestRestTransportExtensions: + @pytest.mark.asyncio + async def test_send_message_with_default_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds extensions to headers.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = RestTransport( + httpx_client=mock_httpx_client, + extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + # Mock the build_request method to capture its inputs + mock_build_request = MagicMock( + return_value=AsyncMock(spec=httpx.Request) + ) + mock_httpx_client.build_request = mock_build_request + + # Mock the send method + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_httpx_client.send.return_value = mock_response + + await client.send_message(request=params) + + mock_build_request.assert_called_once() + _, kwargs = mock_build_request.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_with_new_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] + extensions = ['https://example.com/test-ext/v1'] + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + ) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 137e64c9..b3123028 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -1,6 +1,9 @@ +import pytest from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, find_extension_by_uri, get_requested_extensions, + update_extension_header, ) from a2a.types import AgentCapabilities, AgentCard, AgentExtension @@ -56,3 +59,88 @@ def test_find_extension_by_uri_no_extensions(): ) assert find_extension_by_uri(card, 'foo') is None + + +@pytest.mark.parametrize( + 'extensions, header, expected_extensions', + [ + ( + ['ext1', 'ext2'], # extensions + '', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 1: New extensions provided, empty header. + ( + None, # extensions + 'ext1, ext2', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 2: Extensions is None, existing header extensions. + ( + [], # extensions + 'ext1', # header + {}, # expected_extensions + ), # Case 3: New extensions is empty list, existing header extensions. + ( + ['ext1', 'ext2'], # extensions + 'ext3', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 4: New extensions provided, and an existing header. New extensions should override active extensions. + ], +) +def test_update_extension_header_merge_with_existing_extensions( + extensions: list[str], + header: str, + expected_extensions: set[str], +): + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: header}} + result_kwargs = update_extension_header(http_kwargs, extensions) + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + if not header_value: + actual_extensions = {} + else: + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + assert actual_extensions == expected_extensions + + +def test_update_extension_header_with_other_headers(): + extensions = ['ext'] + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'ext' + assert headers['X_Other'] == 'Test' + + +@pytest.mark.parametrize( + 'http_kwargs', + [ + None, + {}, + ], +) +def test_update_extension_header_headers_not_in_kwargs( + http_kwargs: dict[str, str] | None, +): + extensions = ['ext'] + http_kwargs = {} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'ext' + + +def test_update_extension_header_with_other_headers_extensions_none(): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, None) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 88d4d3d1..e0a564ee 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import AsyncGenerator from typing import NamedTuple -from unittest.mock import ANY, AsyncMock +from unittest.mock import ANY, AsyncMock, patch import grpc import httpx @@ -9,6 +9,8 @@ import pytest_asyncio from grpc.aio import Channel +from a2a.client import ClientConfig +from a2a.client.base_client import BaseClient from a2a.client.transports import JsonRpcTransport, RestTransport from a2a.client.transports.base import ClientTransport from a2a.client.transports.grpc import GrpcTransport @@ -767,3 +769,61 @@ def channel_factory(address: str) -> Channel: assert transport._needs_extended_card is False await transport.close() + + +@pytest.mark.asyncio +async def test_base_client_sends_message_with_extensions( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """ + Integration test for BaseClient with JSON-RPC transport to ensure extensions are included in headers. + """ + transport = jsonrpc_setup.transport + agent_card.capabilities.streaming = False + + # Create a BaseClient instance + client = BaseClient( + card=agent_card, + config=ClientConfig(streaming=False), + transport=transport, + consumers=[], + middleware=[], + ) + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-extensions', + parts=[Part(root=TextPart(text='Hello, extensions test!'))], + ) + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + + with patch.object( + transport, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': TASK_FROM_BLOCKING.model_dump(mode='json'), + } + + # Call send_message on the BaseClient + async for _ in client.send_message( + request=message_to_send, extensions=extensions + ): + pass + + mock_send_request.assert_called_once() + call_args, _ = mock_send_request.call_args + kwargs = call_args[1] + headers = kwargs.get('headers', {}) + assert 'X-A2A-Extensions' in headers + assert ( + headers['X-A2A-Extensions'] + == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + ) + + if hasattr(transport, 'close'): + await transport.close() diff --git a/tests/server/apps/jsonrpc/test_serialization.py b/tests/server/apps/jsonrpc/test_serialization.py index 9365017b..f6778046 100644 --- a/tests/server/apps/jsonrpc/test_serialization.py +++ b/tests/server/apps/jsonrpc/test_serialization.py @@ -136,6 +136,42 @@ def test_handle_oversized_payload(agent_card_with_api_key: AgentCard): assert data['error']['code'] == InvalidRequestError().code +@pytest.mark.parametrize( + 'max_content_length', + [ + None, + 11 * 1024 * 1024, + 30 * 1024 * 1024, + ], +) +def test_handle_oversized_payload_with_max_content_length( + agent_card_with_api_key: AgentCard, + max_content_length: int | None, +): + """Test handling of JSON payloads with sizes within custom max_content_length.""" + handler = mock.AsyncMock() + app_instance = A2AStarletteApplication( + agent_card_with_api_key, handler, max_content_length=max_content_length + ) + client = TestClient(app_instance.build()) + + large_string = 'a' * 11 * 1_000_000 # 11MB string + payload = { + 'jsonrpc': '2.0', + 'method': 'test', + 'id': 1, + 'params': {'data': large_string}, + } + + response = client.post('/', json=payload) + assert response.status_code == 200 + data = response.json() + # When max_content_length is set, requests up to that size should not be + # rejected due to payload size. The request might fail for other reasons, + # but it shouldn't be an InvalidRequestError related to the content length. + assert data['error']['code'] != InvalidRequestError().code + + def test_handle_unicode_characters(agent_card_with_api_key: AgentCard): """Test handling of unicode characters in JSON payload.""" handler = mock.AsyncMock() diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 5268af11..88dd77ab 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -834,6 +834,11 @@ async def test_on_message_send_non_blocking(): assert task is not None assert task.status.state == TaskState.completed + assert ( + result.history + and task.history + and len(result.history) == len(task.history) + ) @pytest.mark.asyncio @@ -855,7 +860,7 @@ async def test_on_message_send_limit_history(): configuration=MessageSendConfiguration( blocking=True, accepted_output_modes=['text/plain'], - history_length=0, + history_length=1, ), ) @@ -866,17 +871,17 @@ async def test_on_message_send_limit_history(): # verify that history_length is honored assert result is not None assert isinstance(result, Task) - assert result.history is not None and len(result.history) == 0 + assert result.history is not None and len(result.history) == 1 assert result.status.state == TaskState.completed # verify that history is still persisted to the store task = await task_store.get(result.id) assert task is not None - assert task.history is not None and len(task.history) > 0 + assert task.history is not None and len(task.history) > 1 @pytest.mark.asyncio -async def test_on_task_get_limit_history(): +async def test_on_get_task_limit_history(): task_store = InMemoryTaskStore() push_store = InMemoryPushNotificationConfigStore() @@ -892,7 +897,8 @@ async def test_on_task_get_limit_history(): parts=[Part(root=TextPart(text='Hi'))], ), configuration=MessageSendConfiguration( - blocking=True, accepted_output_modes=['text/plain'] + blocking=True, + accepted_output_modes=['text/plain'], ), ) @@ -904,14 +910,14 @@ async def test_on_task_get_limit_history(): assert isinstance(result, Task) get_task_result = await request_handler.on_get_task( - TaskQueryParams(id=result.id, history_length=0), + TaskQueryParams(id=result.id, history_length=1), create_server_call_context(), ) assert get_task_result is not None assert isinstance(get_task_result, Task) assert ( get_task_result.history is not None - and len(get_task_result.history) == 0 + and len(get_task_result.history) == 1 )