Skip to content

Commit 8b9ed04

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client(evals) - Add agent run in run_inference
PiperOrigin-RevId: 821741545
1 parent 5b5e6bd commit 8b9ed04

File tree

5 files changed

+644
-40
lines changed

5 files changed

+644
-40
lines changed

tests/unit/vertexai/genai/replays/test_evaluate_instances.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ def test_inference_with_prompt_template(client):
238238
assert inference_result.gcs_source is None
239239

240240

241+
def test_run_inference_with_agent(client):
242+
test_df = pd.DataFrame(
243+
{"prompt": ["agent prompt"], "session_inputs": ['{"user_id": "user_123"}']}
244+
)
245+
inference_result = client.evals.run_inference(
246+
agent="projects/977012026409/locations/us-central1/reasoningEngines/7188347537655332864",
247+
src=test_df,
248+
)
249+
assert inference_result.candidate_name == "agent"
250+
assert inference_result.gcs_source is None
251+
252+
241253
pytestmark = pytest_helper.setup(
242254
file=__file__,
243255
globals_for_file=globals(),

tests/unit/vertexai/genai/test_evals.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ def setup_method(self):
200200
importlib.reload(_evals_metric_handlers)
201201
importlib.reload(_genai.evals)
202202

203+
if hasattr(_evals_common._thread_local_data, "agent_engine_instances"):
204+
del _evals_common._thread_local_data.agent_engine_instances
205+
203206
vertexai.init(
204207
project=_TEST_PROJECT,
205208
location=_TEST_LOCATION,
@@ -967,6 +970,227 @@ def test_inference_with_multimodal_content(
967970
assert inference_result.candidate_name == "gemini-pro"
968971
assert inference_result.gcs_source is None
969972

973+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
974+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
975+
def test_run_inference_with_agent_engine_and_session_inputs_dict(
976+
self,
977+
mock_vertexai_client,
978+
mock_eval_dataset_loader,
979+
):
980+
mock_df = pd.DataFrame(
981+
{
982+
"prompt": ["agent prompt"],
983+
"session_inputs": [
984+
{
985+
"user_id": "123",
986+
"state": {"a": "1"},
987+
}
988+
],
989+
}
990+
)
991+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
992+
orient="records"
993+
)
994+
995+
mock_agent_engine = mock.Mock()
996+
mock_agent_engine.async_create_session = mock.AsyncMock(
997+
return_value={"id": "session1"}
998+
)
999+
stream_query_return_value = [
1000+
{
1001+
"id": "1",
1002+
"content": {"parts": [{"text": "intermediate1"}]},
1003+
"timestamp": 123,
1004+
"author": "model",
1005+
},
1006+
{
1007+
"id": "2",
1008+
"content": {"parts": [{"text": "agent response"}]},
1009+
"timestamp": 124,
1010+
"author": "model",
1011+
},
1012+
]
1013+
1014+
async def _async_iterator(iterable):
1015+
for item in iterable:
1016+
yield item
1017+
1018+
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1019+
stream_query_return_value
1020+
)
1021+
mock_vertexai_client.return_value.agent_engines.get.return_value = (
1022+
mock_agent_engine
1023+
)
1024+
1025+
inference_result = self.client.evals.run_inference(
1026+
agent="projects/test-project/locations/us-central1/reasoningEngines/123",
1027+
src=mock_df,
1028+
)
1029+
1030+
mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df)
1031+
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
1032+
name="projects/test-project/locations/us-central1/reasoningEngines/123"
1033+
)
1034+
mock_agent_engine.async_create_session.assert_called_once_with(
1035+
user_id="123", state={"a": "1"}
1036+
)
1037+
mock_agent_engine.async_stream_query.assert_called_once_with(
1038+
user_id="123", session_id="session1", message="agent prompt"
1039+
)
1040+
1041+
pd.testing.assert_frame_equal(
1042+
inference_result.eval_dataset_df,
1043+
pd.DataFrame(
1044+
{
1045+
"prompt": ["agent prompt"],
1046+
"session_inputs": [
1047+
{
1048+
"user_id": "123",
1049+
"state": {"a": "1"},
1050+
}
1051+
],
1052+
"intermediate_events": [
1053+
[
1054+
{
1055+
"event_id": "1",
1056+
"content": {"parts": [{"text": "intermediate1"}]},
1057+
"creation_timestamp": 123,
1058+
"author": "model",
1059+
}
1060+
]
1061+
],
1062+
"response": ["agent response"],
1063+
}
1064+
),
1065+
)
1066+
assert inference_result.candidate_name == "agent"
1067+
assert inference_result.gcs_source is None
1068+
1069+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
1070+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
1071+
def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
1072+
self,
1073+
mock_vertexai_client,
1074+
mock_eval_dataset_loader,
1075+
):
1076+
session_inputs_str = '{"user_id": "123", "state": {"a": "1"}}'
1077+
mock_df = pd.DataFrame(
1078+
{
1079+
"prompt": ["agent prompt"],
1080+
"session_inputs": [session_inputs_str],
1081+
}
1082+
)
1083+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
1084+
orient="records"
1085+
)
1086+
1087+
mock_agent_engine = mock.Mock()
1088+
mock_agent_engine.async_create_session = mock.AsyncMock(
1089+
return_value={"id": "session1"}
1090+
)
1091+
stream_query_return_value = [
1092+
{
1093+
"id": "1",
1094+
"content": {"parts": [{"text": "intermediate1"}]},
1095+
"timestamp": 123,
1096+
"author": "model",
1097+
},
1098+
{
1099+
"id": "2",
1100+
"content": {"parts": [{"text": "agent response"}]},
1101+
"timestamp": 124,
1102+
"author": "model",
1103+
},
1104+
]
1105+
1106+
async def _async_iterator(iterable):
1107+
for item in iterable:
1108+
yield item
1109+
1110+
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1111+
stream_query_return_value
1112+
)
1113+
mock_vertexai_client.return_value.agent_engines.get.return_value = (
1114+
mock_agent_engine
1115+
)
1116+
1117+
inference_result = self.client.evals.run_inference(
1118+
agent="projects/test-project/locations/us-central1/reasoningEngines/123",
1119+
src=mock_df,
1120+
)
1121+
1122+
mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df)
1123+
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
1124+
name="projects/test-project/locations/us-central1/reasoningEngines/123"
1125+
)
1126+
mock_agent_engine.async_create_session.assert_called_once_with(
1127+
user_id="123", state={"a": "1"}
1128+
)
1129+
mock_agent_engine.async_stream_query.assert_called_once_with(
1130+
user_id="123", session_id="session1", message="agent prompt"
1131+
)
1132+
1133+
pd.testing.assert_frame_equal(
1134+
inference_result.eval_dataset_df,
1135+
pd.DataFrame(
1136+
{
1137+
"prompt": ["agent prompt"],
1138+
"session_inputs": [session_inputs_str],
1139+
"intermediate_events": [
1140+
[
1141+
{
1142+
"event_id": "1",
1143+
"content": {"parts": [{"text": "intermediate1"}]},
1144+
"creation_timestamp": 123,
1145+
"author": "model",
1146+
}
1147+
]
1148+
],
1149+
"response": ["agent response"],
1150+
}
1151+
),
1152+
)
1153+
assert inference_result.candidate_name == "agent"
1154+
assert inference_result.gcs_source is None
1155+
1156+
@mock.patch.object(_evals_utils, "EvalDatasetLoader")
1157+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
1158+
def test_run_inference_with_agent_engine_with_response_column_raises_error(
1159+
self,
1160+
mock_vertexai_client,
1161+
mock_eval_dataset_loader,
1162+
):
1163+
mock_df = pd.DataFrame(
1164+
{
1165+
"prompt": ["agent prompt"],
1166+
"session_inputs": [
1167+
{
1168+
"user_id": "123",
1169+
"state": {"a": "1"},
1170+
}
1171+
],
1172+
"response": ["some response"],
1173+
}
1174+
)
1175+
mock_eval_dataset_loader.return_value.load.return_value = mock_df.to_dict(
1176+
orient="records"
1177+
)
1178+
1179+
mock_agent_engine = mock.Mock()
1180+
mock_vertexai_client.return_value.agent_engines.get.return_value = (
1181+
mock_agent_engine
1182+
)
1183+
1184+
with pytest.raises(ValueError) as excinfo:
1185+
self.client.evals.run_inference(
1186+
agent="projects/test-project/locations/us-central1/reasoningEngines/123",
1187+
src=mock_df,
1188+
)
1189+
assert (
1190+
"The eval dataset provided for agent run should not contain "
1191+
"'intermediate_events' or 'response' columns"
1192+
) in str(excinfo.value)
1193+
9701194
def test_run_inference_with_litellm_string_prompt_format(
9711195
self,
9721196
mock_api_client_fixture,
@@ -1229,6 +1453,102 @@ def test_run_inference_with_litellm_parsing(
12291453
pd.testing.assert_frame_equal(call_kwargs["prompt_dataset"], mock_df)
12301454

12311455

1456+
@pytest.mark.usefixtures("google_auth_mock")
1457+
class TestRunAgentInternal:
1458+
"""Unit tests for the _run_agent_internal function."""
1459+
1460+
def setup_method(self):
1461+
importlib.reload(vertexai_genai_types)
1462+
importlib.reload(_evals_common)
1463+
1464+
@mock.patch.object(_evals_common, "_run_agent")
1465+
def test_run_agent_internal_success(self, mock_run_agent):
1466+
mock_run_agent.return_value = [
1467+
[
1468+
{
1469+
"id": "1",
1470+
"content": {"parts": [{"text": "intermediate1"}]},
1471+
"timestamp": 123,
1472+
"author": "model",
1473+
},
1474+
{
1475+
"id": "2",
1476+
"content": {"parts": [{"text": "final response"}]},
1477+
"timestamp": 124,
1478+
"author": "model",
1479+
},
1480+
]
1481+
]
1482+
prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]})
1483+
mock_agent_engine = mock.Mock()
1484+
mock_api_client = mock.Mock()
1485+
result_df = _evals_common._run_agent_internal(
1486+
api_client=mock_api_client,
1487+
agent_engine=mock_agent_engine,
1488+
prompt_dataset=prompt_dataset,
1489+
)
1490+
1491+
expected_df = pd.DataFrame(
1492+
{
1493+
"prompt": ["prompt1"],
1494+
"intermediate_events": [
1495+
[
1496+
{
1497+
"event_id": "1",
1498+
"content": {"parts": [{"text": "intermediate1"}]},
1499+
"creation_timestamp": 123,
1500+
"author": "model",
1501+
}
1502+
]
1503+
],
1504+
"response": ["final response"],
1505+
}
1506+
)
1507+
pd.testing.assert_frame_equal(result_df, expected_df)
1508+
1509+
@mock.patch.object(_evals_common, "_run_agent")
1510+
def test_run_agent_internal_error_response(self, mock_run_agent):
1511+
mock_run_agent.return_value = [{"error": "agent run failed"}]
1512+
prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]})
1513+
mock_agent_engine = mock.Mock()
1514+
mock_api_client = mock.Mock()
1515+
result_df = _evals_common._run_agent_internal(
1516+
api_client=mock_api_client,
1517+
agent_engine=mock_agent_engine,
1518+
prompt_dataset=prompt_dataset,
1519+
)
1520+
1521+
assert "response" in result_df.columns
1522+
response_content = result_df["response"][0]
1523+
assert "Unexpected response type from agent run" in response_content
1524+
assert not result_df["intermediate_events"][0]
1525+
1526+
@mock.patch.object(_evals_common, "_run_agent")
1527+
def test_run_agent_internal_malformed_event(self, mock_run_agent):
1528+
mock_run_agent.return_value = [
1529+
[
1530+
{
1531+
"id": "1",
1532+
"content": {"parts1": [{"text123": "final response"}]},
1533+
"timestamp": 124,
1534+
"author": "model",
1535+
},
1536+
]
1537+
]
1538+
prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]})
1539+
mock_agent_engine = mock.Mock()
1540+
mock_api_client = mock.Mock()
1541+
result_df = _evals_common._run_agent_internal(
1542+
api_client=mock_api_client,
1543+
agent_engine=mock_agent_engine,
1544+
prompt_dataset=prompt_dataset,
1545+
)
1546+
assert "response" in result_df.columns
1547+
response_content = result_df["response"][0]
1548+
assert "Failed to parse agent run response" in response_content
1549+
assert not result_df["intermediate_events"][0]
1550+
1551+
12321552
class TestMetricPromptBuilder:
12331553
"""Unit tests for the MetricPromptBuilder class."""
12341554

0 commit comments

Comments
 (0)