Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions AIP-discussion-response.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Response to AIP Discussion #1247

> Re: [Respecting AIP response payloads in HTTP](https://github.com/a2aproject/A2A/discussions/1247)

Thanks for this detailed explanation of the AIP conventions, @darrelmiller. I've been working on the a2a-python SDK migration from Pydantic to protobuf types ([PR #572](https://github.com/a2aproject/a2a-python/pull/572)) and wanted to share how we've implemented this.

## How we handle `SetTaskPushNotificationConfig` in the SDK

The key insight is that the request and response types serve different purposes:

**Request (`SetTaskPushNotificationConfigRequest`):**
```protobuf
message SetTaskPushNotificationConfigRequest {
string parent = 1; // e.g., "tasks/{task_id}"
string config_id = 2; // e.g., "my-config-id"
TaskPushNotificationConfig config = 3;
}
```

**Response (`TaskPushNotificationConfig`):**
```protobuf
message TaskPushNotificationConfig {
string name = 1; // Full resource name: "tasks/{task_id}/pushNotificationConfigs/{config_id}"
PushNotificationConfig push_notification_config = 2;
}
```

## Implementation in Python

In our `DefaultRequestHandler`, we construct the proper `name` field from the request's `parent` and `config_id`:

```python
async def on_set_task_push_notification_config(
self,
params: SetTaskPushNotificationConfigRequest,
context: ServerCallContext | None = None,
) -> TaskPushNotificationConfig:
task_id = _extract_task_id(params.parent) # Extract from "tasks/{task_id}"

# Store the config
await self._push_config_store.set_info(
task_id,
params.config.push_notification_config,
)

# Build response with proper AIP resource name
return TaskPushNotificationConfig(
name=f'{params.parent}/pushNotificationConfigs/{params.config_id}',
push_notification_config=params.config.push_notification_config,
)
```

## REST Handler Translation

For the HTTP binding, the REST handler extracts path parameters and constructs the request:

```python
async def set_push_notification(self, request: Request, context: ServerCallContext):
task_id = request.path_params['id']
body = await request.body()

params = SetTaskPushNotificationConfigRequest()
Parse(body, params)
params.parent = f'tasks/{task_id}' # Set from URL path

config = await self.request_handler.on_set_task_push_notification_config(params, context)
return MessageToDict(config) # Returns with proper `name` field
```

## JSON-RPC Handler

The JSON-RPC handler passes the full request directly:

```python
async def set_push_notification_config(
self,
request: SetTaskPushNotificationConfigRequest,
context: ServerCallContext | None = None,
) -> SetTaskPushNotificationConfigResponse:
result = await self.request_handler.on_set_task_push_notification_config(
request, context
)
return prepare_response_object(...)
```

## Key Takeaways

1. **The `name` field is constructed, not passed in** - The server builds the full resource name from `parent` + `config_id`

2. **Consistent across bindings** - Both gRPC and HTTP handlers ultimately call the same `on_set_task_push_notification_config` method

3. **AIP compliance** - The response always includes the full `name` field as required by [AIP-122](https://google.aip.dev/122)

4. **Helper functions for resource name parsing**:
```python
def _extract_task_id(resource_name: str) -> str:
"""Extract task ID from a resource name like 'tasks/{task_id}' or 'tasks/{task_id}/...'."""
match = re.match(r'^tasks/([^/]+)', resource_name)
if match:
return match.group(1)
return resource_name # Fall back for backwards compatibility

def _extract_config_id(resource_name: str) -> str | None:
"""Extract config ID from 'tasks/{task_id}/pushNotificationConfigs/{config_id}'."""
match = re.match(r'^tasks/[^/]+/pushNotificationConfigs/([^/]+)$', resource_name)
if match:
return match.group(1)
return None
```

## E2E Test Example

Here's how a client uses this in practice:

```python
# Client sets the push notification config
await a2a_client.set_task_callback(
SetTaskPushNotificationConfigRequest(
parent=f'tasks/{task.id}',
config_id='my-notification-config',
config=TaskPushNotificationConfig(
push_notification_config=PushNotificationConfig(
id='my-notification-config',
url=f'{notifications_server}/notifications',
token=token,
),
),
)
)
```

This approach keeps the abstract handler logic clean while ensuring AIP compliance at the protocol binding level.

---

**Related PRs:**
- [a2a-python PR #572](https://github.com/a2aproject/a2a-python/pull/572) - Proto migration with these changes
2 changes: 1 addition & 1 deletion buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version: v2
inputs:
- git_repo: https://github.com/a2aproject/A2A.git
ref: transports
ref: main
subdir: specification/grpc
managed:
enabled: true
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ addopts = "-ra --strict-markers"
markers = [
"asyncio: mark a test as a coroutine that should be run by pytest-asyncio",
]
filterwarnings = [
# SQLAlchemy warning about duplicate class registration - this is a known limitation
# of the dynamic model creation pattern used in models.py for custom table names
"ignore:This declarative base already contains a class with the same class name:sqlalchemy.exc.SAWarning",
# ResourceWarnings from asyncio event loop/socket cleanup during garbage collection
# These appear intermittently between tests due to pytest-asyncio and sse-starlette timing
"ignore:unclosed event loop:ResourceWarning",
"ignore:unclosed transport:ResourceWarning",
"ignore:unclosed <socket.socket:ResourceWarning",
]

[tool.pytest-asyncio]
mode = "strict"
Expand Down
24 changes: 18 additions & 6 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,36 @@
AgentCard,
APIKeySecurityScheme,
HTTPAuthSecurityScheme,
MutualTlsSecurityScheme,
OAuth2SecurityScheme,
OpenIdConnectSecurityScheme,
SecurityScheme,
)

logger = logging.getLogger(__name__)

_SecuritySchemeValue = (
APIKeySecurityScheme
| HTTPAuthSecurityScheme
| OAuth2SecurityScheme
| OpenIdConnectSecurityScheme
| MutualTlsSecurityScheme
| None
)


def _get_security_scheme_value(scheme: SecurityScheme):
def _get_security_scheme_value(scheme: SecurityScheme) -> _SecuritySchemeValue:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this change has just proven my point about moving this inline in the function below and changing how the "match" is done.

"""Extract the actual security scheme from the oneof union."""
which = scheme.WhichOneof('scheme')
if which == 'api_key_security_scheme':
return scheme.api_key_security_scheme
elif which == 'http_auth_security_scheme':
if which == 'http_auth_security_scheme':
return scheme.http_auth_security_scheme
elif which == 'oauth2_security_scheme':
if which == 'oauth2_security_scheme':
return scheme.oauth2_security_scheme
elif which == 'open_id_connect_security_scheme':
if which == 'open_id_connect_security_scheme':
return scheme.open_id_connect_security_scheme
elif which == 'mtls_security_scheme':
if which == 'mtls_security_scheme':
return scheme.mtls_security_scheme
return None

Expand Down Expand Up @@ -100,7 +110,9 @@ async def intercept(
return request_payload, http_kwargs

# Case 2: API Key in Header
case APIKeySecurityScheme() if scheme_def.location.lower() == 'header':
case APIKeySecurityScheme() if (
scheme_def.location.lower() == 'header'
):
headers[scheme_def.name] = credential
logger.debug(
"Added API Key Header for scheme '%s'.",
Expand Down
41 changes: 20 additions & 21 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,29 @@
from collections.abc import AsyncIterator, AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

from a2a.client.client import (
Client,
ClientCallContext,
ClientConfig,
Consumer,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.errors import A2AClientInvalidStateError
from a2a.client.middleware import ClientCallInterceptor
from a2a.client.transports.base import ClientTransport
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
Message,
SendMessageConfiguration,
SendMessageRequest,
Task,
TaskArtifactUpdateEvent,
SetTaskPushNotificationConfigRequest,
StreamResponse,
SubscribeToTaskRequest,
CancelTaskRequest,
Task,
TaskPushNotificationConfig,
GetTaskRequest,
TaskStatusUpdateEvent,
StreamResponse,
SetTaskPushNotificationConfigRequest,
GetExtendedAgentCardRequest,
GetTaskPushNotificationConfigRequest,
)


Expand Down Expand Up @@ -79,44 +75,48 @@ async def send_message(
else None
),
)
sendMessageRequest = SendMessageRequest(
send_message_request = SendMessageRequest(
request=request, configuration=config, metadata=request_metadata
)

if not self._config.streaming or not self._card.capabilities.streaming:
response = await self._transport.send_message(
sendMessageRequest, context=context, extensions=extensions
send_message_request, context=context, extensions=extensions
)

# In non-streaming case we convert to a StreamResponse so that the
# client always sees the same iterator.
stream_response = StreamResponse()
client_event: ClientEvent
if response.HasField("task"):
if response.HasField('task'):
stream_response.task.CopyFrom(response.task)
client_event = (stream_response, response.task)

elif response.HasField("msg"):
elif response.HasField('msg'):
stream_response.msg.CopyFrom(response.msg)
client_event = (stream_response, None)
else:
# Response must have either task or msg
raise ValueError('Response has neither task nor msg')

await self.consume(client_event, self._card)
yield client_event
return

stream = self._transport.send_message_streaming(
sendMessageRequest, context=context, extensions=extensions
send_message_request, context=context, extensions=extensions
)
async for client_event in self._process_stream(stream):
yield client_event

async def _process_stream(self, stream: AsyncIterator[StreamResponse]) -> AsyncGenerator[ClientEvent]:
async def _process_stream(
self, stream: AsyncIterator[StreamResponse]
) -> AsyncGenerator[ClientEvent]:
tracker = ClientTaskManager()
async for stream_response in stream:
client_event: ClientEvent
# When we get a message in the stream then we don't expect any
# further messages so yield and return
if stream_response.HasField("msg"):
if stream_response.HasField('msg'):
client_event = (stream_response, None)
await self.consume(client_event, self._card)
yield client_event
Expand Down Expand Up @@ -240,7 +240,6 @@ async def subscribe(
'client and/or server do not support resubscription.'
)

tracker = ClientTaskManager()
# Note: resubscribe can only be called on an existing task. As such,
# we should never see Message updates, despite the typing of the service
# definition indicating it may be possible.
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/client/card_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import httpx

from google.protobuf.json_format import ParseDict
from pydantic import ValidationError

from google.protobuf.json_format import ParseDict
from a2a.client.errors import (
A2AClientHTTPError,
A2AClientJSONError,
Expand Down
15 changes: 6 additions & 9 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,16 @@
from a2a.client.optionals import Channel
from a2a.types.a2a_pb2 import (
AgentCard,
CancelTaskRequest,
GetTaskPushNotificationConfigRequest,
GetTaskRequest,
Message,
PushNotificationConfig,
Task,
TaskArtifactUpdateEvent,
TaskPushNotificationConfig,
TaskStatusUpdateEvent,
StreamResponse,
SendMessageRequest,
GetTaskRequest,
CancelTaskRequest,
SetTaskPushNotificationConfigRequest,
GetTaskPushNotificationConfigRequest,
StreamResponse,
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
)


Expand Down
9 changes: 5 additions & 4 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AgentInterface,
)


TRANSPORT_PROTOCOLS_JSONRPC = 'JSONRPC'
TRANSPORT_PROTOCOLS_GRPC = 'GRPC'
TRANSPORT_PROTOCOLS_HTTP_JSON = 'HTTP+JSON'
Expand Down Expand Up @@ -71,9 +72,7 @@ def __init__(
self._registry: dict[str, TransportProducer] = {}
self._register_defaults(config.supported_protocol_bindings)

def _register_defaults(
self, supported: list[str]
) -> None:
def _register_defaults(self, supported: list[str]) -> None:
# Empty support list implies JSON-RPC only.
if TRANSPORT_PROTOCOLS_JSONRPC in supported or not supported:
self.register(
Expand Down Expand Up @@ -203,7 +202,9 @@ def create(
If there is no valid matching of the client configuration with the
server configuration, a `ValueError` is raised.
"""
server_preferred = card.preferred_transport or TRANSPORT_PROTOCOLS_JSONRPC
server_preferred = (
card.preferred_transport or TRANSPORT_PROTOCOLS_JSONRPC
)
server_set = {server_preferred: card.url}
if card.additional_interfaces:
server_set.update(
Expand Down
Loading