Skip to content

Commit 435b3cc

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Gen AI SDK client - Fix bug in GCS bucket creation for new agent engines.
PiperOrigin-RevId: 834146214
1 parent e8b12cb commit 435b3cc

File tree

3 files changed

+182
-3
lines changed

3 files changed

+182
-3
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,77 @@ def test_create_base64_encoded_tarball_outside_project_dir_raises(self):
14501450
finally:
14511451
os.chdir(origin_dir)
14521452

1453+
@mock.patch.object(_agent_engines_utils, "_upload_requirements")
1454+
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
1455+
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
1456+
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
1457+
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
1458+
def test_prepare_with_creds(
1459+
self,
1460+
mock_get_gcs_bucket,
1461+
mock_scan_requirements,
1462+
mock_upload_agent_engine,
1463+
mock_upload_extra_packages,
1464+
mock_upload_requirements,
1465+
):
1466+
mock_scan_requirements.return_value = {}
1467+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
1468+
mock_creds.universe_domain = "googleapis.com"
1469+
_agent_engines_utils._prepare(
1470+
agent=self.test_agent,
1471+
project=_TEST_PROJECT,
1472+
location=_TEST_LOCATION,
1473+
staging_bucket=_TEST_STAGING_BUCKET,
1474+
credentials=mock_creds,
1475+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1476+
requirements=[],
1477+
extra_packages=[],
1478+
)
1479+
mock_upload_agent_engine.assert_called_once_with(
1480+
agent=self.test_agent,
1481+
gcs_bucket=mock.ANY,
1482+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1483+
)
1484+
1485+
@mock.patch.object(_agent_engines_utils, "_upload_requirements")
1486+
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
1487+
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
1488+
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
1489+
@mock.patch("google.auth.default")
1490+
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
1491+
def test_prepare_without_creds(
1492+
self,
1493+
mock_get_gcs_bucket,
1494+
mock_auth_default,
1495+
mock_scan_requirements,
1496+
mock_upload_agent_engine,
1497+
mock_upload_extra_packages,
1498+
mock_upload_requirements,
1499+
):
1500+
mock_scan_requirements.return_value = {}
1501+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
1502+
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
1503+
_agent_engines_utils._prepare(
1504+
agent=self.test_agent,
1505+
project=_TEST_PROJECT,
1506+
location=_TEST_LOCATION,
1507+
staging_bucket=_TEST_STAGING_BUCKET,
1508+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1509+
requirements=[],
1510+
extra_packages=[],
1511+
)
1512+
mock_get_gcs_bucket.assert_called_once_with(
1513+
project=_TEST_PROJECT,
1514+
location=_TEST_LOCATION,
1515+
staging_bucket=_TEST_STAGING_BUCKET,
1516+
credentials=None,
1517+
)
1518+
mock_upload_agent_engine.assert_called_once_with(
1519+
agent=self.test_agent,
1520+
gcs_bucket=mock.ANY,
1521+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1522+
)
1523+
14531524

14541525
@pytest.mark.usefixtures("google_auth_mock")
14551526
class TestAgentEngine:
@@ -2622,6 +2693,109 @@ def test_operation_schemas(
26222693
want_operation_schemas.append(want_operation_schema)
26232694
assert test_agent_engine.operation_schemas() == want_operation_schemas
26242695

2696+
@mock.patch.object(_agent_engines_utils, "_prepare")
2697+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2698+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2699+
def test_create_agent_engine_with_creds(
2700+
self, mock_await_operation, mock_create, mock_prepare
2701+
):
2702+
mock_operation = mock.Mock()
2703+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2704+
mock_create.return_value = mock_operation
2705+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2706+
response=_genai_types.ReasoningEngine(
2707+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2708+
spec=_TEST_AGENT_ENGINE_SPEC,
2709+
)
2710+
)
2711+
self.client.agent_engines.create(
2712+
agent=self.test_agent,
2713+
config=_genai_types.AgentEngineConfig(
2714+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2715+
staging_bucket=_TEST_STAGING_BUCKET,
2716+
),
2717+
)
2718+
mock_args, mock_kwargs = mock_prepare.call_args
2719+
assert mock_kwargs["agent"] == self.test_agent
2720+
assert mock_kwargs["extra_packages"] == []
2721+
assert mock_kwargs["project"] == _TEST_PROJECT
2722+
assert mock_kwargs["location"] == _TEST_LOCATION
2723+
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
2724+
assert mock_kwargs["credentials"] == _TEST_CREDENTIALS
2725+
assert mock_kwargs["gcs_dir_name"] == "agent_engine"
2726+
2727+
@mock.patch.object(_agent_engines_utils, "_prepare")
2728+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2729+
@mock.patch("google.auth.default")
2730+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2731+
def test_create_agent_engine_without_creds(
2732+
self, mock_await_operation, mock_auth_default, mock_create, mock_prepare
2733+
):
2734+
mock_operation = mock.Mock()
2735+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2736+
mock_create.return_value = mock_operation
2737+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2738+
response=_genai_types.ReasoningEngine(
2739+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2740+
spec=_TEST_AGENT_ENGINE_SPEC,
2741+
)
2742+
)
2743+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
2744+
mock_creds.quota_project_id = _TEST_PROJECT
2745+
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
2746+
client = vertexai.Client(
2747+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=mock_creds
2748+
)
2749+
client.agent_engines.create(
2750+
agent=self.test_agent,
2751+
config=_genai_types.AgentEngineConfig(
2752+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2753+
staging_bucket=_TEST_STAGING_BUCKET,
2754+
),
2755+
)
2756+
mock_args, mock_kwargs = mock_prepare.call_args
2757+
assert mock_kwargs["agent"] == self.test_agent
2758+
assert mock_kwargs["extra_packages"] == []
2759+
assert mock_kwargs["project"] == _TEST_PROJECT
2760+
assert mock_kwargs["location"] == _TEST_LOCATION
2761+
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
2762+
assert mock_kwargs["credentials"] == mock_creds
2763+
assert mock_kwargs["gcs_dir_name"] == "agent_engine"
2764+
2765+
@mock.patch.object(_agent_engines_utils, "_prepare")
2766+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2767+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2768+
def test_create_agent_engine_with_no_creds_in_client(
2769+
self, mock_await_operation, mock_create, mock_prepare
2770+
):
2771+
mock_operation = mock.Mock()
2772+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2773+
mock_create.return_value = mock_operation
2774+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2775+
response=_genai_types.ReasoningEngine(
2776+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2777+
spec=_TEST_AGENT_ENGINE_SPEC,
2778+
)
2779+
)
2780+
client = vertexai.Client(
2781+
project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=None
2782+
)
2783+
client.agent_engines.create(
2784+
agent=self.test_agent,
2785+
config=_genai_types.AgentEngineConfig(
2786+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2787+
staging_bucket=_TEST_STAGING_BUCKET,
2788+
),
2789+
)
2790+
mock_args, mock_kwargs = mock_prepare.call_args
2791+
assert mock_kwargs["agent"] == self.test_agent
2792+
assert mock_kwargs["extra_packages"] == []
2793+
assert mock_kwargs["project"] == _TEST_PROJECT
2794+
assert mock_kwargs["location"] == _TEST_LOCATION
2795+
assert mock_kwargs["staging_bucket"] == _TEST_STAGING_BUCKET
2796+
assert mock_kwargs["credentials"] is None
2797+
assert mock_kwargs["gcs_dir_name"] == "agent_engine"
2798+
26252799

26262800
@pytest.mark.usefixtures("google_auth_mock")
26272801
class TestAgentEngineErrors:

vertexai/_genai/_agent_engines_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -772,10 +772,11 @@ def _get_gcs_bucket(
772772
project: str,
773773
location: str,
774774
staging_bucket: str,
775+
credentials: Optional[Any] = None,
775776
) -> _StorageBucket:
776777
"""Gets or creates the GCS bucket."""
777778
storage = _import_cloud_storage_or_raise()
778-
storage_client = storage.Client(project=project)
779+
storage_client = storage.Client(project=project, credentials=credentials)
779780
staging_bucket = staging_bucket.replace("gs://", "")
780781
try:
781782
gcs_bucket = storage_client.get_bucket(staging_bucket)
@@ -910,6 +911,7 @@ def _prepare(
910911
location: str,
911912
staging_bucket: str,
912913
gcs_dir_name: str,
914+
credentials: Optional[Any] = None,
913915
) -> None:
914916
"""Prepares the agent engine for creation or updates in Vertex AI.
915917
@@ -926,15 +928,17 @@ def _prepare(
926928
project (str): The project for the staging bucket.
927929
location (str): The location for the staging bucket.
928930
staging_bucket (str): The staging bucket name in the form "gs://...".
929-
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to
930-
use for staging the artifacts needed.
931+
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to use
932+
for staging the artifacts needed.
933+
credentials: The credentials to use for the storage client.
931934
"""
932935
if agent is None:
933936
return
934937
gcs_bucket = _get_gcs_bucket(
935938
project=project,
936939
location=location,
937940
staging_bucket=staging_bucket,
941+
credentials=credentials,
938942
)
939943
_upload_agent_engine(
940944
agent=agent,

vertexai/_genai/agent_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,7 @@ def _create_config(
10821082
staging_bucket=staging_bucket,
10831083
gcs_dir_name=gcs_dir_name,
10841084
extra_packages=extra_packages,
1085+
credentials=self._api_client._credentials,
10851086
)
10861087
# Update the package spec.
10871088
update_masks.append("spec.package_spec.pickle_object_gcs_uri")

0 commit comments

Comments
 (0)