Skip to content

Commit 4216790

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat!: GenAI SDK client - Enabling new data driven prompt optimization for prompts from Android API by passing
`method=types.PromptOptimizerMethod.GEMINI_NANO` Replaces the raw string `method="vapo"` with enum. Available options: `types.OptimizerMethod.VAPO` and `types.PromptOptimizerMethod.GEMINI_NANO`. PiperOrigin-RevId: 825719137
1 parent 30826b1 commit 4216790

File tree

6 files changed

+166
-48
lines changed

6 files changed

+166
-48
lines changed

tests/unit/vertexai/genai/replays/test_prompt_optimizer_optimize_job_state.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import logging
1718
import os
18-
1919
from tests.unit.vertexai.genai.replays import pytest_helper
2020
from vertexai._genai import types
2121
from google.genai import types as genai_types
@@ -38,7 +38,7 @@ def test_optimize(client):
3838

3939
_raise_for_unset_env_vars()
4040

41-
config = types.PromptOptimizerVAPOConfig(
41+
config = types.PromptOptimizerConfig(
4242
config_path=os.environ.get("VAPO_CONFIG_PATH"),
4343
wait_for_completion=True,
4444
service_account_project_number=os.environ.get(
@@ -47,7 +47,33 @@ def test_optimize(client):
4747
optimizer_job_display_name="optimizer_job_test",
4848
)
4949
job = client.prompt_optimizer.optimize(
50-
method="vapo",
50+
method=types.PromptOptimizerMethod.VAPO,
51+
config=config,
52+
)
53+
assert isinstance(job, types.CustomJob)
54+
assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED
55+
56+
57+
def test_optimize_nano(client):
58+
"""Tests the optimize request parameters method."""
59+
60+
_raise_for_unset_env_vars()
61+
62+
config_path = os.environ.get("VAPO_CONFIG_PATH")
63+
root, ext = os.path.splitext(config_path)
64+
nano_path = f"{root}_nano{ext}"
65+
66+
config = types.PromptOptimizerConfig(
67+
config_path=nano_path,
68+
wait_for_completion=True,
69+
service_account_project_number=os.environ.get(
70+
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
71+
),
72+
optimizer_job_display_name="optimizer_job_test",
73+
)
74+
75+
job = client.prompt_optimizer.optimize(
76+
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
5177
config=config,
5278
)
5379
assert isinstance(job, types.CustomJob)
@@ -68,15 +94,37 @@ def test_optimize(client):
6894
async def test_optimize_async(client):
6995
_raise_for_unset_env_vars()
7096

71-
config = types.PromptOptimizerVAPOConfig(
97+
config = types.PromptOptimizerConfig(
7298
config_path=os.environ.get("VAPO_CONFIG_PATH"),
7399
service_account_project_number=os.environ.get(
74100
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
75101
),
76102
optimizer_job_display_name="optimizer_job_test",
77103
)
78104
job = await client.aio.prompt_optimizer.optimize(
79-
method="vapo",
105+
method=types.PromptOptimizerMethod.VAPO,
106+
config=config,
107+
)
108+
assert isinstance(job, types.CustomJob)
109+
assert job.state == genai_types.JobState.JOB_STATE_PENDING
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_optimize_nano_async(client):
114+
_raise_for_unset_env_vars()
115+
config_path = os.environ.get("VAPO_CONFIG_PATH")
116+
root, ext = os.path.splitext(config_path)
117+
nano_path = f"{root}_nano{ext}"
118+
119+
config = types.PromptOptimizerConfig(
120+
config_path=nano_path,
121+
service_account_project_number=os.environ.get(
122+
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
123+
),
124+
optimizer_job_display_name="optimizer_job_test",
125+
)
126+
job = await client.aio.prompt_optimizer.optimize(
127+
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
80128
config=config,
81129
)
82130
assert isinstance(job, types.CustomJob)
@@ -86,8 +134,9 @@ async def test_optimize_async(client):
86134
@pytest.mark.asyncio
87135
async def test_optimize_async_with_config_wait_for_completion(client, caplog):
88136
_raise_for_unset_env_vars()
137+
caplog.set_level(logging.INFO)
89138

90-
config = types.PromptOptimizerVAPOConfig(
139+
config = types.PromptOptimizerConfig(
91140
config_path=os.environ.get("VAPO_CONFIG_PATH"),
92141
service_account_project_number=os.environ.get(
93142
"VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"
@@ -96,7 +145,7 @@ async def test_optimize_async_with_config_wait_for_completion(client, caplog):
96145
wait_for_completion=True,
97146
)
98147
job = await client.aio.prompt_optimizer.optimize(
99-
method="vapo",
148+
method=types.PromptOptimizerMethod.VAPO,
100149
config=config,
101150
)
102151
assert isinstance(job, types.CustomJob)

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,24 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
5454
"""Test that prompt_optimizer.optimize method creates a custom job."""
5555
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
5656
test_client.prompt_optimizer.optimize(
57-
method="vapo",
58-
config=types.PromptOptimizerVAPOConfig(
57+
method=types.PromptOptimizerMethod.VAPO,
58+
config=types.PromptOptimizerConfig(
59+
config_path="gs://ssusie-vapo-sdk-test/config.json",
60+
wait_for_completion=False,
61+
service_account="test-service-account",
62+
),
63+
)
64+
mock_client.assert_called_once()
65+
mock_custom_job.assert_called_once()
66+
67+
@mock.patch.object(client.Client, "_get_api_client")
68+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_create_custom_job_resource")
69+
def test_prompt_optimizer_optimize_nano(self, mock_custom_job, mock_client):
70+
"""Test that prompt_optimizer.optimize method creates a custom job."""
71+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
72+
test_client.prompt_optimizer.optimize(
73+
method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO,
74+
config=types.PromptOptimizerConfig(
5975
config_path="gs://ssusie-vapo-sdk-test/config.json",
6076
wait_for_completion=False,
6177
service_account="test-service-account",

vertexai/_genai/_prompt_optimizer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919

2020

2121
def _get_service_account(
22-
config: types.PromptOptimizerVAPOConfigOrDict,
22+
config: types.PromptOptimizerConfigOrDict,
2323
) -> str:
2424
"""Get the service account from the config for the custom job."""
2525
if isinstance(config, dict):
26-
config = types.PromptOptimizerVAPOConfig.model_validate(config)
26+
config = types.PromptOptimizerConfig.model_validate(config)
2727

2828
if config.service_account and config.service_account_project_number:
2929
raise ValueError(

vertexai/_genai/prompt_optimizer.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -407,37 +407,46 @@ def _wait_for_completion(self, job_name: str) -> types.CustomJob:
407407

408408
def optimize(
409409
self,
410-
method: str,
411-
config: types.PromptOptimizerVAPOConfigOrDict,
410+
method: types.PromptOptimizerMethod,
411+
config: types.PromptOptimizerConfigOrDict,
412412
) -> types.CustomJob:
413413
"""Call PO-Data optimizer.
414414
415415
Args:
416-
method: The method for optimizing multiple prompts.
417-
config: PromptOptimizerVAPOConfig instance containing the
416+
method: The method for optimizing multiple prompts. Supported methods:
417+
VAPO, OPTIMIZATION_TARGET_GEMINI_NANO.
418+
config: PromptOptimizerConfig instance containing the
418419
configuration for prompt optimization.
419420
Returns:
420421
The custom job that was created.
421422
"""
422423

423-
if method != "vapo":
424-
raise ValueError("Only vapo method is currently supported.")
425-
426424
if isinstance(config, dict):
427-
config = types.PromptOptimizerVAPOConfig(**config)
425+
config = types.PromptOptimizerConfig(**config)
426+
427+
if not config.config_path:
428+
raise ValueError("Config path is required.")
429+
430+
_OPTIMIZER_METHOD_TO_CONTAINER_URI = {
431+
types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0",
432+
types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/android-apo:preview_v1_0",
433+
}
434+
container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method)
435+
if not container_uri:
436+
raise ValueError(
437+
'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" '
438+
"methods are currently supported."
439+
)
428440

429441
if config.optimizer_job_display_name:
430442
display_name = config.optimizer_job_display_name
431443
else:
432444
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
433-
display_name = f"vapo-optimizer-{timestamp}"
445+
display_name = f"{method.value.lower()}-optimizer-{timestamp}"
446+
434447
wait_for_completion = config.wait_for_completion
435-
if not config.config_path:
436-
raise ValueError("Config path is required.")
437448
bucket = "/".join(config.config_path.split("/")[:-1])
438449

439-
container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0"
440-
441450
region = self._api_client.location
442451
project = self._api_client.project
443452
container_args = {
@@ -766,8 +775,8 @@ async def _get_custom_job(
766775
# Todo: b/428953357 - Add example in the README.
767776
async def optimize(
768777
self,
769-
method: str,
770-
config: types.PromptOptimizerVAPOConfigOrDict,
778+
method: types.PromptOptimizerMethod,
779+
config: types.PromptOptimizerConfigOrDict,
771780
) -> types.CustomJob:
772781
"""Call async Vertex AI Prompt Optimizer (VAPO).
773782
@@ -777,26 +786,37 @@ async def optimize(
777786
778787
Example usage:
779788
client = vertexai.Client(project=PROJECT_NAME, location='us-central1')
780-
vapo_config = vertexai.types.PromptOptimizerVAPOConfig(
789+
vapo_config = vertexai.types.PromptOptimizerConfig(
781790
config_path='gs://you-bucket-name/your-config.json',
782791
service_account=service_account,
783792
)
784793
job = await client.aio.prompt_optimizer.optimize(
785-
method='vapo', config=vapo_config)
794+
method=types.PromptOptimizerMethod.VAPO, config=vapo_config)
786795
787796
Args:
788-
method: The method for optimizing multiple prompts (currently only
789-
vapo is supported).
790-
config: PromptOptimizerVAPOConfig instance containing the
797+
method: The method for optimizing multiple prompts. Supported methods:
798+
VAPO, OPTIMIZATION_TARGET_GEMINI_NANO.
799+
config: PromptOptimizerConfig instance containing the
791800
configuration for prompt optimization.
792801
Returns:
793802
The custom job that was created.
794803
"""
795-
if method != "vapo":
796-
raise ValueError("Only vapo methods is currently supported.")
797-
798804
if isinstance(config, dict):
799-
config = types.PromptOptimizerVAPOConfig(**config)
805+
config = types.PromptOptimizerConfig(**config)
806+
807+
if not config.config_path:
808+
raise ValueError("Config path is required.")
809+
810+
_OPTIMIZER_METHOD_TO_CONTAINER_URI = {
811+
types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0",
812+
types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/android-apo:preview_v1_0",
813+
}
814+
container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method)
815+
if not container_uri:
816+
raise ValueError(
817+
'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" '
818+
"methods are currently supported."
819+
)
800820

801821
if config.wait_for_completion:
802822
logger.info(
@@ -807,14 +827,12 @@ async def optimize(
807827
display_name = config.optimizer_job_display_name
808828
else:
809829
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
810-
display_name = f"vapo-optimizer-{timestamp}"
830+
display_name = f"{method.value.lower()}-optimizer-{timestamp}"
811831

812832
if not config.config_path:
813833
raise ValueError("Config path is required.")
814834
bucket = "/".join(config.config_path.split("/")[:-1])
815835

816-
container_uri = "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0"
817-
818836
region = self._api_client.location
819837
project = self._api_client.project
820838
container_args = {

vertexai/_genai/types/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,9 @@
587587
from .common import OptimizeResponseEndpointDict
588588
from .common import OptimizeResponseEndpointOrDict
589589
from .common import OptimizeResponseOrDict
590+
from .common import OptimizerMethodPlaceholder
591+
from .common import OptimizerMethodPlaceholderDict
592+
from .common import OptimizerMethodPlaceholderOrDict
590593
from .common import OptimizeTarget
591594
from .common import PairwiseChoice
592595
from .common import PairwiseMetricInput
@@ -618,9 +621,10 @@
618621
from .common import PromptDataDict
619622
from .common import PromptDataOrDict
620623
from .common import PromptDict
621-
from .common import PromptOptimizerVAPOConfig
622-
from .common import PromptOptimizerVAPOConfigDict
623-
from .common import PromptOptimizerVAPOConfigOrDict
624+
from .common import PromptOptimizerConfig
625+
from .common import PromptOptimizerConfigDict
626+
from .common import PromptOptimizerConfigOrDict
627+
from .common import PromptOptimizerMethod
624628
from .common import PromptOrDict
625629
from .common import PromptRef
626630
from .common import PromptRefDict
@@ -1739,9 +1743,12 @@
17391743
"UpdateDatasetConfig",
17401744
"UpdateDatasetConfigDict",
17411745
"UpdateDatasetConfigOrDict",
1742-
"PromptOptimizerVAPOConfig",
1743-
"PromptOptimizerVAPOConfigDict",
1744-
"PromptOptimizerVAPOConfigOrDict",
1746+
"PromptOptimizerConfig",
1747+
"PromptOptimizerConfigDict",
1748+
"PromptOptimizerConfigOrDict",
1749+
"OptimizerMethodPlaceholder",
1750+
"OptimizerMethodPlaceholderDict",
1751+
"OptimizerMethodPlaceholderOrDict",
17451752
"ApplicableGuideline",
17461753
"ApplicableGuidelineDict",
17471754
"ApplicableGuidelineOrDict",
@@ -1837,6 +1844,7 @@
18371844
"Importance",
18381845
"OptimizeTarget",
18391846
"GenerateMemoriesResponseGeneratedMemoryAction",
1847+
"PromptOptimizerMethod",
18401848
"PromptData",
18411849
"PromptDataDict",
18421850
"PromptDataOrDict",

vertexai/_genai/types/common.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,15 @@ class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum)
354354
"""The memory was deleted."""
355355

356356

357+
class PromptOptimizerMethod(_common.CaseInSensitiveEnum):
358+
"""The method for data driven prompt optimization."""
359+
360+
VAPO = "VAPO"
361+
"""The default data driven Vertex AI Prompt Optimizer."""
362+
OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO"
363+
"""The data driven prompt optimizer designer for prompts from Android core API."""
364+
365+
357366
class CreateEvaluationItemConfig(_common.BaseModel):
358367
"""Config to create an evaluation item."""
359368

@@ -12025,7 +12034,7 @@ class _UpdateDatasetParametersDict(TypedDict, total=False):
1202512034
]
1202612035

1202712036

12028-
class PromptOptimizerVAPOConfig(_common.BaseModel):
12037+
class PromptOptimizerConfig(_common.BaseModel):
1202912038
"""VAPO Prompt Optimizer Config."""
1203012039

1203112040
config_path: Optional[str] = Field(
@@ -12050,7 +12059,7 @@ class PromptOptimizerVAPOConfig(_common.BaseModel):
1205012059
)
1205112060

1205212061

12053-
class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
12062+
class PromptOptimizerConfigDict(TypedDict, total=False):
1205412063
"""VAPO Prompt Optimizer Config."""
1205512064

1205612065
config_path: Optional[str]
@@ -12069,8 +12078,26 @@ class PromptOptimizerVAPOConfigDict(TypedDict, total=False):
1206912078
"""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used."""
1207012079

1207112080

12072-
PromptOptimizerVAPOConfigOrDict = Union[
12073-
PromptOptimizerVAPOConfig, PromptOptimizerVAPOConfigDict
12081+
PromptOptimizerConfigOrDict = Union[PromptOptimizerConfig, PromptOptimizerConfigDict]
12082+
12083+
12084+
class OptimizerMethodPlaceholder(_common.BaseModel):
12085+
"""Placeholder class to generate OptimizerMethod enum in common.py."""
12086+
12087+
method: Optional[PromptOptimizerMethod] = Field(
12088+
default=None, description="""The method for optimizing multiple prompts."""
12089+
)
12090+
12091+
12092+
class OptimizerMethodPlaceholderDict(TypedDict, total=False):
12093+
"""Placeholder class to generate OptimizerMethod enum in common.py."""
12094+
12095+
method: Optional[PromptOptimizerMethod]
12096+
"""The method for optimizing multiple prompts."""
12097+
12098+
12099+
OptimizerMethodPlaceholderOrDict = Union[
12100+
OptimizerMethodPlaceholder, OptimizerMethodPlaceholderDict
1207412101
]
1207512102

1207612103

0 commit comments

Comments
 (0)