diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index b8697d86..5719bc1b 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -49,6 +49,7 @@ async def send_message( *, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the agent. @@ -60,6 +61,7 @@ async def send_message( request: The message to send to the agent. context: The client call context. request_metadata: Extensions Metadata attached to the request. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` or a final `Message` response. @@ -79,7 +81,7 @@ async def send_message( if not self._config.streaming or not self._card.capabilities.streaming: response = await self._transport.send_message( - params, context=context + params, context=context, extensions=extensions ) result = ( (response, None) if isinstance(response, Task) else response @@ -89,7 +91,9 @@ async def send_message( return tracker = ClientTaskManager() - stream = self._transport.send_message_streaming(params, context=context) + stream = self._transport.send_message_streaming( + params, context=context, extensions=extensions + ) first_event = await anext(stream) # The response from a server may be either exactly one Message or a @@ -126,74 +130,91 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task. Args: request: The `TaskQueryParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object representing the current state of the task. """ - return await self._transport.get_task(request, context=context) + return await self._transport.get_task( + request, context=context, extensions=extensions + ) async def cancel_task( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task. Args: request: The `TaskIdParams` object specifying the task ID. context: The client call context. + extensions: List of extensions to be activated. Returns: A `Task` object containing the updated task status. """ - return await self._transport.cancel_task(request, context=context) + return await self._transport.cancel_task( + request, context=context, extensions=extensions + ) async def set_task_callback( self, request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task. Args: request: The `TaskPushNotificationConfig` object with the new configuration. context: The client call context. + extensions: List of extensions to be activated. Returns: The created or updated `TaskPushNotificationConfig` object. """ - return await self._transport.set_task_callback(request, context=context) + return await self._transport.set_task_callback( + request, context=context, extensions=extensions + ) async def get_task_callback( self, request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task. Args: request: The `GetTaskPushNotificationConfigParams` object specifying the task. context: The client call context. + extensions: List of extensions to be activated. Returns: A `TaskPushNotificationConfig` object containing the configuration. """ - return await self._transport.get_task_callback(request, context=context) + return await self._transport.get_task_callback( + request, context=context, extensions=extensions + ) async def resubscribe( self, request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream. @@ -202,6 +223,7 @@ async def resubscribe( Args: request: Parameters to identify the task to resubscribe to. context: The client call context. + extensions: List of extensions to be activated. Yields: An async iterator of `ClientEvent` objects. @@ -219,12 +241,15 @@ async def resubscribe( # we should never see Message updates, despite the typing of the service # definition indicating it may be possible. async for event in self._transport.resubscribe( - request, context=context + request, context=context, extensions=extensions ): yield await self._process_response(tracker, event) async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card. @@ -233,11 +258,14 @@ async def get_card( Args: context: The client call context. + extensions: List of extensions to be activated. Returns: The `AgentCard` for the agent. """ - card = await self._transport.get_card(context=context) + card = await self._transport.get_card( + context=context, extensions=extensions + ) self._card = card return card diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 0e1c4323..fd97b4d1 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -67,6 +67,9 @@ class ClientConfig: ) """Push notification callbacks to use for every request.""" + extensions: list[str] = dataclasses.field(default_factory=list) + """A list of extension URIs the client supports.""" + UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None # Alias for emitted events from client @@ -111,6 +114,7 @@ async def send_message( *, context: ClientCallContext | None = None, request_metadata: dict[str, Any] | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the server. @@ -129,6 +133,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -138,6 +143,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -147,6 +153,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -156,6 +163,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -165,6 +173,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncIterator[ClientEvent]: """Resubscribes to a task's event stream.""" return @@ -172,7 +181,10 @@ async def resubscribe( @abstractmethod async def get_card( - self, *, context: ClientCallContext | None = None + self, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 65b3fb5f..fabd7270 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -80,6 +80,7 @@ def _register_defaults( card, url, interceptors, + config.extensions or None, ), ) if TransportProtocol.http_json in supported: @@ -90,6 +91,7 @@ def _register_defaults( card, url, interceptors, + config.extensions or None, ), ) if TransportProtocol.grpc in supported: @@ -113,6 +115,7 @@ async def connect( # noqa: PLR0913 relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, extra_transports: dict[str, TransportProducer] | None = None, + extensions: list[str] | None = None, ) -> Client: """Convenience method for constructing a client. @@ -142,6 +145,7 @@ async def connect( # noqa: PLR0913 A2AAgentCardResolver.get_agent_card as the http_kwargs parameter. extra_transports: Additional transport protocols to enable when constructing the client. + extensions: List of extensions to be activated. Returns: A `Client` object. @@ -166,7 +170,7 @@ async def connect( # noqa: PLR0913 factory = cls(client_config) for label, generator in (extra_transports or {}).items(): factory.register(label, generator) - return factory.create(card, consumers, interceptors) + return factory.create(card, consumers, interceptors, extensions) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -177,6 +181,7 @@ def create( card: AgentCard, consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ) -> Client: """Create a new `Client` for the provided `AgentCard`. @@ -186,6 +191,7 @@ def create( interceptors: A list of interceptors to use for each request. These are used for things like attaching credentials or http headers to all outbound requests. + extensions: List of extensions to be activated. Returns: A `Client` object. @@ -226,12 +232,21 @@ def create( if consumers: all_consumers.extend(consumers) + all_extensions = self._config.extensions.copy() + if extensions: + all_extensions.extend(extensions) + self._config.extensions = all_extensions + transport = self._registry[transport_protocol]( card, transport_url, self._config, interceptors or [] ) return BaseClient( - card, self._config, transport, all_consumers, interceptors or [] + card, + self._config, + transport, + all_consumers, + interceptors or [], ) diff --git a/src/a2a/client/transports/base.py b/src/a2a/client/transports/base.py index 3573cb7c..8f114d95 100644 --- a/src/a2a/client/transports/base.py +++ b/src/a2a/client/transports/base.py @@ -25,6 +25,7 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" @@ -34,6 +35,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -47,6 +49,7 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" @@ -56,6 +59,7 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" @@ -65,6 +69,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" @@ -74,6 +79,7 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" @@ -83,6 +89,7 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -95,6 +102,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the AgentCard.""" diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e50b0ea8..4e27953a 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -12,10 +12,12 @@ "'pip install a2a-sdk[grpc]'" ) from e + from a2a.client.client import ClientConfig from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCard, @@ -44,6 +46,7 @@ def __init__( self, channel: Channel, agent_card: AgentCard | None, + extensions: list[str] | None = None, ): """Initializes the GrpcTransport.""" self.agent_card = agent_card @@ -54,6 +57,18 @@ def __init__( if agent_card else True ) + self.extensions = extensions + + def _get_grpc_metadata( + self, + extensions: list[str] | None = None, + ) -> list[tuple[str, str]] | None: + """Creates gRPC metadata for extensions.""" + if extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(extensions))] + if self.extensions is not None: + return [(HTTP_EXTENSION_HEADER, ','.join(self.extensions))] + return None @classmethod def create( @@ -66,16 +81,14 @@ def create( """Creates a gRPC transport for the A2A client.""" if config.grpc_channel_factory is None: raise ValueError('grpc_channel_factory is required when using gRPC') - return cls( - config.grpc_channel_factory(url), - card, - ) + return cls(config.grpc_channel_factory(url), card, config.extensions) async def send_message( self, request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" response = await self.stub.SendMessage( @@ -85,7 +98,8 @@ async def send_message( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) if response.HasField('task'): return proto_utils.FromProto.task(response.task) @@ -96,6 +110,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -107,7 +122,8 @@ async def send_message_streaming( request.configuration ), metadata=proto_utils.ToProto.metadata(request.metadata), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -116,13 +132,18 @@ async def send_message_streaming( yield proto_utils.FromProto.stream_response(response) async def resubscribe( - self, request: TaskIdParams, *, context: ClientCallContext | None = None + self, + request: TaskIdParams, + *, + context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" stream = self.stub.TaskSubscription( - a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}') + a2a_pb2.TaskSubscriptionRequest(name=f'tasks/{request.id}'), + metadata=self._get_grpc_metadata(extensions), ) while True: response = await stream.read() @@ -135,13 +156,15 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" task = await self.stub.GetTask( a2a_pb2.GetTaskRequest( name=f'tasks/{request.id}', history_length=request.history_length, - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -150,10 +173,12 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" task = await self.stub.CancelTask( - a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}'), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task(task) @@ -162,6 +187,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" config = await self.stub.CreateTaskPushNotificationConfig( @@ -171,7 +197,8 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config( request ), - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -180,12 +207,14 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" config = await self.stub.GetTaskPushNotificationConfig( a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', - ) + ), + metadata=self._get_grpc_metadata(extensions), ) return proto_utils.FromProto.task_push_notification_config(config) @@ -193,6 +222,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -203,6 +233,7 @@ async def get_card( card_pb = await self.stub.GetAgentCard( a2a_pb2.GetAgentCardRequest(), + metadata=self._get_grpc_metadata(extensions), ) card = proto_utils.FromProto.agent_card(card_pb) self.agent_card = card diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index bfba09d7..d8011cf4 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -18,6 +18,7 @@ ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import update_extension_header from a2a.types import ( AgentCard, CancelTaskRequest, @@ -62,6 +63,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ): """Initializes the JsonRpcTransport.""" if url: @@ -79,6 +81,7 @@ def __init__( if agent_card else True ) + self.extensions = extensions async def _apply_interceptors( self, @@ -113,13 +116,18 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" rpc_request = SendMessageRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/send', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -133,6 +141,7 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: @@ -140,13 +149,16 @@ async def send_message_streaming( rpc_request = SendStreamingMessageRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'message/stream', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - modified_kwargs.setdefault( 'timeout', self.httpx_client.timeout.as_dict().get('read', None) ) @@ -207,13 +219,18 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" rpc_request = GetTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -227,13 +244,18 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" rpc_request = CancelTaskRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/cancel', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -247,15 +269,20 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" rpc_request = SetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/set', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -271,15 +298,20 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" rpc_request = GetTaskPushNotificationConfigRequest( params=request, id=str(uuid4()) ) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/pushNotificationConfig/get', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_request(payload, modified_kwargs) @@ -295,18 +327,22 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent ]: """Reconnects to get task updates.""" rpc_request = TaskResubscriptionRequest(params=request, id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( 'tasks/resubscribe', rpc_request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - modified_kwargs.setdefault('timeout', None) async with aconnect_sse( @@ -339,6 +375,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -356,13 +393,16 @@ async def get_card( return card request = GetAuthenticatedExtendedCardRequest(id=str(uuid4())) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( request.method, request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) - response_data = await self._send_request( payload, modified_kwargs, diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index eef7b0f2..83c26787 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -13,6 +13,7 @@ from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.transports.base import ClientTransport +from a2a.extensions.common import update_extension_header from a2a.grpc import a2a_pb2 from a2a.types import ( AgentCard, @@ -43,6 +44,7 @@ def __init__( agent_card: AgentCard | None = None, url: str | None = None, interceptors: list[ClientCallInterceptor] | None = None, + extensions: list[str] | None = None, ): """Initializes the RestTransport.""" if url: @@ -61,6 +63,7 @@ def __init__( if agent_card else True ) + self.extensions = extensions async def _apply_interceptors( self, @@ -79,7 +82,10 @@ def _get_http_args( return context.state.get('http_kwargs') if context else None async def _prepare_send_message( - self, request: MessageSendParams, context: ClientCallContext | None + self, + request: MessageSendParams, + context: ClientCallContext | None, + extensions: list[str] | None = None, ) -> tuple[dict[str, Any], dict[str, Any]]: pb = a2a_pb2.SendMessageRequest( request=proto_utils.ToProto.message(request.message), @@ -93,9 +99,13 @@ async def _prepare_send_message( ), ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) return payload, modified_kwargs @@ -105,10 +115,11 @@ async def send_message( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task | Message: """Sends a non-streaming message request to the agent.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) response_data = await self._send_post_request( '/v1/message:send', payload, modified_kwargs @@ -122,12 +133,13 @@ async def send_message_streaming( request: MessageSendParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Sends a streaming message request to the agent and yields responses as they arrive.""" payload, modified_kwargs = await self._prepare_send_message( - request, context + request, context, extensions ) modified_kwargs.setdefault('timeout', None) @@ -204,11 +216,16 @@ async def get_task( request: TaskQueryParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Retrieves the current state and history of a specific task.""" + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) _payload, modified_kwargs = await self._apply_interceptors( request.model_dump(mode='json', exclude_none=True), - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( @@ -227,13 +244,18 @@ async def cancel_task( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> Task: """Requests the agent to cancel a specific task.""" pb = a2a_pb2.CancelTaskRequest(name=f'tasks/{request.id}') payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_post_request( @@ -248,6 +270,7 @@ async def set_task_callback( request: TaskPushNotificationConfig, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Sets or updates the push notification configuration for a specific task.""" pb = a2a_pb2.CreateTaskPushNotificationConfigRequest( @@ -256,8 +279,12 @@ async def set_task_callback( config=proto_utils.ToProto.task_push_notification_config(request), ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( - payload, self._get_http_args(context), context + payload, modified_kwargs, context ) response_data = await self._send_post_request( f'/v1/tasks/{request.task_id}/pushNotificationConfigs', @@ -273,15 +300,20 @@ async def get_task_callback( request: GetTaskPushNotificationConfigParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> TaskPushNotificationConfig: """Retrieves the push notification configuration for a specific task.""" pb = a2a_pb2.GetTaskPushNotificationConfigRequest( name=f'tasks/{request.id}/pushNotificationConfigs/{request.push_notification_config_id}', ) payload = MessageToDict(pb) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) payload, modified_kwargs = await self._apply_interceptors( payload, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( @@ -298,18 +330,22 @@ async def resubscribe( request: TaskIdParams, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AsyncGenerator[ Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | Message ]: """Reconnects to get task updates.""" - http_kwargs = self._get_http_args(context) or {} - http_kwargs.setdefault('timeout', None) + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) + modified_kwargs.setdefault('timeout', None) async with aconnect_sse( self.httpx_client, 'GET', f'{self.url}/v1/tasks/{request.id}:subscribe', - **http_kwargs, + **modified_kwargs, ) as event_source: try: async for sse in event_source.aiter_sse(): @@ -331,6 +367,7 @@ async def get_card( self, *, context: ClientCallContext | None = None, + extensions: list[str] | None = None, ) -> AgentCard: """Retrieves the agent's card.""" card = self.agent_card @@ -347,9 +384,13 @@ async def get_card( if not self._needs_extended_card: return card + modified_kwargs = update_extension_header( + self._get_http_args(context), + extensions if extensions is not None else self.extensions, + ) _, modified_kwargs = await self._apply_interceptors( {}, - self._get_http_args(context), + modified_kwargs, context, ) response_data = await self._send_get_request( diff --git a/src/a2a/extensions/common.py b/src/a2a/extensions/common.py index 2f752caa..cba3517e 100644 --- a/src/a2a/extensions/common.py +++ b/src/a2a/extensions/common.py @@ -1,3 +1,5 @@ +from typing import Any + from a2a.types import AgentCard, AgentExtension @@ -25,3 +27,15 @@ def find_extension_by_uri(card: AgentCard, uri: str) -> AgentExtension | None: return ext return None + + +def update_extension_header( + http_kwargs: dict[str, Any] | None, + extensions: list[str] | None, +) -> dict[str, Any]: + """Update the X-A2A-Extensions header with active extensions.""" + http_kwargs = http_kwargs or {} + if extensions is not None: + headers = http_kwargs.setdefault('headers', {}) + headers[HTTP_EXTENSION_HEADER] = ','.join(extensions) + return http_kwargs diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 847b256f..16a1433f 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -39,12 +39,14 @@ def test_client_factory_selects_preferred_transport(base_agent_card: AgentCard): TransportProtocol.jsonrpc, TransportProtocol.http_json, ], + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, JsonRpcTransport) assert client._transport.url == 'http://primary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_selects_secondary_transport_url( @@ -65,12 +67,14 @@ def test_client_factory_selects_secondary_transport_url( TransportProtocol.jsonrpc, ], use_client_preference=True, + extensions=['https://example.com/test-ext/v0'], ) factory = ClientFactory(config) client = factory.create(base_agent_card) assert isinstance(client._transport, RestTransport) assert client._transport.url == 'http://secondary-url.com' + assert ['https://example.com/test-ext/v0'] == client._transport.extensions def test_client_factory_server_preference(base_agent_card: AgentCard): diff --git a/tests/client/test_grpc_client.py b/tests/client/transports/test_grpc_client.py similarity index 80% rename from tests/client/test_grpc_client.py rename to tests/client/transports/test_grpc_client.py index 6dab75e9..111e44ba 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/transports/test_grpc_client.py @@ -4,6 +4,7 @@ import pytest from a2a.client.transports.grpc import GrpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.grpc import a2a_pb2, a2a_pb2_grpc from a2a.types import ( AgentCapabilities, @@ -64,7 +65,14 @@ def grpc_transport( ) -> GrpcTransport: """Provides a GrpcTransport instance.""" channel = AsyncMock() - transport = GrpcTransport(channel=channel, agent_card=sample_agent_card) + transport = GrpcTransport( + channel=channel, + agent_card=sample_agent_card, + extensions=[ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ], + ) transport.stub = mock_grpc_stub return transport @@ -185,9 +193,19 @@ async def test_send_message_task_response( task=proto_utils.ToProto.task(sample_task) ) - response = await grpc_transport.send_message(sample_message_send_params) + response = await grpc_transport.send_message( + sample_message_send_params, + extensions=['https://example.com/test-ext/v3'], + ) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v3', + ) + ] assert isinstance(response, Task) assert response.id == sample_task.id @@ -207,6 +225,13 @@ async def test_send_message_message_response( response = await grpc_transport.send_message(sample_message_send_params) mock_grpc_stub.SendMessage.assert_awaited_once() + _, kwargs = mock_grpc_stub.SendMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ] assert isinstance(response, Message) assert response.message_id == sample_message.message_id assert get_text_parts(response.parts) == get_text_parts( @@ -255,6 +280,13 @@ async def test_send_message_streaming( # noqa: PLR0913 ] mock_grpc_stub.SendStreamingMessage.assert_called_once() + _, kwargs = mock_grpc_stub.SendStreamingMessage.call_args + assert kwargs['metadata'] == [ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ] assert isinstance(responses[0], Message) assert responses[0].message_id == sample_message.message_id assert isinstance(responses[1], Task) @@ -278,7 +310,13 @@ async def test_get_task( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=None - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.id == sample_task.id @@ -297,7 +335,13 @@ async def test_get_task_with_history( mock_grpc_stub.GetTask.assert_awaited_once_with( a2a_pb2.GetTaskRequest( name=f'tasks/{sample_task.id}', history_length=history_len - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) @@ -312,11 +356,14 @@ async def test_cancel_task( cancelled_task ) params = TaskIdParams(id=sample_task.id) - - response = await grpc_transport.cancel_task(params) + extensions = [ + 'https://example.com/test-ext/v3', + ] + response = await grpc_transport.cancel_task(params, extensions=extensions) mock_grpc_stub.CancelTask.assert_awaited_once_with( - a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}') + a2a_pb2.CancelTaskRequest(name=f'tasks/{sample_task.id}'), + metadata=[(HTTP_EXTENSION_HEADER, 'https://example.com/test-ext/v3')], ) assert response.status.state == TaskState.canceled @@ -345,7 +392,13 @@ async def test_set_task_callback_with_valid_task( config=proto_utils.ToProto.task_push_notification_config( sample_task_push_notification_config ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -402,7 +455,13 @@ async def test_get_task_callback_with_valid_task( f'tasks/{params.id}/' f'pushNotificationConfigs/{params.push_notification_config_id}' ), - ) + ), + metadata=[ + ( + HTTP_EXTENSION_HEADER, + 'https://example.com/test-ext/v1,https://example.com/test-ext/v2', + ) + ], ) assert response.task_id == sample_task_push_notification_config.task_id @@ -434,3 +493,50 @@ async def test_get_task_callback_with_invalid_task( 'Bad TaskPushNotificationConfig resource name' in exc_info.value.error.message ) + + +@pytest.mark.parametrize( + 'initial_extensions, input_extensions, expected_metadata', + [ + ( + None, + None, + None, + ), # Case 1: No initial, No input + ( + ['ext1'], + None, + [(HTTP_EXTENSION_HEADER, 'ext1')], + ), # Case 2: Initial, No input + ( + None, + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ), # Case 3: No initial, Input + ( + ['ext1'], + ['ext2'], + [(HTTP_EXTENSION_HEADER, 'ext2')], + ), # Case 4: Initial, Input (override) + ( + ['ext1'], + ['ext2', 'ext3'], + [(HTTP_EXTENSION_HEADER, 'ext2,ext3')], + ), # Case 5: Initial, Multiple inputs (override) + ( + ['ext1', 'ext2'], + ['ext3'], + [(HTTP_EXTENSION_HEADER, 'ext3')], + ), # Case 6: Multiple initial, Single input (override) + ], +) +def test_get_grpc_metadata( + grpc_transport: GrpcTransport, + initial_extensions: list[str] | None, + input_extensions: list[str] | None, + expected_metadata: list[tuple[str, str]] | None, +) -> None: + """Tests _get_grpc_metadata for correct metadata generation and self.extensions update.""" + grpc_transport.extensions = initial_extensions + metadata = grpc_transport._get_grpc_metadata(input_extensions) + assert metadata == expected_metadata diff --git a/tests/client/test_jsonrpc_client.py b/tests/client/transports/test_jsonrpc_client.py similarity index 89% rename from tests/client/test_jsonrpc_client.py rename to tests/client/transports/test_jsonrpc_client.py index 58feec25..bd705d93 100644 --- a/tests/client/test_jsonrpc_client.py +++ b/tests/client/transports/test_jsonrpc_client.py @@ -17,6 +17,7 @@ create_text_message_object, ) from a2a.client.transports.jsonrpc import JsonRpcTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER from a2a.types import ( AgentCapabilities, AgentCard, @@ -785,3 +786,92 @@ async def test_close(self, mock_httpx_client: AsyncMock): ) await client.close() mock_httpx_client.aclose.assert_called_once() + + +class TestJsonRpcTransportExtensions: + @pytest.mark.asyncio + async def test_send_message_with_default_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds extension headers when extensions are provided.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + success_response = create_text_message_object( + role=Role.agent, content='Hi there!' + ) + rpc_response = SendMessageSuccessResponse( + id='123', jsonrpc='2.0', result=success_response + ) + # Mock the response from httpx_client.post + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = rpc_response.model_dump(mode='json') + mock_httpx_client.post.return_value = mock_response + + await client.send_message(request=params) + + mock_httpx_client.post.assert_called_once() + _, mock_kwargs = mock_httpx_client.post.call_args + + headers = mock_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = headers[HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + @patch('a2a.client.transports.jsonrpc.aconnect_sse') + async def test_send_message_streaming_with_new_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] + extensions = ['https://example.com/test-ext/v1'] + client = JsonRpcTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + ) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py new file mode 100644 index 00000000..04bd1036 --- /dev/null +++ b/tests/client/transports/test_rest_client.py @@ -0,0 +1,121 @@ +from collections.abc import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from httpx_sse import EventSource, ServerSentEvent + +from a2a.client import create_text_message_object +from a2a.client.transports.rest import RestTransport +from a2a.extensions.common import HTTP_EXTENSION_HEADER +from a2a.types import AgentCard, MessageSendParams, Role + + +@pytest.fixture +def mock_httpx_client() -> AsyncMock: + return AsyncMock(spec=httpx.AsyncClient) + + +@pytest.fixture +def mock_agent_card() -> MagicMock: + mock = MagicMock(spec=AgentCard, url='http://agent.example.com/api') + mock.supports_authenticated_extended_card = False + return mock + + +async def async_iterable_from_list( + items: list[ServerSentEvent], +) -> AsyncGenerator[ServerSentEvent, None]: + """Helper to create an async iterable from a list.""" + for item in items: + yield item + + +class TestRestTransportExtensions: + @pytest.mark.asyncio + async def test_send_message_with_default_extensions( + self, mock_httpx_client: AsyncMock, mock_agent_card: MagicMock + ): + """Test that send_message adds extensions to headers.""" + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + client = RestTransport( + httpx_client=mock_httpx_client, + extensions=extensions, + agent_card=mock_agent_card, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + + # Mock the build_request method to capture its inputs + mock_build_request = MagicMock( + return_value=AsyncMock(spec=httpx.Request) + ) + mock_httpx_client.build_request = mock_build_request + + # Mock the send method + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_httpx_client.send.return_value = mock_response + + await client.send_message(request=params) + + mock_build_request.assert_called_once() + _, kwargs = mock_build_request.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + header_value = kwargs['headers'][HTTP_EXTENSION_HEADER] + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + + expected_extensions = { + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + } + assert len(actual_extensions_list) == 2 + assert actual_extensions == expected_extensions + + @pytest.mark.asyncio + @patch('a2a.client.transports.rest.aconnect_sse') + async def test_send_message_streaming_with_new_extensions( + self, + mock_aconnect_sse: AsyncMock, + mock_httpx_client: AsyncMock, + mock_agent_card: MagicMock, + ): + """Test X-A2A-Extensions header in send_message_streaming.""" + new_extensions = ['https://example.com/test-ext/v2'] + extensions = ['https://example.com/test-ext/v1'] + client = RestTransport( + httpx_client=mock_httpx_client, + agent_card=mock_agent_card, + extensions=extensions, + ) + params = MessageSendParams( + message=create_text_message_object(content='Hello stream') + ) + + mock_event_source = AsyncMock(spec=EventSource) + mock_event_source.aiter_sse.return_value = async_iterable_from_list([]) + mock_aconnect_sse.return_value.__aenter__.return_value = ( + mock_event_source + ) + + async for _ in client.send_message_streaming( + request=params, extensions=new_extensions + ): + pass + + mock_aconnect_sse.assert_called_once() + _, kwargs = mock_aconnect_sse.call_args + + headers = kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert ( + headers[HTTP_EXTENSION_HEADER] == 'https://example.com/test-ext/v2' + ) diff --git a/tests/extensions/test_common.py b/tests/extensions/test_common.py index 137e64c9..b3123028 100644 --- a/tests/extensions/test_common.py +++ b/tests/extensions/test_common.py @@ -1,6 +1,9 @@ +import pytest from a2a.extensions.common import ( + HTTP_EXTENSION_HEADER, find_extension_by_uri, get_requested_extensions, + update_extension_header, ) from a2a.types import AgentCapabilities, AgentCard, AgentExtension @@ -56,3 +59,88 @@ def test_find_extension_by_uri_no_extensions(): ) assert find_extension_by_uri(card, 'foo') is None + + +@pytest.mark.parametrize( + 'extensions, header, expected_extensions', + [ + ( + ['ext1', 'ext2'], # extensions + '', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 1: New extensions provided, empty header. + ( + None, # extensions + 'ext1, ext2', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 2: Extensions is None, existing header extensions. + ( + [], # extensions + 'ext1', # header + {}, # expected_extensions + ), # Case 3: New extensions is empty list, existing header extensions. + ( + ['ext1', 'ext2'], # extensions + 'ext3', # header + { + 'ext1', + 'ext2', + }, # expected_extensions + ), # Case 4: New extensions provided, and an existing header. New extensions should override active extensions. + ], +) +def test_update_extension_header_merge_with_existing_extensions( + extensions: list[str], + header: str, + expected_extensions: set[str], +): + http_kwargs = {'headers': {HTTP_EXTENSION_HEADER: header}} + result_kwargs = update_extension_header(http_kwargs, extensions) + header_value = result_kwargs['headers'][HTTP_EXTENSION_HEADER] + if not header_value: + actual_extensions = {} + else: + actual_extensions_list = [e.strip() for e in header_value.split(',')] + actual_extensions = set(actual_extensions_list) + assert actual_extensions == expected_extensions + + +def test_update_extension_header_with_other_headers(): + extensions = ['ext'] + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'ext' + assert headers['X_Other'] == 'Test' + + +@pytest.mark.parametrize( + 'http_kwargs', + [ + None, + {}, + ], +) +def test_update_extension_header_headers_not_in_kwargs( + http_kwargs: dict[str, str] | None, +): + extensions = ['ext'] + http_kwargs = {} + result_kwargs = update_extension_header(http_kwargs, extensions) + headers = result_kwargs.get('headers', {}) + assert HTTP_EXTENSION_HEADER in headers + assert headers[HTTP_EXTENSION_HEADER] == 'ext' + + +def test_update_extension_header_with_other_headers_extensions_none(): + http_kwargs = {'headers': {'X_Other': 'Test'}} + result_kwargs = update_extension_header(http_kwargs, None) + assert HTTP_EXTENSION_HEADER not in result_kwargs['headers'] + assert result_kwargs['headers']['X_Other'] == 'Test' diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 88d4d3d1..e0a564ee 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1,7 +1,7 @@ import asyncio from collections.abc import AsyncGenerator from typing import NamedTuple -from unittest.mock import ANY, AsyncMock +from unittest.mock import ANY, AsyncMock, patch import grpc import httpx @@ -9,6 +9,8 @@ import pytest_asyncio from grpc.aio import Channel +from a2a.client import ClientConfig +from a2a.client.base_client import BaseClient from a2a.client.transports import JsonRpcTransport, RestTransport from a2a.client.transports.base import ClientTransport from a2a.client.transports.grpc import GrpcTransport @@ -767,3 +769,61 @@ def channel_factory(address: str) -> Channel: assert transport._needs_extended_card is False await transport.close() + + +@pytest.mark.asyncio +async def test_base_client_sends_message_with_extensions( + jsonrpc_setup: TransportSetup, agent_card: AgentCard +) -> None: + """ + Integration test for BaseClient with JSON-RPC transport to ensure extensions are included in headers. + """ + transport = jsonrpc_setup.transport + agent_card.capabilities.streaming = False + + # Create a BaseClient instance + client = BaseClient( + card=agent_card, + config=ClientConfig(streaming=False), + transport=transport, + consumers=[], + middleware=[], + ) + + message_to_send = Message( + role=Role.user, + message_id='msg-integration-test-extensions', + parts=[Part(root=TextPart(text='Hello, extensions test!'))], + ) + extensions = [ + 'https://example.com/test-ext/v1', + 'https://example.com/test-ext/v2', + ] + + with patch.object( + transport, '_send_request', new_callable=AsyncMock + ) as mock_send_request: + mock_send_request.return_value = { + 'id': '123', + 'jsonrpc': '2.0', + 'result': TASK_FROM_BLOCKING.model_dump(mode='json'), + } + + # Call send_message on the BaseClient + async for _ in client.send_message( + request=message_to_send, extensions=extensions + ): + pass + + mock_send_request.assert_called_once() + call_args, _ = mock_send_request.call_args + kwargs = call_args[1] + headers = kwargs.get('headers', {}) + assert 'X-A2A-Extensions' in headers + assert ( + headers['X-A2A-Extensions'] + == 'https://example.com/test-ext/v1,https://example.com/test-ext/v2' + ) + + if hasattr(transport, 'close'): + await transport.close()