Skip to content

Commit c934f1e

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add the identity type option for the agent engine and add effective identity to the resource
PiperOrigin-RevId: 819733517
1 parent 7dd2629 commit c934f1e

File tree

3 files changed

+90
-7
lines changed

3 files changed

+90
-7
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ def register_operations(self) -> Dict[str, List[str]]:
526526
}
527527
_TEST_AGENT_ENGINE_CONTAINER_CONCURRENCY = 4
528528
_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT = "test-custom-service-account"
529+
_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT = (
530+
_genai_types.IdentityType.SERVICE_ACCOUNT
531+
)
529532
_TEST_AGENT_ENGINE_ENCRYPTION_SPEC = {"kms_key_name": "test-kms-key"}
530533
_TEST_AGENT_ENGINE_SPEC = _genai_types.ReasoningEngineSpecDict(
531534
agent_framework=_TEST_AGENT_ENGINE_FRAMEWORK,
@@ -552,6 +555,7 @@ def register_operations(self) -> Dict[str, List[str]]:
552555
requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
553556
),
554557
service_account=_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
558+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
555559
)
556560
_TEST_AGENT_ENGINE_STREAM_QUERY_RESPONSE = [{"output": "hello"}, {"output": "world"}]
557561
_TEST_AGENT_ENGINE_OPERATION_SCHEMAS = []
@@ -858,6 +862,7 @@ def test_create_agent_engine_config_full(self, mock_prepare):
858862
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
859863
env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT,
860864
service_account=_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
865+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
861866
psc_interface_config=_TEST_AGENT_ENGINE_PSC_INTERFACE_CONFIG,
862867
min_instances=_TEST_AGENT_ENGINE_MIN_INSTANCES,
863868
max_instances=_TEST_AGENT_ENGINE_MAX_INSTANCES,
@@ -900,6 +905,10 @@ def test_create_agent_engine_config_full(self, mock_prepare):
900905
config["spec"]["service_account"]
901906
== _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT
902907
)
908+
assert (
909+
config["spec"]["identity_type"]
910+
== _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT
911+
)
903912

904913
@mock.patch.object(_agent_engines_utils, "_prepare")
905914
def test_update_agent_engine_config_full(self, mock_prepare):
@@ -914,6 +923,7 @@ def test_update_agent_engine_config_full(self, mock_prepare):
914923
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
915924
env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT,
916925
service_account=_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
926+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
917927
)
918928
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
919929
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
@@ -944,21 +954,45 @@ def test_update_agent_engine_config_full(self, mock_prepare):
944954
config["spec"]["service_account"]
945955
== _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT
946956
)
957+
assert (
958+
config["spec"]["identity_type"]
959+
== _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT
960+
)
947961
assert config["update_mask"] == ",".join(
948962
[
949963
"display_name",
950964
"description",
965+
"spec.identity_type",
966+
"spec.service_account",
951967
"spec.package_spec.pickle_object_gcs_uri",
952968
"spec.package_spec.dependency_files_gcs_uri",
953969
"spec.package_spec.requirements_gcs_uri",
954970
"spec.deployment_spec.env",
955971
"spec.deployment_spec.secret_env",
956-
"spec.service_account",
957972
"spec.class_methods",
958973
"spec.agent_framework",
959974
]
960975
)
961976

977+
@mock.patch.object(_agent_engines_utils, "_prepare")
978+
def test_update_agent_engine_clear_service_account(self, mock_prepare):
979+
config = self.client.agent_engines._create_config(
980+
mode="update",
981+
service_account="",
982+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
983+
)
984+
assert "service_account" not in config["spec"].keys()
985+
assert (
986+
config["spec"]["identity_type"]
987+
== _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT
988+
)
989+
assert config["update_mask"] == ",".join(
990+
[
991+
"spec.identity_type",
992+
"spec.service_account",
993+
]
994+
)
995+
962996
def test_get_agent_operation(self):
963997
with mock.patch.object(
964998
self.client.agent_engines._api_client, "request"
@@ -1355,6 +1389,7 @@ def test_create_agent_engine_with_env_vars_dict(
13551389
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
13561390
env_vars=_TEST_AGENT_ENGINE_ENV_VARS_INPUT,
13571391
service_account=None,
1392+
identity_type=None,
13581393
context_spec=None,
13591394
psc_interface_config=None,
13601395
min_instances=None,
@@ -1403,6 +1438,7 @@ def test_create_agent_engine_with_custom_service_account(
14031438
},
14041439
"class_methods": [_TEST_AGENT_ENGINE_CLASS_METHOD_1],
14051440
"service_account": _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
1441+
"identity_type": _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
14061442
"agent_framework": _TEST_AGENT_ENGINE_FRAMEWORK,
14071443
},
14081444
}
@@ -1424,6 +1460,7 @@ def test_create_agent_engine_with_custom_service_account(
14241460
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
14251461
staging_bucket=_TEST_STAGING_BUCKET,
14261462
service_account=_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
1463+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
14271464
),
14281465
)
14291466
mock_create_config.assert_called_with(
@@ -1437,6 +1474,7 @@ def test_create_agent_engine_with_custom_service_account(
14371474
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
14381475
env_vars=None,
14391476
service_account=_TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
1477+
identity_type=_TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
14401478
context_spec=None,
14411479
psc_interface_config=None,
14421480
min_instances=None,
@@ -1463,6 +1501,7 @@ def test_create_agent_engine_with_custom_service_account(
14631501
"requirements_gcs_uri": _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI,
14641502
},
14651503
"service_account": _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT,
1504+
"identity_type": _TEST_AGENT_ENGINE_IDENTITY_TYPE_SERVICE_ACCOUNT,
14661505
},
14671506
},
14681507
None,
@@ -1521,6 +1560,7 @@ def test_create_agent_engine_with_experimental_mode(
15211560
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
15221561
env_vars=None,
15231562
service_account=None,
1563+
identity_type=None,
15241564
context_spec=None,
15251565
psc_interface_config=None,
15261566
min_instances=None,
@@ -1603,6 +1643,7 @@ def test_create_agent_engine_with_class_methods(
16031643
extra_packages=[_TEST_AGENT_ENGINE_EXTRA_PACKAGE_PATH],
16041644
env_vars=None,
16051645
service_account=None,
1646+
identity_type=None,
16061647
context_spec=None,
16071648
psc_interface_config=None,
16081649
min_instances=None,

vertexai/_genai/agent_engines.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,7 @@ def create(
853853
api_config = self._create_config(
854854
mode="create",
855855
agent=agent,
856+
identity_type=config.identity_type,
856857
staging_bucket=config.staging_bucket,
857858
requirements=config.requirements,
858859
display_name=config.display_name,
@@ -912,6 +913,7 @@ def _create_config(
912913
*,
913914
mode: str,
914915
agent: Any = None,
916+
identity_type: Optional[types.IdentityType] = None,
915917
staging_bucket: Optional[str] = None,
916918
requirements: Optional[Union[str, Sequence[str]]] = None,
917919
display_name: Optional[str] = None,
@@ -957,6 +959,17 @@ def _create_config(
957959
if labels is not None:
958960
update_masks.append("labels")
959961
config["labels"] = labels
962+
963+
agent_engine_spec: types.ReasoningEngineSpecDict = {}
964+
if identity_type is not None:
965+
agent_engine_spec["identity_type"] = identity_type
966+
update_masks.append("spec.identity_type")
967+
if service_account is not None:
968+
# Clear the field in case of empty service_account.
969+
if service_account:
970+
agent_engine_spec["service_account"] = service_account
971+
update_masks.append("spec.service_account")
972+
960973
if agent is not None:
961974
project = self._api_client.project
962975
if project is None:
@@ -1013,9 +1026,7 @@ def _create_config(
10131026
gcs_dir_name,
10141027
_agent_engines_utils._REQUIREMENTS_FILE,
10151028
)
1016-
agent_engine_spec: types.ReasoningEngineSpecDict = {
1017-
"package_spec": package_spec,
1018-
}
1029+
agent_engine_spec["package_spec"] = package_spec
10191030
if (
10201031
env_vars is not None
10211032
or psc_interface_config is not None
@@ -1037,9 +1048,6 @@ def _create_config(
10371048
)
10381049
update_masks.extend(deployment_update_masks)
10391050
agent_engine_spec["deployment_spec"] = deployment_spec
1040-
if service_account is not None:
1041-
agent_engine_spec["service_account"] = service_account
1042-
update_masks.append("spec.service_account")
10431051

10441052
update_masks.append("spec.class_methods")
10451053
class_methods_spec = []
@@ -1076,7 +1084,10 @@ def _create_config(
10761084
_agent_engines_utils._get_agent_framework(agent=agent)
10771085
)
10781086
update_masks.append("spec.agent_framework")
1087+
1088+
if agent_engine_spec.items():
10791089
config["spec"] = agent_engine_spec
1090+
10801091
if update_masks and mode == "update":
10811092
config["update_mask"] = ",".join(update_masks)
10821093
return config
@@ -1280,6 +1291,7 @@ def update(
12801291
api_config = self._create_config(
12811292
mode="update",
12821293
agent=agent,
1294+
identity_type=config.identity_type,
12831295
staging_bucket=config.staging_bucket,
12841296
requirements=config.requirements,
12851297
display_name=config.display_name,

vertexai/_genai/types.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,16 @@ class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum)
360360
"""The memory was deleted."""
361361

362362

363+
class IdentityType(_common.CaseInSensitiveEnum):
364+
"""The identity type for the agent engine."""
365+
IDENTITY_TYPE_UNSPECIFIED = "IDENTITY_TYPE_UNSPECIFIED"
366+
"""The identity type is unspecified. Use a custom service account if the `service_account` field is set, otherwise use the default Vertex AI Reasoning Engine Service Agent in the project."""
367+
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
368+
"""Use a custom service account if the `service_account` field is set, otherwise use the default Vertex AI Reasoning Engine Service Agent in the project."""
369+
AGENT_IDENTITY = "AGENT_IDENTITY"
370+
"""Use the Agent Identity to access the resources."""
371+
372+
363373
class SamplingConfig(_common.BaseModel):
364374
"""Sampling config for a BigQuery request set."""
365375

@@ -4790,6 +4800,14 @@ class ReasoningEngineSpec(_common.BaseModel):
47904800
default=None,
47914801
description="""Optional. The service account that the Reasoning Engine artifact runs as. It should have "roles/storage.objectViewer" for reading the user project's Cloud Storage and "roles/aiplatform.user" for using Vertex extensions. If not specified, the Vertex AI Reasoning Engine Service Agent in the project will be used.""",
47924802
)
4803+
identity_type: Optional[IdentityType] = Field(
4804+
default=None,
4805+
description="""Optional. The identity type for the Reasoning Engine. If not specified, the default value is `IDENTITY_TYPE_UNSPECIFIED`.""",
4806+
)
4807+
effective_identity: Optional[str] = Field(
4808+
default=None,
4809+
description="""Output only. The identity to be used for the Reasoning Engine. If not specified, the default value is the service account specified in `service_account`.""",
4810+
)
47934811

47944812

47954813
class ReasoningEngineSpecDict(TypedDict, total=False):
@@ -4810,6 +4828,11 @@ class ReasoningEngineSpecDict(TypedDict, total=False):
48104828
service_account: Optional[str]
48114829
"""Optional. The service account that the Reasoning Engine artifact runs as. It should have "roles/storage.objectViewer" for reading the user project's Cloud Storage and "roles/aiplatform.user" for using Vertex extensions. If not specified, the Vertex AI Reasoning Engine Service Agent in the project will be used."""
48124830

4831+
identity_type: Optional[IdentityType]
4832+
"""Optional. The identity type for the Reasoning Engine. If not specified, the default value is `IDENTITY_TYPE_UNSPECIFIED`."""
4833+
4834+
effective_identity: Optional[str]
4835+
"""Output only. The identity to be used for the Reasoning Engine. If not specified, the default value is the service account specified in `service_account`."""
48134836

48144837
ReasoningEngineSpecOrDict = Union[ReasoningEngineSpec, ReasoningEngineSpecDict]
48154838

@@ -12198,6 +12221,10 @@ class AgentEngineConfig(_common.BaseModel):
1219812221
agent_server_mode: Optional[AgentServerMode] = Field(
1219912222
default=None, description="""The agent server mode to use for deployment."""
1220012223
)
12224+
identity_type: Optional[IdentityType] = Field(
12225+
default=None,
12226+
description="""The identity type to use for the Agent Engine.""",
12227+
)
1220112228
class_methods: Optional[list[dict[str, Any]]] = Field(
1220212229
default=None,
1220312230
description="""The class methods to be used for the Agent Engine.
@@ -12293,6 +12320,9 @@ class AgentEngineConfigDict(TypedDict, total=False):
1229312320
agent_server_mode: Optional[AgentServerMode]
1229412321
"""The agent server mode to use for deployment."""
1229512322

12323+
identity_type: Optional[IdentityType]
12324+
"""The identity type to use for the Agent Engine."""
12325+
1229612326
class_methods: Optional[list[dict[str, Any]]]
1229712327
"""The class methods to be used for the Agent Engine.
1229812328
If specified, they'll override the class methods that are autogenerated by

0 commit comments

Comments
 (0)