diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 3cb2e0bceb..8ee2080cab 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -683,6 +683,20 @@ async def _init_session( if request.events: for event in request.events: await session_service.append_event(session, Event(**event)) + if request.artifacts: + await self._save_artifacts( + session.id, artifact_service, request + ) + return session + + async def _save_artifacts( + self, + session_id: str, + artifact_service: "BaseArtifactService", + request: _StreamRunRequest, + ): + """Saves the artifacts.""" + app = self._tmpl_attrs.get("app") if request.artifacts: for artifact in request.artifacts: artifact = _Artifact(**artifact) @@ -693,7 +707,7 @@ async def _init_session( saved_version = await artifact_service.save_artifact( app_name=app.name if app else self._tmpl_attrs.get("app_name"), user_id=request.user_id, - session_id=session.id, + session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) @@ -707,7 +721,6 @@ async def _init_session( saved_version, version_data.version, ) - return session async def _convert_response_events( self, @@ -1209,6 +1222,11 @@ async def streaming_agent_run_with_events(self, request_json: str): user_id=request.user_id, session_id=request.session_id, ) + self._save_artifacts( + session_id=session.id, + artifact_service=artifact_service, + request=request + ) except ClientError: pass if not session: