Skip to content

Commit 5438209

Browse files
Tongzhou-Jiangcopybara-github
authored andcommitted
fix: investigate save artifact in init_session
PiperOrigin-RevId: 833983514
1 parent 26b7e51 commit 5438209

File tree

1 file changed

+11
-3
lines changed
  • vertexai/agent_engines/templates

1 file changed

+11
-3
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ def is_version_sufficient(version_to_check: str) -> bool:
135135

136136
class _ArtifactVersion:
137137
def __init__(self, **kwargs):
138+
from google.genai import types
139+
138140
self.version: Optional[str] = kwargs.get("version")
139-
self.data = kwargs.get("data")
141+
self.data: Optional[types.Part] = kwargs.get("data")
140142

141143
def dump(self) -> Dict[str, Any]:
142144
result = {}
@@ -603,15 +605,20 @@ async def _init_session(
603605
"""Initializes the session, and returns the session id."""
604606
from google.adk.events.event import Event
605607

608+
from google.cloud.aiplatform import base
609+
610+
_LOGGER = base.Logger(__name__)
611+
606612
session_state = None
607613
if request.authorizations:
608614
session_state = {}
609615
for auth_id, auth in request.authorizations.items():
610616
auth = _Authorization(**auth)
611617
session_state[auth_id] = auth.access_token
612618

619+
app = self._tmpl_attrs.get("app")
613620
session = await session_service.create_session(
614-
app_name=self._tmpl_attrs.get("app_name"),
621+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
615622
user_id=request.user_id,
616623
state=session_state,
617624
)
@@ -627,8 +634,9 @@ async def _init_session(
627634
artifact.versions, key=lambda x: x["version"]
628635
):
629636
version_data = _ArtifactVersion(**version_data)
637+
_LOGGER.info(f'Saving artifact {version_data.data}, type {type(version_data.data)}')
630638
saved_version = await artifact_service.save_artifact(
631-
app_name=self._tmpl_attrs.get("app_name"),
639+
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
632640
user_id=request.user_id,
633641
session_id=session.id,
634642
filename=artifact.file_name,

0 commit comments

Comments
 (0)