Skip to content

Commit 798d005

Browse files
GWealecopybara-github
authored andcommitted
fix: Stream errors as simple JSON objects in ADK web server SSE
The ADK web server's /run_sse endpoint now yields a JSON object like {"error": "..."} when an exception occurs during event generation. The adk_web_server_client is updated to detect this error payload and raise a RuntimeError. Close #4291 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 863475838
1 parent d0102ec commit 798d005

File tree

4 files changed

+84
-11
lines changed

4 files changed

+84
-11
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,17 +1474,7 @@ async def event_generator():
14741474
yield f"data: {sse_event}\n\n"
14751475
except Exception as e:
14761476
logger.exception("Error in event_generator: %s", e)
1477-
# Yield a proper Event object for the error
1478-
error_event = Event(
1479-
author="system",
1480-
content=types.Content(
1481-
role="model", parts=[types.Part(text=f"Error: {e}")]
1482-
),
1483-
)
1484-
yield (
1485-
"data:"
1486-
f" {error_event.model_dump_json(by_alias=True, exclude_none=True)}\n\n"
1487-
)
1477+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
14881478

14891479
# Returns a streaming response with the proper media type for SSE
14901480
return StreamingResponse(

src/google/adk/cli/conformance/adk_web_server_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ async def run_agent(
228228
ValueError: If mode is provided but test_case_dir or user_message_index is None
229229
httpx.HTTPStatusError: If the request fails
230230
json.JSONDecodeError: If event data cannot be parsed
231+
RuntimeError: If the server streams an error payload
231232
"""
232233
# Add recording parameters to state_delta for conformance tests
233234
if mode:
@@ -262,6 +263,8 @@ async def run_agent(
262263
async for line in response.aiter_lines():
263264
if line.startswith("data:") and (data := line[5:].strip()):
264265
event_data = json.loads(data)
266+
if isinstance(event_data, dict) and "error" in event_data:
267+
raise RuntimeError(event_data["error"])
265268
yield Event.model_validate(event_data)
266269
else:
267270
logger.debug("Non data line received: %s", line)

tests/unittests/cli/conformance/test_adk_web_server_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,44 @@ def mock_stream(*_args, **_kwargs):
224224
assert events[1].invocation_id == "test_invocation_2"
225225

226226

227+
@pytest.mark.asyncio
228+
async def test_run_agent_raises_on_streamed_error():
229+
client = AdkWebServerClient()
230+
231+
class MockStreamResponse:
232+
233+
def raise_for_status(self):
234+
pass
235+
236+
async def aiter_lines(self):
237+
yield 'data: {"error": "boom"}'
238+
239+
async def __aenter__(self):
240+
return self
241+
242+
async def __aexit__(self, exc_type, exc_val, exc_tb):
243+
pass
244+
245+
def mock_stream(*_args, **_kwargs):
246+
return MockStreamResponse()
247+
248+
with patch("httpx.AsyncClient") as mock_client_class:
249+
mock_client = AsyncMock()
250+
mock_client.stream = mock_stream
251+
mock_client_class.return_value = mock_client
252+
253+
request = RunAgentRequest(
254+
app_name="test_app",
255+
user_id="test_user",
256+
session_id="test_session",
257+
new_message=types.Content(role="user", parts=[types.Part(text="Hi")]),
258+
)
259+
260+
with pytest.raises(RuntimeError, match="boom"):
261+
async for _ in client.run_agent(request):
262+
pass
263+
264+
227265
@pytest.mark.asyncio
228266
async def test_close():
229267
client = AdkWebServerClient()

tests/unittests/cli/test_fast_api.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,48 @@ async def run_async_with_artifact_delta(
10191019
assert sse_events[1]["actions"]["artifactDelta"] == {"artifact.txt": 0}
10201020

10211021

1022+
def test_agent_run_sse_yields_error_object_on_exception(
1023+
test_app, create_test_session, monkeypatch
1024+
):
1025+
"""Test /run_sse streams an error object if streaming raises."""
1026+
info = create_test_session
1027+
1028+
async def run_async_raises(
1029+
self,
1030+
*,
1031+
user_id: str,
1032+
session_id: str,
1033+
invocation_id: Optional[str] = None,
1034+
new_message: Optional[types.Content] = None,
1035+
state_delta: Optional[dict[str, Any]] = None,
1036+
run_config: Optional[RunConfig] = None,
1037+
):
1038+
del user_id, session_id, invocation_id, new_message, state_delta, run_config
1039+
raise ValueError("boom")
1040+
if False: # pylint: disable=using-constant-test
1041+
yield _event_1()
1042+
1043+
monkeypatch.setattr(Runner, "run_async", run_async_raises)
1044+
1045+
payload = {
1046+
"app_name": info["app_name"],
1047+
"user_id": info["user_id"],
1048+
"session_id": info["session_id"],
1049+
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
1050+
"streaming": True,
1051+
}
1052+
1053+
response = test_app.post("/run_sse", json=payload)
1054+
assert response.status_code == 200
1055+
1056+
sse_events = [
1057+
json.loads(line.removeprefix("data: "))
1058+
for line in response.text.splitlines()
1059+
if line.startswith("data: ")
1060+
]
1061+
assert sse_events == [{"error": "boom"}]
1062+
1063+
10221064
def test_list_artifact_names(test_app, create_test_session):
10231065
"""Test listing artifact names for a session."""
10241066
info = create_test_session

0 commit comments

Comments
 (0)