Skip to content

Commit 6eac090

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add support for Vertex Express Mode API key in AdkApp
PiperOrigin-RevId: 825638989
1 parent 59e3004 commit 6eac090

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, name: str, model: str):
5454

5555
_TEST_LOCATION = "us-central1"
5656
_TEST_PROJECT = "test-project"
57+
_TEST_API_KEY = "test-api-key"
5758
_TEST_MODEL = "gemini-2.0-flash"
5859
_TEST_USER_ID = "test_user_id"
5960
_TEST_AGENT_NAME = "test_agent"
@@ -852,6 +853,41 @@ def test_dump_event_for_json():
852853
assert base64.b64decode(part["thought_signature"]) == raw_signature
853854

854855

856+
def test_adk_app_initialization_with_api_key():
857+
importlib.reload(initializer)
858+
importlib.reload(vertexai)
859+
try:
860+
vertexai.init(api_key=_TEST_API_KEY)
861+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
862+
assert app._tmpl_attrs.get("project") is None
863+
assert app._tmpl_attrs.get("location") is None
864+
assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY
865+
assert app._tmpl_attrs.get("runner") is None
866+
app.set_up()
867+
assert app._tmpl_attrs.get("runner") is not None
868+
assert os.environ.get("GOOGLE_API_KEY") == _TEST_API_KEY
869+
assert "GOOGLE_CLOUD_LOCATION" not in os.environ
870+
assert "GOOGLE_CLOUD_PROJECT" not in os.environ
871+
finally:
872+
initializer.global_pool.shutdown(wait=True)
873+
874+
875+
def test_adk_app_initialization_with_env_api_key():
876+
try:
877+
os.environ["GOOGLE_API_KEY"] == _TEST_API_KEY
878+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
879+
assert app._tmpl_attrs.get("project") is None
880+
assert app._tmpl_attrs.get("location") is None
881+
assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY
882+
assert app._tmpl_attrs.get("runner") is None
883+
app.set_up()
884+
assert app._tmpl_attrs.get("runner") is not None
885+
assert "GOOGLE_CLOUD_LOCATION" not in os.environ
886+
assert "GOOGLE_CLOUD_PROJECT" not in os.environ
887+
finally:
888+
initializer.global_pool.shutdown(wait=True)
889+
890+
855891
@pytest.mark.usefixtures("mock_adk_version")
856892
class TestAdkAppErrors:
857893
@pytest.mark.asyncio

vertexai/agent_engines/templates/adk.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def __init__(
534534
If not provided, a default instrumentor builder will be used.
535535
This parameter is ignored if `enable_tracing` is False.
536536
"""
537+
import os
537538
from google.cloud.aiplatform import initializer
538539

539540
adk_version = get_adk_version()
@@ -571,6 +572,10 @@ def __init__(
571572
"artifact_service_builder": artifact_service_builder,
572573
"memory_service_builder": memory_service_builder,
573574
"instrumentor_builder": instrumentor_builder,
575+
"express_mode_api_key": (
576+
initializer.global_config.api_key
577+
or os.environ.get("GOOGLE_API_KEY")
578+
),
574579
}
575580

576581
async def _init_session(
@@ -708,9 +713,18 @@ def set_up(self):
708713

709714
os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
710715
project = self._tmpl_attrs.get("project")
711-
os.environ["GOOGLE_CLOUD_PROJECT"] = project
716+
if project:
717+
os.environ["GOOGLE_CLOUD_PROJECT"] = project
712718
location = self._tmpl_attrs.get("location")
713-
os.environ["GOOGLE_CLOUD_LOCATION"] = location
719+
if location:
720+
os.environ["GOOGLE_CLOUD_LOCATION"] = location
721+
express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key")
722+
if express_mode_api_key and not project:
723+
os.environ["GOOGLE_API_KEY"] = express_mode_api_key
724+
# Clear location and project env vars if express mode api key is provided.
725+
os.environ.pop("GOOGLE_CLOUD_LOCATION", None)
726+
os.environ.pop("GOOGLE_CLOUD_PROJECT", None)
727+
location = None
714728

715729
# Disable content capture in custom ADK spans unless user enabled
716730
# tracing explicitly with the old flag
@@ -757,6 +771,8 @@ def set_up(self):
757771
VertexAiSessionService,
758772
)
759773

774+
# If the express mode api key is set, it will be read from the
775+
# environment variable when initializing the session service.
760776
self._tmpl_attrs["session_service"] = VertexAiSessionService(
761777
project=project,
762778
location=location,
@@ -767,6 +783,8 @@ def set_up(self):
767783
VertexAiSessionService,
768784
)
769785

786+
# If the express mode api key is set, it will be read from the
787+
# environment variable when initializing the session service.
770788
self._tmpl_attrs["session_service"] = VertexAiSessionService(
771789
project=project,
772790
location=location,
@@ -787,6 +805,8 @@ def set_up(self):
787805
VertexAiMemoryBankService,
788806
)
789807

808+
# If the express mode api key is set, it will be read from the
809+
# environment variable when initializing the memory service.
790810
self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService(
791811
project=project,
792812
location=location,

0 commit comments

Comments
 (0)