Skip to content

Commit 88950d7

Browse files
committed
fix(fastapi): pass state_delta to runner in /run endpoint
1 parent 831e2e6 commit 88950d7

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
993993
user_id=req.user_id,
994994
session_id=req.session_id,
995995
new_message=req.new_message,
996+
state_delta=req.state_delta,
996997
)
997998
) as agen:
998999
events = [event async for event in agen]

tests/unittests/cli/test_fast_api.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ def _event_3():
9494
)
9595

9696

97+
def _event_state_delta(state_delta: dict[str, Any]):
98+
return Event(
99+
author="dummy agent",
100+
invocation_id="invocation_id",
101+
content=types.Content(
102+
role="model",
103+
parts=[
104+
types.Part(text=json.dumps(state_delta, sort_keys=True))
105+
],
106+
),
107+
)
108+
109+
97110
# Define mocked async generator functions for the Runner
98111
async def dummy_run_live(self, session, live_request_queue):
99112
yield _event_1()
@@ -110,6 +123,7 @@ async def dummy_run_async(
110123
user_id,
111124
session_id,
112125
new_message,
126+
state_delta=None,
113127
run_config: RunConfig = RunConfig(),
114128
):
115129
yield _event_1()
@@ -119,6 +133,10 @@ async def dummy_run_async(
119133
await asyncio.sleep(0)
120134

121135
yield _event_3()
136+
await asyncio.sleep(0)
137+
138+
if state_delta is not None:
139+
yield _event_state_delta(state_delta)
122140

123141

124142
# Define a local mock for EvalCaseResult specific to fast_api tests
@@ -743,6 +761,26 @@ def test_agent_run(test_app, create_test_session):
743761

744762
logger.info("Agent run test completed successfully")
745763

764+
def test_agent_run_passes_state_delta(test_app, create_test_session):
765+
"""Test /run forwards state_delta and surfaces it in events."""
766+
info = create_test_session
767+
payload = {
768+
"app_name": info["app_name"],
769+
"user_id": info["user_id"],
770+
"session_id": info["session_id"],
771+
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
772+
"streaming": False,
773+
"state_delta": {"k": "v", "count": 1},
774+
}
775+
776+
response = test_app.post("/run", json=payload)
777+
assert response.status_code == 200
778+
data = response.json()
779+
assert isinstance(data, list)
780+
assert len(data) == 4
781+
782+
text = data[3]["content"]["parts"][0]["text"]
783+
assert json.loads(text) == payload["state_delta"]
746784

747785
def test_list_artifact_names(test_app, create_test_session):
748786
"""Test listing artifact names for a session."""

0 commit comments

Comments
 (0)