Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,27 +174,30 @@
**modified_kwargs,
) as event_source:
try:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(
self,
rpc_request_payload: dict[str, Any],

Check notice on line 200 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (160-173)
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
Expand Down
3 changes: 3 additions & 0 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,25 @@
**modified_kwargs,
) as event_source:
try:
event_source.response.raise_for_status()
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except httpx.HTTPStatusError as e:
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:

Check notice on line 173 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (185-200)

Check notice on line 173 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (358-369)
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
Expand Down
38 changes: 38 additions & 0 deletions tests/client/transports/test_jsonrpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,44 @@ async def test_send_message_streaming_with_new_extensions(
},
)

@pytest.mark.asyncio
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
async def test_send_message_streaming_server_error_propagates(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
client = JsonRpcTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Error stream')
)

mock_event_source = AsyncMock(spec=EventSource)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 403
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
'Forbidden',
request=httpx.Request('POST', 'http://test.url'),
response=mock_response,
)
mock_event_source.response = mock_response
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)

with pytest.raises(A2AClientHTTPError) as exc_info:
async for _ in client.send_message_streaming(request=params):
pass

assert exc_info.value.status_code == 403
mock_aconnect_sse.assert_called_once()

@pytest.mark.asyncio
async def test_get_card_no_card_provided_with_extensions(
self, mock_httpx_client: AsyncMock
Expand Down
42 changes: 40 additions & 2 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
from httpx_sse import EventSource, ServerSentEvent

from a2a.client import create_text_message_object
from a2a.client.errors import A2AClientHTTPError
from a2a.client.transports.rest import RestTransport
from a2a.extensions.common import HTTP_EXTENSION_HEADER
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
MessageSendParams,
Role,
)


Expand Down Expand Up @@ -130,6 +129,45 @@ async def test_send_message_streaming_with_new_extensions(
},
)

@pytest.mark.asyncio
@patch('a2a.client.transports.rest.aconnect_sse')
async def test_send_message_streaming_server_error_propagates(
self,
mock_aconnect_sse: AsyncMock,
mock_httpx_client: AsyncMock,
mock_agent_card: MagicMock,
):
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
client = RestTransport(
httpx_client=mock_httpx_client,
agent_card=mock_agent_card,
)
params = MessageSendParams(
message=create_text_message_object(content='Error stream')
)

mock_event_source = AsyncMock(spec=EventSource)
mock_response = MagicMock(spec=httpx.Response)
mock_response.status_code = 403
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
'Forbidden',
request=httpx.Request('POST', 'http://test.url'),
response=mock_response,
)
mock_event_source.response = mock_response
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
mock_aconnect_sse.return_value.__aenter__.return_value = (
mock_event_source
)

with pytest.raises(A2AClientHTTPError) as exc_info:
async for _ in client.send_message_streaming(request=params):
pass

assert exc_info.value.status_code == 403

mock_aconnect_sse.assert_called_once()

@pytest.mark.asyncio
async def test_get_card_no_card_provided_with_extensions(
self, mock_httpx_client: AsyncMock
Expand Down
Loading