Skip to content

Commit 715cc5b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Enabling Few-shot Prompt Optimization by passing either "OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS" or "OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE" to the optimize_prompt method
together with example dataframe PiperOrigin-RevId: 848237033
1 parent d9c6eb1 commit 715cc5b

File tree

8 files changed

+575
-102
lines changed

8 files changed

+575
-102
lines changed

tests/unit/vertexai/genai/replays/test_prompt_optimizer_async_optimize_prompt_return_type.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
19+
import pandas as pd
1920
import pytest
2021

2122

@@ -32,6 +33,75 @@ async def test_optimize_prompt(client):
3233
assert response.raw_text_response
3334

3435

36+
@pytest.mark.asyncio
37+
async def test_optimize_prompt_w_optimization_target(client):
38+
"""Tests the optimize request parameters method with optimization target."""
39+
test_prompt = "Generate system instructions for analyzing medical articles"
40+
response = await client.aio.prompt_optimizer.optimize_prompt(
41+
prompt=test_prompt,
42+
config=types.OptimizeConfig(
43+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
44+
),
45+
)
46+
assert isinstance(response, types.OptimizeResponse)
47+
assert response.raw_text_response
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_optimize_prompt_w_few_shot_optimization_target(client):
52+
"""Tests the optimize request parameters method with few shot optimization target."""
53+
test_prompt = "Generate system instructions for analyzing medical articles"
54+
df = pd.DataFrame(
55+
{
56+
"prompt": ["prompt1", "prompt2"],
57+
"model_response": ["response1", "response2"],
58+
"target_response": ["target1", "target2"],
59+
}
60+
)
61+
response = await client.aio.prompt_optimizer.optimize_prompt(
62+
prompt=test_prompt,
63+
config=types.OptimizeConfig(
64+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
65+
examples_dataframe=df,
66+
),
67+
)
68+
assert isinstance(response, types.OptimizeResponse)
69+
assert response.raw_text_response
70+
assert isinstance(response.raw_text_response, str)
71+
if response.parsed_response:
72+
assert isinstance(
73+
response.parsed_response, types.prompt_optimizer.ParsedResponseFewShot
74+
)
75+
76+
77+
@pytest.mark.asyncio
78+
async def test_optimize_prompt_w_few_shot_optimization_rubrics(client):
79+
"""Tests the optimize request parameters method with few shot optimization target."""
80+
test_prompt = "Generate system instructions for analyzing medical articles"
81+
df = pd.DataFrame(
82+
{
83+
"prompt": ["prompt1", "prompt2"],
84+
"model_response": ["response1", "response2"],
85+
"rubrics": ["rubric1", "rubric2"],
86+
"rubrics_evaluations": ["[True, True]", "[True, False]"],
87+
}
88+
)
89+
response = await client.aio.prompt_optimizer.optimize_prompt(
90+
prompt=test_prompt,
91+
config=types.OptimizeConfig(
92+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
93+
examples_dataframe=df,
94+
),
95+
)
96+
assert isinstance(response, types.OptimizeResponse)
97+
assert response.raw_text_response
98+
assert isinstance(response.raw_text_response, str)
99+
if response.parsed_response:
100+
assert isinstance(
101+
response.parsed_response, types.prompt_optimizer.ParsedResponseFewShot
102+
)
103+
104+
35105
pytestmark = pytest_helper.setup(
36106
file=__file__,
37107
globals_for_file=globals(),

tests/unit/vertexai/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai._genai import types
19+
import pandas as pd
1920

2021

2122
def test_optimize_prompt(client):
@@ -27,18 +28,70 @@ def test_optimize_prompt(client):
2728
assert response.raw_text_response
2829

2930

30-
# def test_optimize_prompt_w_optimization_target(client):
31-
# """Tests the optimize request parameters method with optimization target."""
32-
# from google.genai import types as genai_types
33-
# test_prompt = "Generate system instructions for analyzing medical articles"
34-
# response = client.prompt_optimizer.optimize_prompt(
35-
# prompt=test_prompt,
36-
# config=types.OptimizeConfig(
37-
# optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
38-
# ),
39-
# )
40-
# assert isinstance(response, types.OptimizeResponse)
41-
# assert response.raw_text_response
31+
def test_optimize_prompt_w_optimization_target(client):
32+
"""Tests the optimize request parameters method with optimization target."""
33+
test_prompt = "Generate system instructions for analyzing medical articles"
34+
response = client.prompt_optimizer.optimize_prompt(
35+
prompt=test_prompt,
36+
config=types.OptimizeConfig(
37+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO,
38+
),
39+
)
40+
assert isinstance(response, types.OptimizeResponse)
41+
assert response.raw_text_response
42+
43+
44+
def test_optimize_prompt_w_few_shot_optimization_target(client):
45+
"""Tests the optimize request parameters method with few shot optimization target."""
46+
test_prompt = "Generate system instructions for analyzing medical articles"
47+
df = pd.DataFrame(
48+
{
49+
"prompt": ["prompt1", "prompt2"],
50+
"model_response": ["response1", "response2"],
51+
"target_response": ["target1", "target2"],
52+
}
53+
)
54+
response = client.prompt_optimizer.optimize_prompt(
55+
prompt=test_prompt,
56+
config=types.OptimizeConfig(
57+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
58+
examples_dataframe=df,
59+
),
60+
)
61+
assert isinstance(response, types.OptimizeResponse)
62+
assert response.raw_text_response
63+
assert isinstance(response.raw_text_response, str)
64+
if response.parsed_response:
65+
assert isinstance(
66+
response.parsed_response, types.prompt_optimizer.ParsedResponseFewShot
67+
)
68+
69+
70+
def test_optimize_prompt_w_few_shot_optimization_rubrics(client):
71+
"""Tests the optimize request parameters method with few shot optimization target."""
72+
test_prompt = "Generate system instructions for analyzing medical articles"
73+
df = pd.DataFrame(
74+
{
75+
"prompt": ["prompt1", "prompt2"],
76+
"model_response": ["response1", "response2"],
77+
"rubrics": ["rubric1", "rubric2"],
78+
"rubrics_evaluations": ["[True, True]", "[True, False]"],
79+
}
80+
)
81+
response = client.prompt_optimizer.optimize_prompt(
82+
prompt=test_prompt,
83+
config=types.OptimizeConfig(
84+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
85+
examples_dataframe=df,
86+
),
87+
)
88+
assert isinstance(response, types.OptimizeResponse)
89+
assert response.raw_text_response
90+
assert isinstance(response.raw_text_response, str)
91+
if response.parsed_response:
92+
assert isinstance(
93+
response.parsed_response, types.prompt_optimizer.ParsedResponseFewShot
94+
)
4295

4396

4497
pytestmark = pytest_helper.setup(

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from vertexai._genai import prompt_optimizer
2222
from vertexai._genai import types
2323
from google.genai import client
24+
import pandas as pd
2425
import pytest
2526

2627

@@ -91,6 +92,34 @@ def test_prompt_optimizer_optimize_prompt(
9192
mock_client.assert_called_once()
9293
mock_custom_optimize_prompt.assert_called_once()
9394

95+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
96+
def test_prompt_optimizer_optimize_few_shot(self, mock_custom_optimize_prompt):
97+
"""Test that prompt_optimizer.optimize method for few shot optimizer."""
98+
df = pd.DataFrame(
99+
{
100+
"prompt": ["prompt1", "prompt2"],
101+
"model_response": ["response1", "response2"],
102+
"target_response": ["target1", "target2"],
103+
}
104+
)
105+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
106+
test_config = types.OptimizeConfig(
107+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
108+
examples_dataframe=df,
109+
)
110+
test_client.prompt_optimizer.optimize_prompt(
111+
prompt="test_prompt",
112+
config=test_config,
113+
)
114+
mock_custom_optimize_prompt.assert_called_once()
115+
mock_kwargs = mock_custom_optimize_prompt.call_args.kwargs
116+
assert (
117+
mock_kwargs["config"].optimization_target == test_config.optimization_target
118+
)
119+
pd.testing.assert_frame_equal(
120+
mock_kwargs["config"].examples_dataframe, test_config.examples_dataframe
121+
)
122+
94123
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
95124
def test_prompt_optimizer_optimize_prompt_with_optimization_target(
96125
self, mock_custom_optimize_prompt
@@ -138,4 +167,59 @@ async def test_async_prompt_optimizer_optimize_prompt_with_optimization_target(
138167
config=config,
139168
)
140169

170+
@pytest.mark.asyncio
171+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
172+
async def test_async_prompt_optimizer_optimize_prompt_few_shot_target_response(
173+
self, mock_custom_optimize_prompt
174+
):
175+
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot target response."""
176+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
177+
df = pd.DataFrame(
178+
{
179+
"prompt": ["prompt1", "prompt2"],
180+
"model_response": ["response1", "response2"],
181+
"target_response": ["target1", "target2"],
182+
}
183+
)
184+
config = types.OptimizeConfig(
185+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE,
186+
examples_dataframe=df,
187+
)
188+
await test_client.aio.prompt_optimizer.optimize_prompt(
189+
prompt="test_prompt",
190+
config=config,
191+
)
192+
mock_custom_optimize_prompt.assert_called_once_with(
193+
content=mock.ANY,
194+
config=config,
195+
)
196+
197+
@pytest.mark.asyncio
198+
@mock.patch.object(prompt_optimizer.AsyncPromptOptimizer, "_custom_optimize_prompt")
199+
async def test_async_prompt_optimizer_optimize_prompt_few_shot_rubrics(
200+
self, mock_custom_optimize_prompt
201+
):
202+
"""Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot rubrics."""
203+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
204+
df = pd.DataFrame(
205+
{
206+
"prompt": ["prompt1", "prompt2"],
207+
"model_response": ["response1", "response2"],
208+
"rubrics": ["rubric1", "rubric2"],
209+
"rubrics_evaluations": ["[True, True]", "[True, False]"],
210+
}
211+
)
212+
config = types.OptimizeConfig(
213+
optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS,
214+
examples_dataframe=df,
215+
)
216+
await test_client.aio.prompt_optimizer.optimize_prompt(
217+
prompt="test_prompt",
218+
config=config,
219+
)
220+
mock_custom_optimize_prompt.assert_called_once_with(
221+
content=mock.ANY,
222+
config=config,
223+
)
224+
141225
# # TODO(b/415060797): add more tests for prompt_optimizer.optimize

0 commit comments

Comments
 (0)