Skip to content

Commit a643e8d

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fixing a regression in InMemorySessionService
PiperOrigin-RevId: 862801950
1 parent 7019d39 commit a643e8d

File tree

2 files changed

+59
-23
lines changed

2 files changed

+59
-23
lines changed

core/src/main/java/com/google/adk/sessions/InMemorySessionService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ public Single<Event> appendEvent(Session session, Event event) {
259259
.computeIfAbsent(userId, unused -> new ConcurrentHashMap<>())
260260
.put(userStateKey, value);
261261
}
262-
} else if (!key.startsWith(State.TEMP_PREFIX)) {
262+
} else {
263263
if (value == State.REMOVED) {
264264
session.state().remove(key);
265265
} else {

core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java

Lines changed: 58 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ public void lifecycle_listSessions() {
8989

9090
ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
9191
stateDelta.put("sessionKey", "sessionValue");
92-
stateDelta.put("app:appKey", "appValue");
93-
stateDelta.put("user:userKey", "userValue");
92+
stateDelta.put("_app_appKey", "appValue");
93+
stateDelta.put("_user_userKey", "userValue");
9494
stateDelta.put("temp:tempKey", "tempValue");
9595

9696
Event event =
@@ -106,9 +106,9 @@ public void lifecycle_listSessions() {
106106
assertThat(listedSession.id()).isEqualTo(session.id());
107107
assertThat(listedSession.events()).isEmpty();
108108
assertThat(listedSession.state()).containsEntry("sessionKey", "sessionValue");
109-
assertThat(listedSession.state()).containsEntry("app:appKey", "appValue");
110-
assertThat(listedSession.state()).containsEntry("user:userKey", "userValue");
111-
assertThat(listedSession.state()).doesNotContainKey("temp:tempKey");
109+
assertThat(listedSession.state()).containsEntry("_app_appKey", "appValue");
110+
assertThat(listedSession.state()).containsEntry("_user_userKey", "userValue");
111+
assertThat(listedSession.state()).containsEntry("temp:tempKey", "tempValue");
112112
}
113113

114114
@Test
@@ -136,8 +136,8 @@ public void appendEvent_updatesSessionState() {
136136

137137
ConcurrentMap<String, Object> stateDelta = new ConcurrentHashMap<>();
138138
stateDelta.put("sessionKey", "sessionValue");
139-
stateDelta.put("app:appKey", "appValue");
140-
stateDelta.put("user:userKey", "userValue");
139+
stateDelta.put("_app_appKey", "appValue");
140+
stateDelta.put("_user_userKey", "userValue");
141141
stateDelta.put("temp:tempKey", "tempValue");
142142

143143
Event event =
@@ -148,19 +148,19 @@ public void appendEvent_updatesSessionState() {
148148
// After appendEvent, session state in memory should contain session-specific state from delta
149149
// and merged global state.
150150
assertThat(session.state()).containsEntry("sessionKey", "sessionValue");
151-
assertThat(session.state()).containsEntry("app:appKey", "appValue");
152-
assertThat(session.state()).containsEntry("user:userKey", "userValue");
153-
assertThat(session.state()).doesNotContainKey("temp:tempKey");
151+
assertThat(session.state()).containsEntry("_app_appKey", "appValue");
152+
assertThat(session.state()).containsEntry("_user_userKey", "userValue");
153+
assertThat(session.state()).containsEntry("temp:tempKey", "tempValue");
154154

155155
// getSession should return session with merged state.
156156
Session retrievedSession =
157157
sessionService
158158
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
159159
.blockingGet();
160160
assertThat(retrievedSession.state()).containsEntry("sessionKey", "sessionValue");
161-
assertThat(retrievedSession.state()).containsEntry("app:appKey", "appValue");
162-
assertThat(retrievedSession.state()).containsEntry("user:userKey", "userValue");
163-
assertThat(retrievedSession.state()).doesNotContainKey("temp:tempKey");
161+
assertThat(retrievedSession.state()).containsEntry("_app_appKey", "appValue");
162+
assertThat(retrievedSession.state()).containsEntry("_user_userKey", "userValue");
163+
assertThat(retrievedSession.state()).containsEntry("temp:tempKey", "tempValue");
164164
}
165165

166166
@Test
@@ -173,8 +173,8 @@ public void appendEvent_removesState() {
173173

174174
ConcurrentMap<String, Object> stateDeltaAdd = new ConcurrentHashMap<>();
175175
stateDeltaAdd.put("sessionKey", "sessionValue");
176-
stateDeltaAdd.put("app:appKey", "appValue");
177-
stateDeltaAdd.put("user:userKey", "userValue");
176+
stateDeltaAdd.put("_app_appKey", "appValue");
177+
stateDeltaAdd.put("_user_userKey", "userValue");
178178
stateDeltaAdd.put("temp:tempKey", "tempValue");
179179

180180
Event eventAdd =
@@ -188,15 +188,15 @@ public void appendEvent_removesState() {
188188
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
189189
.blockingGet();
190190
assertThat(retrievedSessionAdd.state()).containsEntry("sessionKey", "sessionValue");
191-
assertThat(retrievedSessionAdd.state()).containsEntry("app:appKey", "appValue");
192-
assertThat(retrievedSessionAdd.state()).containsEntry("user:userKey", "userValue");
193-
assertThat(retrievedSessionAdd.state()).doesNotContainKey("temp:tempKey");
191+
assertThat(retrievedSessionAdd.state()).containsEntry("_app_appKey", "appValue");
192+
assertThat(retrievedSessionAdd.state()).containsEntry("_user_userKey", "userValue");
193+
assertThat(retrievedSessionAdd.state()).containsEntry("temp:tempKey", "tempValue");
194194

195195
// Prepare and append event to remove state
196196
ConcurrentMap<String, Object> stateDeltaRemove = new ConcurrentHashMap<>();
197197
stateDeltaRemove.put("sessionKey", State.REMOVED);
198-
stateDeltaRemove.put("app:appKey", State.REMOVED);
199-
stateDeltaRemove.put("user:userKey", State.REMOVED);
198+
stateDeltaRemove.put("_app_appKey", State.REMOVED);
199+
stateDeltaRemove.put("_user_userKey", State.REMOVED);
200200
stateDeltaRemove.put("temp:tempKey", State.REMOVED);
201201

202202
Event eventRemove =
@@ -212,8 +212,44 @@ public void appendEvent_removesState() {
212212
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
213213
.blockingGet();
214214
assertThat(retrievedSessionRemove.state()).doesNotContainKey("sessionKey");
215-
assertThat(retrievedSessionRemove.state()).doesNotContainKey("app:appKey");
216-
assertThat(retrievedSessionRemove.state()).doesNotContainKey("user:userKey");
215+
assertThat(retrievedSessionRemove.state()).doesNotContainKey("_app_appKey");
216+
assertThat(retrievedSessionRemove.state()).doesNotContainKey("_user_userKey");
217217
assertThat(retrievedSessionRemove.state()).doesNotContainKey("temp:tempKey");
218218
}
219+
220+
@Test
221+
public void sequentialAgents_shareTempState() {
222+
InMemorySessionService sessionService = new InMemorySessionService();
223+
Session session =
224+
sessionService
225+
.createSession("app", "user", new ConcurrentHashMap<>(), "session1")
226+
.blockingGet();
227+
228+
// Agent 1 writes to temp state
229+
ConcurrentMap<String, Object> stateDelta1 = new ConcurrentHashMap<>();
230+
stateDelta1.put("temp:agent1_output", "data");
231+
Event event1 =
232+
Event.builder().actions(EventActions.builder().stateDelta(stateDelta1).build()).build();
233+
var unused = sessionService.appendEvent(session, event1).blockingGet();
234+
235+
// Verify agent 1 output is in session state
236+
assertThat(session.state()).containsEntry("temp:agent1_output", "data");
237+
238+
// Agent 2 reads "agent1_output", processes it, writes "agent2_output", and removes
239+
// "agent1_output"
240+
ConcurrentMap<String, Object> stateDelta2 = new ConcurrentHashMap<>();
241+
stateDelta2.put("temp:agent2_output", "processed_data");
242+
stateDelta2.put("temp:agent1_output", State.REMOVED);
243+
Event event2 =
244+
Event.builder().actions(EventActions.builder().stateDelta(stateDelta2).build()).build();
245+
unused = sessionService.appendEvent(session, event2).blockingGet();
246+
247+
// Verify final state after agent 2 processing
248+
Session retrievedSession =
249+
sessionService
250+
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
251+
.blockingGet();
252+
assertThat(retrievedSession.state()).doesNotContainKey("temp:agent1_output");
253+
assertThat(retrievedSession.state()).containsEntry("temp:agent2_output", "processed_data");
254+
}
219255
}

0 commit comments

Comments
 (0)