Skip to content

Commit f00ef3e

Browse files
committed
fix: correctly handle streaming errors
1 parent 3bfbea9 commit f00ef3e

File tree

4 files changed

+84
-2
lines changed

4 files changed

+84
-2
lines changed

src/a2a/client/transports/jsonrpc.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,16 @@ async def send_message_streaming(
174174
**modified_kwargs,
175175
) as event_source:
176176
try:
177+
event_source.response.raise_for_status()
177178
async for sse in event_source.aiter_sse():
178179
response = SendStreamingMessageResponse.model_validate(
179180
json.loads(sse.data)
180181
)
181182
if isinstance(response.root, JSONRPCErrorResponse):
182183
raise A2AClientJSONRPCError(response.root)
183184
yield response.root.result
185+
except httpx.HTTPStatusError as e:
186+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
184187
except SSEError as e:
185188
raise A2AClientHTTPError(
186189
400, f'Invalid SSE response or protocol error: {e}'

src/a2a/client/transports/rest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,13 @@ async def send_message_streaming(
152152
**modified_kwargs,
153153
) as event_source:
154154
try:
155+
event_source.response.raise_for_status()
155156
async for sse in event_source.aiter_sse():
156157
event = a2a_pb2.StreamResponse()
157158
Parse(sse.data, event)
158159
yield proto_utils.FromProto.stream_response(event)
160+
except httpx.HTTPStatusError as e:
161+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
159162
except SSEError as e:
160163
raise A2AClientHTTPError(
161164
400, f'Invalid SSE response or protocol error: {e}'

tests/client/transports/test_jsonrpc_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,44 @@ async def test_send_message_streaming_with_new_extensions(
880880
},
881881
)
882882

883+
@pytest.mark.asyncio
884+
@patch('a2a.client.transports.jsonrpc.aconnect_sse')
885+
async def test_send_message_streaming_server_error_propagates(
886+
self,
887+
mock_aconnect_sse: AsyncMock,
888+
mock_httpx_client: AsyncMock,
889+
mock_agent_card: MagicMock,
890+
):
891+
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
892+
client = JsonRpcTransport(
893+
httpx_client=mock_httpx_client,
894+
agent_card=mock_agent_card,
895+
)
896+
params = MessageSendParams(
897+
message=create_text_message_object(content='Error stream')
898+
)
899+
900+
mock_event_source = AsyncMock(spec=EventSource)
901+
mock_response = MagicMock(spec=httpx.Response)
902+
mock_response.status_code = 403
903+
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
904+
'Forbidden',
905+
request=httpx.Request('POST', 'http://test.url'),
906+
response=mock_response,
907+
)
908+
mock_event_source.response = mock_response
909+
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
910+
mock_aconnect_sse.return_value.__aenter__.return_value = (
911+
mock_event_source
912+
)
913+
914+
with pytest.raises(A2AClientHTTPError) as exc_info:
915+
async for _ in client.send_message_streaming(request=params):
916+
pass
917+
918+
assert exc_info.value.status_code == 403
919+
mock_aconnect_sse.assert_called_once()
920+
883921
@pytest.mark.asyncio
884922
async def test_get_card_no_card_provided_with_extensions(
885923
self, mock_httpx_client: AsyncMock

tests/client/transports/test_rest_client.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from httpx_sse import EventSource, ServerSentEvent
88

99
from a2a.client import create_text_message_object
10+
from a2a.client.errors import A2AClientHTTPError
1011
from a2a.client.transports.rest import RestTransport
1112
from a2a.extensions.common import HTTP_EXTENSION_HEADER
1213
from a2a.types import (
1314
AgentCapabilities,
1415
AgentCard,
15-
AgentSkill,
1616
MessageSendParams,
17-
Role,
1817
)
1918

2019

@@ -130,6 +129,45 @@ async def test_send_message_streaming_with_new_extensions(
130129
},
131130
)
132131

132+
@pytest.mark.asyncio
133+
@patch('a2a.client.transports.rest.aconnect_sse')
134+
async def test_send_message_streaming_server_error_propagates(
135+
self,
136+
mock_aconnect_sse: AsyncMock,
137+
mock_httpx_client: AsyncMock,
138+
mock_agent_card: MagicMock,
139+
):
140+
"""Test that send_message_streaming propagates server errors (e.g., 403, 500) directly."""
141+
client = RestTransport(
142+
httpx_client=mock_httpx_client,
143+
agent_card=mock_agent_card,
144+
)
145+
params = MessageSendParams(
146+
message=create_text_message_object(content='Error stream')
147+
)
148+
149+
mock_event_source = AsyncMock(spec=EventSource)
150+
mock_response = MagicMock(spec=httpx.Response)
151+
mock_response.status_code = 403
152+
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
153+
'Forbidden',
154+
request=httpx.Request('POST', 'http://test.url'),
155+
response=mock_response,
156+
)
157+
mock_event_source.response = mock_response
158+
mock_event_source.aiter_sse.return_value = async_iterable_from_list([])
159+
mock_aconnect_sse.return_value.__aenter__.return_value = (
160+
mock_event_source
161+
)
162+
163+
with pytest.raises(A2AClientHTTPError) as exc_info:
164+
async for _ in client.send_message_streaming(request=params):
165+
pass
166+
167+
assert exc_info.value.status_code == 403
168+
169+
mock_aconnect_sse.assert_called_once()
170+
133171
@pytest.mark.asyncio
134172
async def test_get_card_no_card_provided_with_extensions(
135173
self, mock_httpx_client: AsyncMock

0 commit comments

Comments
 (0)