Skip to content

Commit eda15d5

Browse files
yeesiancopybara-github
authored andcommitted
feat: Allow list of events to be passed to AdkApp.async_stream_query
PiperOrigin-RevId: 844834249
1 parent df0976e commit eda15d5

File tree

1 file changed

+18
-1
lines changed
  • vertexai/agent_engines/templates

1 file changed

+18
-1
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,7 @@ async def async_stream_query(
932932
message: Union[str, Dict[str, Any]],
933933
user_id: str,
934934
session_id: Optional[str] = None,
935+
session_events: Optional[Dict[str, Any]] = None,
935936
run_config: Optional[Dict[str, Any]] = None,
936937
**kwargs,
937938
) -> AsyncIterable[Dict[str, Any]]:
@@ -944,7 +945,11 @@ async def async_stream_query(
944945
Required. The ID of the user.
945946
session_id (str):
946947
Optional. The ID of the session. If not provided, a new
947-
session will be created for the user.
948+
session will be created for the user. If this is specified, then
949+
`session_events` will be ignored.
950+
session_events (Optional[List[Dict[str, Any]]]):
951+
Optional. The session events to use for the query. This will be
952+
used to initialize the session if `session_id` is not provided.
948953
run_config (Optional[Dict[str, Any]]):
949954
Optional. The run config to use for the query. If you want to
950955
pass in a `run_config` pydantic object, you can pass in a dict
@@ -974,6 +979,18 @@ async def async_stream_query(
974979
if not session_id:
975980
session = await self.async_create_session(user_id=user_id)
976981
session_id = session.id
982+
if session_events is not None:
983+
# We allow for session_events to be an empty list.
984+
from google.adk.events.event import Event
985+
986+
session_service = self._tmpl_attrs.get("session_service")
987+
for event in session_events:
988+
if not isinstance(event, Event):
989+
event = Event.model_validate(event)
990+
await session_service.append_event(
991+
session=session,
992+
event=event,
993+
)
977994

978995
run_config = _validate_run_config(run_config)
979996
if run_config:

0 commit comments

Comments
 (0)