Skip to content

Commit c46308b

Browse files
seanzhougooglecopybara-github
authored andcommitted
chore: Add session patch endpoint to api server for state update
This is allow user to update session state without running the agent. e.g. if I want to test some case when session has certain state on adk web. PiperOrigin-RevId: 814252851
1 parent 822efe0 commit c46308b

File tree

4 files changed

+214
-0
lines changed

4 files changed

+214
-0
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ class UpdateMemoryRequest(common.BaseModel):
220220
"""The ID of the session to add to memory."""
221221

222222

223+
class UpdateSessionRequest(common.BaseModel):
224+
"""Request to update session state without running the agent."""
225+
226+
state_delta: dict[str, Any]
227+
"""The state changes to apply to the session."""
228+
229+
223230
class RunEvalResult(common.BaseModel):
224231
eval_set_file: str
225232
eval_set_id: str
@@ -767,6 +774,56 @@ async def delete_session(
767774
app_name=app_name, user_id=user_id, session_id=session_id
768775
)
769776

777+
@app.patch(
778+
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
779+
response_model_exclude_none=True,
780+
)
781+
async def update_session(
782+
app_name: str,
783+
user_id: str,
784+
session_id: str,
785+
req: UpdateSessionRequest,
786+
) -> Session:
787+
"""Updates session state without running the agent.
788+
789+
Args:
790+
app_name: The name of the application.
791+
user_id: The ID of the user.
792+
session_id: The ID of the session to update.
793+
req: The patch request containing state changes.
794+
795+
Returns:
796+
The updated session.
797+
798+
Raises:
799+
HTTPException: If the session is not found.
800+
"""
801+
session = await self.session_service.get_session(
802+
app_name=app_name, user_id=user_id, session_id=session_id
803+
)
804+
if not session:
805+
raise HTTPException(status_code=404, detail="Session not found")
806+
807+
# Create an event to record the state change
808+
import uuid
809+
810+
from ..events.event import Event
811+
from ..events.event import EventActions
812+
813+
state_update_event = Event(
814+
invocation_id="p-" + str(uuid.uuid4()),
815+
author="user",
816+
actions=EventActions(state_delta=req.state_delta),
817+
)
818+
819+
# Append the event to the session
820+
# This will automatically update the session state through __update_session_state
821+
await self.session_service.append_event(
822+
session=session, event=state_update_event
823+
)
824+
825+
return session
826+
770827
@app.post(
771828
"/apps/{app_name}/eval-sets",
772829
response_model_exclude_none=True,

src/google/adk/cli/conformance/adk_web_server_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,36 @@ async def delete_session(
176176
)
177177
response.raise_for_status()
178178

179+
async def update_session(
180+
self,
181+
*,
182+
app_name: str,
183+
user_id: str,
184+
session_id: str,
185+
state_delta: Dict[str, Any],
186+
) -> Session:
187+
"""Update session state without running the agent.
188+
189+
Args:
190+
app_name: Name of the application
191+
user_id: User identifier
192+
session_id: Session identifier to update
193+
state_delta: The state changes to apply to the session
194+
195+
Returns:
196+
The updated Session object
197+
198+
Raises:
199+
httpx.HTTPStatusError: If the request fails or session not found
200+
"""
201+
async with self._get_client() as client:
202+
response = await client.patch(
203+
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
204+
json={"state_delta": state_delta},
205+
)
206+
response.raise_for_status()
207+
return Session.model_validate(response.json())
208+
179209
async def run_agent(
180210
self,
181211
request: RunAgentRequest,

tests/unittests/cli/conformance/test_adk_web_server_client.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,43 @@ async def test_delete_session():
127127
mock_response.raise_for_status.assert_called_once()
128128

129129

130+
@pytest.mark.asyncio
131+
async def test_update_session():
132+
client = AdkWebServerClient()
133+
134+
# Mock the HTTP response
135+
mock_response = MagicMock()
136+
mock_response.json.return_value = {
137+
"id": "test_session",
138+
"app_name": "test_app",
139+
"user_id": "test_user",
140+
"events": [],
141+
"state": {"key": "updated_value", "new_key": "new_value"},
142+
}
143+
144+
with patch("httpx.AsyncClient") as mock_client_class:
145+
mock_client = AsyncMock()
146+
mock_client.patch.return_value = mock_response
147+
mock_client_class.return_value = mock_client
148+
149+
state_delta = {"key": "updated_value", "new_key": "new_value"}
150+
session = await client.update_session(
151+
app_name="test_app",
152+
user_id="test_user",
153+
session_id="test_session",
154+
state_delta=state_delta,
155+
)
156+
157+
assert isinstance(session, Session)
158+
assert session.id == "test_session"
159+
assert session.state == {"key": "updated_value", "new_key": "new_value"}
160+
mock_client.patch.assert_called_once_with(
161+
"/apps/test_app/users/test_user/sessions/test_session",
162+
json={"state_delta": state_delta},
163+
)
164+
mock_response.raise_for_status.assert_called_once()
165+
166+
130167
@pytest.mark.asyncio
131168
async def test_run_agent():
132169
client = AdkWebServerClient()

tests/unittests/cli/test_fast_api.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,22 @@ async def delete_session(self, app_name, user_id, session_id):
288288
):
289289
del session_data[app_name][user_id][session_id]
290290

291+
async def append_event(self, session, event):
292+
"""Append an event to a session."""
293+
# Update session state if event has state_delta
294+
if event.actions and event.actions.state_delta:
295+
session["state"].update(event.actions.state_delta)
296+
297+
# Add event to session events
298+
session["events"].append(event.model_dump())
299+
300+
# Update the session in storage
301+
session_data[session["app_name"]][session["user_id"]][
302+
session["id"]
303+
] = session
304+
305+
return event
306+
291307
# Return an instance of our mock service
292308
return MockSessionService()
293309

@@ -725,6 +741,80 @@ def test_delete_session(test_app, create_test_session):
725741
logger.info("Session deleted successfully")
726742

727743

744+
def test_update_session(test_app, create_test_session):
745+
"""Test patching a session state."""
746+
info = create_test_session
747+
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
748+
749+
# Get the original session
750+
response = test_app.get(url)
751+
assert response.status_code == 200
752+
original_session = response.json()
753+
original_state = original_session.get("state", {})
754+
755+
# Prepare state delta
756+
state_delta = {"test_key": "test_value", "counter": 42}
757+
758+
# Patch the session
759+
response = test_app.patch(url, json={"state_delta": state_delta})
760+
assert response.status_code == 200
761+
762+
# Verify the response
763+
patched_session = response.json()
764+
assert patched_session["id"] == info["session_id"]
765+
766+
# Verify state was updated correctly
767+
expected_state = {**original_state, **state_delta}
768+
assert patched_session["state"] == expected_state
769+
770+
# Verify the session was actually updated in storage
771+
response = test_app.get(url)
772+
assert response.status_code == 200
773+
retrieved_session = response.json()
774+
assert retrieved_session["state"] == expected_state
775+
776+
# Verify an event was created for the state change
777+
events = retrieved_session.get("events", [])
778+
assert len(events) > len(original_session.get("events", []))
779+
780+
# Find the state patch event (looking for "p-" prefix pattern)
781+
state_patch_events = [
782+
event
783+
for event in events
784+
if (
785+
event.get("invocationId") or event.get("invocation_id", "")
786+
).startswith("p-")
787+
]
788+
789+
assert len(state_patch_events) == 1, (
790+
f"Expected 1 state_patch event, found {len(state_patch_events)}. Events:"
791+
f" {events}"
792+
)
793+
state_patch_event = state_patch_events[0]
794+
assert state_patch_event["author"] == "user"
795+
796+
# Check for actions in both camelCase and snake_case
797+
actions = state_patch_event.get("actions") or state_patch_event.get("actions")
798+
assert actions is not None, f"No actions found in event: {state_patch_event}"
799+
state_delta_in_event = actions.get("state_delta") or actions.get("stateDelta")
800+
assert state_delta_in_event == state_delta
801+
802+
logger.info("Session state patched successfully")
803+
804+
805+
def test_patch_session_not_found(test_app, test_session_info):
806+
"""Test patching a non-existent session."""
807+
info = test_session_info
808+
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/nonexistent"
809+
810+
state_delta = {"test_key": "test_value"}
811+
response = test_app.patch(url, json={"state_delta": state_delta})
812+
813+
assert response.status_code == 404
814+
assert "Session not found" in response.json()["detail"]
815+
logger.info("Patch session not found test passed")
816+
817+
728818
def test_agent_run(test_app, create_test_session):
729819
"""Test running an agent with a message."""
730820
info = create_test_session

0 commit comments

Comments
 (0)