Skip to content

Commit 3b7c36a

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Support CustomCodeExecution metric in Vertex Gen AI Eval Service
PiperOrigin-RevId: 839489822
1 parent da79e21 commit 3b7c36a

File tree

6 files changed

+520
-4
lines changed

6 files changed

+520
-4
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
import pandas as pd
20+
21+
22+
def test_custom_code_execution(client):
23+
"""Tests that custom code execution metric produces a correctly structured EvaluationResult."""
24+
25+
code_snippet = """
26+
def evaluate(instance):
27+
if instance['response'] == instance['reference']:
28+
return 1.0
29+
return 0.0
30+
"""
31+
32+
custom_metric = types.Metric(
33+
name="my_custom_code_metric",
34+
remote_custom_function=code_snippet,
35+
)
36+
37+
prompts_df = pd.DataFrame(
38+
{
39+
"prompt": ["What is 2+2?", "What is 3+3?"],
40+
"response": ["4", "5"],
41+
"reference": ["4", "6"],
42+
}
43+
)
44+
45+
eval_dataset = types.EvaluationDataset(
46+
eval_dataset_df=prompts_df,
47+
candidate_name="test_model",
48+
)
49+
50+
evaluation_result = client.evals.evaluate(
51+
dataset=eval_dataset,
52+
metrics=[custom_metric],
53+
)
54+
55+
assert isinstance(evaluation_result, types.EvaluationResult)
56+
57+
assert evaluation_result.summary_metrics is not None
58+
assert evaluation_result.summary_metrics
59+
for summary in evaluation_result.summary_metrics:
60+
assert isinstance(summary, types.AggregatedMetricResult)
61+
assert summary.metric_name == "my_custom_code_metric"
62+
63+
assert evaluation_result.eval_case_results is not None
64+
assert evaluation_result.eval_case_results
65+
for case_result in evaluation_result.eval_case_results:
66+
assert isinstance(case_result, types.EvalCaseResult)
67+
assert case_result.eval_case_index is not None
68+
assert case_result.response_candidate_results is not None
69+
70+
71+
def test_custom_code_execution_batch_evaluate(client):
72+
"""Tests that batch_evaluate() works with custom code execution metric."""
73+
74+
code_snippet = """
75+
def evaluate(instance):
76+
if instance['response'] == instance['reference']:
77+
return 1.0
78+
return 0.0
79+
"""
80+
81+
custom_metric = types.Metric(
82+
name="my_custom_code_metric",
83+
remote_custom_function=code_snippet,
84+
)
85+
86+
eval_dataset = types.EvaluationDataset(
87+
gcs_source=types.GcsSource(
88+
uris=["gs://genai-eval-sdk-replay-test/test_data/inference_results.jsonl"]
89+
),
90+
)
91+
92+
evaluation_result = client.evals.batch_evaluate(
93+
dataset=eval_dataset,
94+
metrics=[custom_metric],
95+
dest="gs://genai-eval-sdk-replay-test/test_data/batch_eval_output",
96+
)
97+
98+
assert evaluation_result is not None
99+
100+
101+
pytestmark = pytest_helper.setup(
102+
file=__file__,
103+
globals_for_file=globals(),
104+
test_method="evals.evaluate",
105+
)

vertexai/_genai/_evals_metric_handlers.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,10 +685,9 @@ def get_metric_result(
685685
)
686686
except Exception as e: # pylint: disable=broad-exception-caught
687687
logger.error(
688-
"Error processing metric %s for case %s: %s",
688+
"Error processing metric %s for case %s.",
689689
metric_name,
690690
eval_case.eval_case_id,
691-
e,
692691
exc_info=True,
693692
)
694693
return types.EvalCaseMetricResult(
@@ -1099,7 +1098,147 @@ def aggregate(
10991098
)
11001099

11011100

1101+
class CustomCodeExecutionMetricHandler(MetricHandler):
1102+
"""Metric handler for custom code execution metrics."""
1103+
1104+
def __init__(self, module: "evals.Evals", metric: types.Metric):
1105+
super().__init__(module=module, metric=metric)
1106+
1107+
if not self.metric.remote_custom_function:
1108+
raise ValueError(
1109+
f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs "
1110+
" Metric.remote_custom_function to be set."
1111+
)
1112+
1113+
def _build_request_payload(
1114+
self, eval_case: types.EvalCase, response_index: int
1115+
) -> dict[str, Any]:
1116+
"""Builds the request parameters for evaluate instances request."""
1117+
if not eval_case.responses or response_index >= len(eval_case.responses):
1118+
raise IndexError(f"response_index {response_index} is out of bounds.")
1119+
1120+
response_content = eval_case.responses[response_index].response
1121+
if not response_content:
1122+
raise ValueError(
1123+
f"Response content missing for candidate {response_index}."
1124+
)
1125+
1126+
reference_instance_data = None
1127+
if eval_case.reference:
1128+
reference_instance_data = PredefinedMetricHandler._content_to_instance_data(
1129+
eval_case.reference.response
1130+
)
1131+
1132+
prompt_instance_data = PredefinedMetricHandler._content_to_instance_data(
1133+
eval_case.prompt
1134+
)
1135+
1136+
instance_payload = types.EvaluationInstance(
1137+
prompt=prompt_instance_data,
1138+
response=PredefinedMetricHandler._content_to_instance_data(
1139+
response_content
1140+
),
1141+
reference=reference_instance_data,
1142+
)
1143+
1144+
return {
1145+
"instance": instance_payload,
1146+
}
1147+
1148+
@override
1149+
def get_metric_result(
1150+
self, eval_case: types.EvalCase, response_index: int
1151+
) -> types.EvalCaseMetricResult:
1152+
"""Processes a single evaluation case for a specific custom code execution metric."""
1153+
metric_name = self.metric.name
1154+
try:
1155+
payload = self._build_request_payload(eval_case, response_index)
1156+
for attempt in range(_MAX_RETRIES):
1157+
try:
1158+
api_response = self.module._evaluate_instances(
1159+
metrics=[self.metric],
1160+
instance=payload.get("instance"),
1161+
)
1162+
break
1163+
except genai_errors.ClientError as e:
1164+
if e.code == 429:
1165+
logger.warning(
1166+
"Resource Exhausted error on attempt %d/%d: %s. Retrying in %s"
1167+
" seconds...",
1168+
attempt + 1,
1169+
_MAX_RETRIES,
1170+
e,
1171+
2**attempt,
1172+
)
1173+
if attempt == _MAX_RETRIES - 1:
1174+
return types.EvalCaseMetricResult(
1175+
metric_name=metric_name,
1176+
error_message=f"Resource exhausted after {_MAX_RETRIES} retries: {e}",
1177+
)
1178+
time.sleep(2**attempt)
1179+
else:
1180+
raise e
1181+
1182+
if (
1183+
api_response
1184+
and hasattr(api_response, "metric_results")
1185+
and api_response.metric_results
1186+
):
1187+
result_data = api_response.metric_results[0]
1188+
1189+
error_message = None
1190+
if result_data.error and getattr(result_data.error, "code"):
1191+
error_message = f"Error in metric result: {result_data.error}"
1192+
return types.EvalCaseMetricResult(
1193+
metric_name=metric_name,
1194+
score=result_data.score,
1195+
explanation=result_data.explanation,
1196+
error_message=error_message,
1197+
)
1198+
else:
1199+
logger.error(
1200+
"Metric results missing in API response for metric '%s'."
1201+
" API response: %s",
1202+
metric_name,
1203+
(
1204+
api_response.model_dump_json(exclude_none=True)
1205+
if api_response
1206+
else "None"
1207+
),
1208+
)
1209+
return types.EvalCaseMetricResult(
1210+
metric_name=metric_name,
1211+
error_message="Metric results missing in API response.",
1212+
)
1213+
except Exception as e: # pylint: disable=broad-exception-caught
1214+
logger.error(
1215+
"Error processing metric %s for case %s",
1216+
metric_name,
1217+
eval_case.eval_case_id,
1218+
exc_info=True,
1219+
)
1220+
return types.EvalCaseMetricResult(
1221+
metric_name=metric_name, error_message=str(e)
1222+
)
1223+
1224+
@override
1225+
def aggregate(
1226+
self, eval_case_metric_results: list[types.EvalCaseMetricResult]
1227+
) -> types.AggregatedMetricResult:
1228+
"""Aggregates the metric results for a custom code execution metric."""
1229+
logger.debug(
1230+
"Aggregating results for custom code execution metric: %s", self.metric.name
1231+
)
1232+
return _default_aggregate_scores(
1233+
self.metric.name, eval_case_metric_results, calculate_pass_rate=True
1234+
)
1235+
1236+
11021237
_METRIC_HANDLER_MAPPING = [
1238+
(
1239+
lambda m: hasattr(m, "remote_custom_function") and m.remote_custom_function,
1240+
CustomCodeExecutionMetricHandler,
1241+
),
11031242
(
11041243
lambda m: m.custom_function and isinstance(m.custom_function, Callable),
11051244
CustomMetricHandler,
@@ -1125,6 +1264,7 @@ def aggregate(
11251264
TranslationMetricHandler,
11261265
LLMMetricHandler,
11271266
CustomMetricHandler,
1267+
CustomCodeExecutionMetricHandler,
11281268
PredefinedMetricHandler,
11291269
)
11301270

vertexai/_genai/_transformers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def t_metrics(
6060
"metric_spec_name": metric_name,
6161
"metric_spec_parameters": metric.metric_spec_parameters,
6262
}
63+
# Custom Code Execution Metric
64+
elif (
65+
hasattr(metric, "remote_custom_function") and metric.remote_custom_function
66+
):
67+
metric_payload_item["custom_code_execution_spec"] = {
68+
"evaluation_function": metric.remote_custom_function
69+
}
6370
# Pointwise metrics
6471
elif hasattr(metric, "prompt_template") and metric.prompt_template:
6572
pointwise_spec = {"metric_prompt_template": metric.prompt_template}

0 commit comments

Comments
 (0)