Skip to content

Commit 77fb325

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - support setting autorater generation config for predefined rubric metrics
PiperOrigin-RevId: 833487047
1 parent 186e6d8 commit 77fb325

File tree

4 files changed

+126
-4
lines changed

4 files changed

+126
-4
lines changed

tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from tests.unit.vertexai.genai.replays import pytest_helper
1818
from vertexai import types
19+
from google.genai import types as genai_types
1920
import pandas as pd
2021

2122

@@ -60,6 +61,50 @@ def test_evaluation_result(client):
6061
assert case_result.response_candidate_results is not None
6162

6263

64+
def test_evaluation_result_with_autorater_config(client):
65+
"""Tests that evaluate() produces a correctly structured EvaluationResult."""
66+
prompts_df = pd.DataFrame(
67+
{
68+
"prompt": ["Explain the concept of machine learning in simple terms."],
69+
"response": [
70+
"Machine learning is a type of artificial intelligence that allows"
71+
" computers to learn from data without being explicitly programmed."
72+
],
73+
}
74+
)
75+
76+
eval_dataset = types.EvaluationDataset(
77+
eval_dataset_df=prompts_df,
78+
candidate_name="gemini-2.5-flash",
79+
)
80+
81+
predefined_metric_with_autorater_config = types.RubricMetric.GENERAL_QUALITY(
82+
judge_model_generation_config=genai_types.GenerationConfig(
83+
temperature=0.1,
84+
max_output_tokens=1024,
85+
)
86+
)
87+
88+
evaluation_result = client.evals.evaluate(
89+
dataset=eval_dataset,
90+
metrics=[predefined_metric_with_autorater_config],
91+
)
92+
93+
assert isinstance(evaluation_result, types.EvaluationResult)
94+
95+
assert evaluation_result.summary_metrics is not None
96+
for summary in evaluation_result.summary_metrics:
97+
assert isinstance(summary, types.AggregatedMetricResult)
98+
assert summary.metric_name == "general_quality_v1"
99+
assert summary.mean_score is not None
100+
101+
assert evaluation_result.eval_case_results is not None
102+
for case_result in evaluation_result.eval_case_results:
103+
assert isinstance(case_result, types.EvalCaseResult)
104+
assert case_result.eval_case_index is not None
105+
assert case_result.response_candidate_results is not None
106+
107+
63108
def test_multi_turn_predefined_metric(client):
64109
"""Tests that evaluate works with multi-turn predefined metrics."""
65110
prompts_data = {

tests/unit/vertexai/genai/test_evals.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,51 @@ def test_eval_evaluate_with_agent_info(self, mock_execute_evaluation):
189189
assert "agent_info" in kwargs
190190
assert kwargs["agent_info"] == agent_info
191191

192+
def test_evaluate_predefined_metric_with_autorater_config(self):
193+
dataset = vertexai_genai_types.EvaluationDataset(
194+
eval_dataset_df=pd.DataFrame([{"prompt": "p1", "response": "r1"}])
195+
)
196+
generation_config = genai_types.GenerationConfig(
197+
temperature=0.1,
198+
max_output_tokens=1024,
199+
)
200+
metrics = [
201+
vertexai_genai_types.RubricMetric.GENERAL_QUALITY(
202+
judge_model_generation_config=generation_config
203+
)
204+
]
205+
206+
mock_prebuilt_metric = vertexai_genai_types.RubricMetric(
207+
name="general_quality_v1",
208+
judge_model_generation_config=generation_config,
209+
)
210+
mock_prebuilt_metric._is_predefined = True
211+
mock_prebuilt_metric._version = "v1"
212+
213+
with mock.patch(
214+
"vertexai._genai._evals_metric_loaders.LazyLoadedPrebuiltMetric._fetch_and_parse",
215+
return_value=mock_prebuilt_metric,
216+
), mock.patch(
217+
"vertexai._genai.evals.Evals._evaluate_instances"
218+
) as mock_evaluate_instances, mock.patch(
219+
"vertexai._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS",
220+
frozenset(["general_quality_v1"]),
221+
):
222+
mock_evaluate_instances.return_value = (
223+
vertexai_genai_types.EvaluateInstancesResponse(
224+
metric_results=[vertexai_genai_types.MetricResult(score=0.9)]
225+
)
226+
)
227+
self.client.evals.evaluate(
228+
dataset=dataset,
229+
metrics=metrics,
230+
)
231+
232+
mock_evaluate_instances.assert_called_once()
233+
_, kwargs = mock_evaluate_instances.call_args
234+
assert "autorater_config" in kwargs
235+
assert kwargs["autorater_config"].generation_config == generation_config
236+
192237

193238
class TestEvalsVisualization:
194239
@mock.patch(
@@ -4990,7 +5035,9 @@ def test_execute_evaluation_adds_creation_timestamp(
49905035
frozenset(["summarization_quality"]),
49915036
)
49925037
@mock.patch("time.sleep", return_value=None)
4993-
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances")
5038+
@mock.patch(
5039+
"vertexai._genai.evals.Evals._evaluate_instances"
5040+
)
49945041
def test_predefined_metric_retry_on_resource_exhausted(
49955042
self,
49965043
mock_private_evaluate_instances,
@@ -5043,7 +5090,9 @@ def test_predefined_metric_retry_on_resource_exhausted(
50435090
frozenset(["summarization_quality"]),
50445091
)
50455092
@mock.patch("time.sleep", return_value=None)
5046-
@mock.patch("vertexai._genai.evals.Evals._evaluate_instances")
5093+
@mock.patch(
5094+
"vertexai._genai.evals.Evals._evaluate_instances"
5095+
)
50475096
def test_predefined_metric_retry_fail_on_resource_exhausted(
50485097
self,
50495098
mock_private_evaluate_instances,

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,10 @@ def _add_autorater_config(self, payload: dict[str, Any]) -> None:
620620
autorater_config = {}
621621
if self.metric.judge_model:
622622
autorater_config["autorater_model"] = self.metric.judge_model
623+
if self.metric.judge_model_generation_config:
624+
autorater_config["generation_config"] = (
625+
self.metric.judge_model_generation_config
626+
)
623627
if self.metric.judge_model_sampling_count:
624628
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count # type: ignore[assignment]
625629

@@ -986,10 +990,25 @@ def _build_request_payload(
986990
agent_data=PredefinedMetricHandler._eval_case_to_agent_data(eval_case),
987991
)
988992

989-
return {
993+
request_payload = {
990994
"instance": instance_payload,
991995
}
992996

997+
autorater_config = {}
998+
if self.metric.judge_model:
999+
autorater_config["autorater_model"] = self.metric.judge_model
1000+
if self.metric.judge_model_generation_config:
1001+
autorater_config["generation_config"] = (
1002+
self.metric.judge_model_generation_config
1003+
)
1004+
if self.metric.judge_model_sampling_count:
1005+
autorater_config["sampling_count"] = self.metric.judge_model_sampling_count
1006+
if autorater_config:
1007+
request_payload["autorater_config"] = genai_types.AutoraterConfig(
1008+
**autorater_config
1009+
)
1010+
return request_payload
1011+
9931012
@override
9941013
def get_metric_result(
9951014
self, eval_case: types.EvalCase, response_index: int
@@ -1001,7 +1020,9 @@ def get_metric_result(
10011020
for attempt in range(_MAX_RETRIES):
10021021
try:
10031022
api_response = self.module._evaluate_instances(
1004-
metrics=[self.metric], instance=payload.get("instance")
1023+
metrics=[self.metric],
1024+
instance=payload.get("instance"),
1025+
autorater_config=payload.get("autorater_config"),
10051026
)
10061027
break
10071028
except genai_errors.ClientError as e:

vertexai/_genai/types/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2622,6 +2622,10 @@ class Metric(_common.BaseModel):
26222622
judge_model: Optional[str] = Field(
26232623
default=None, description="""The judge model for the metric."""
26242624
)
2625+
judge_model_generation_config: Optional[genai_types.GenerationConfig] = Field(
2626+
default=None,
2627+
description="""The generation config for the judge LLM (temperature, top_k, top_p, etc).""",
2628+
)
26252629
judge_model_sampling_count: Optional[int] = Field(
26262630
default=None, description="""The sampling count for the judge model."""
26272631
)
@@ -2825,6 +2829,9 @@ class MetricDict(TypedDict, total=False):
28252829
judge_model: Optional[str]
28262830
"""The judge model for the metric."""
28272831

2832+
judge_model_generation_config: Optional[genai_types.GenerationConfigDict]
2833+
"""The generation config for the judge LLM (temperature, top_k, top_p, etc)."""
2834+
28282835
judge_model_sampling_count: Optional[int]
28292836
"""The sampling count for the judge model."""
28302837

0 commit comments

Comments
 (0)