diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 7c4ed635f5..048a8c6142 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -1885,9 +1885,133 @@ def test_run_agent_internal_error_response(self, mock_run_agent): assert "response" in result_df.columns response_content = result_df["response"][0] - assert "Unexpected response type from agent run" in response_content + assert "agent run failed" in response_content assert not result_df["intermediate_events"][0] + @mock.patch.object(_evals_common, "_run_agent") + def test_run_agent_internal_multi_turn_success(self, mock_run_agent): + mock_run_agent.return_value = [ + [ + {"turn_index": 0, "turn_id": "t1", "events": []}, + {"turn_index": 1, "turn_id": "t2", "events": []}, + ] + ] + prompt_dataset = pd.DataFrame({"prompt": ["p1"], "conversation_plan": ["plan"]}) + mock_agent_engine = mock.Mock() + mock_api_client = mock.Mock() + result_df = _evals_common._run_agent_internal( + api_client=mock_api_client, + agent_engine=mock_agent_engine, + agent=None, + prompt_dataset=prompt_dataset, + ) + + assert "agent_data" in result_df.columns + agent_data = result_df["agent_data"][0] + assert agent_data["turns"] == [ + {"turn_index": 0, "turn_id": "t1", "events": []}, + {"turn_index": 1, "turn_id": "t2", "events": []}, + ] + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_with_intermediate_events( + self, + mock_config, + mock_scenario, + mock_simulator, + mock_generator, + mock_session_input, + ): + """Tests that intermediate invocation events (e.g. tool calls) are parsed successfully.""" + row = pd.Series( + { + "starting_prompt": "I want a laptop.", + "conversation_plan": "Ask for a laptop", + "session_inputs": json.dumps({"user_id": "u1"}), + } + ) + mock_agent = mock.Mock() + + mock_invocation = mock.Mock() + mock_invocation.invocation_id = "turn_123" + mock_invocation.creation_timestamp = 1771811084.88 + mock_invocation.user_content.model_dump.return_value = { + "parts": [{"text": "I want a laptop."}], + "role": "user", + } + mock_event_1 = mock.Mock() + mock_event_1.author = "ecommerce_agent" + mock_event_1.content.model_dump.return_value = { + "parts": [ + { + "function_call": { + "name": "search_products", + "args": {"query": "laptop"}, + } + } + ] + } + mock_event_2 = mock.Mock() + mock_event_2.author = "ecommerce_agent" + mock_event_2.content.model_dump.return_value = { + "parts": [ + { + "function_response": { + "name": "search_products", + "response": {"products": []}, + } + } + ] + } + + mock_invocation.intermediate_data.invocation_events = [ + mock_event_1, + mock_event_2, + ] + mock_invocation.final_response.model_dump.return_value = { + "parts": [{"text": "There are no laptops matching your search."}], + "role": "model", + } + mock_generator._generate_inferences_from_root_agent = mock.AsyncMock( + return_value=[mock_invocation] + ) + turns = await _evals_common._run_adk_user_simulation(row, mock_agent) + + assert len(turns) == 1 + turn = turns[0] + assert turn["turn_index"] == 0 + assert turn["turn_id"] == "turn_123" + assert len(turn["events"]) == 4 + assert turn["events"][0]["author"] == "user" + assert turn["events"][0]["content"]["parts"][0]["text"] == "I want a laptop." + assert turn["events"][1]["author"] == "ecommerce_agent" + assert "function_call" in turn["events"][1]["content"]["parts"][0] + assert turn["events"][2]["author"] == "ecommerce_agent" + assert "function_response" in turn["events"][2]["content"]["parts"][0] + assert turn["events"][3]["author"] == "agent" + assert ( + turn["events"][3]["content"]["parts"][0]["text"] + == "There are no laptops matching your search." + ) + mock_invocation.user_content.model_dump.assert_called_with(mode="json") + mock_event_1.content.model_dump.assert_called_with(mode="json") + mock_invocation.final_response.model_dump.assert_called_with(mode="json") + @mock.patch.object(_evals_common, "_run_agent") def test_run_agent_internal_malformed_event(self, mock_run_agent): mock_run_agent.return_value = [ @@ -1915,6 +2039,28 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent): assert not result_df["intermediate_events"][0] +class TestIsMultiTurnAgentRun: + """Unit tests for the _is_multi_turn_agent_run function.""" + + def test_is_multi_turn_agent_run_with_config(self): + config = vertexai_genai_types.UserSimulatorConfig(model_name="gemini-pro") + assert _evals_common._is_multi_turn_agent_run( + user_simulator_config=config, prompt_dataset=pd.DataFrame() + ) + + def test_is_multi_turn_agent_run_with_conversation_plan(self): + prompt_dataset = pd.DataFrame({"conversation_plan": ["plan"]}) + assert _evals_common._is_multi_turn_agent_run( + user_simulator_config=None, prompt_dataset=prompt_dataset + ) + + def test_is_multi_turn_agent_run_false(self): + prompt_dataset = pd.DataFrame({"prompt": ["prompt"]}) + assert not _evals_common._is_multi_turn_agent_run( + user_simulator_config=None, prompt_dataset=prompt_dataset + ) + + class TestMetricPromptBuilder: """Unit tests for the MetricPromptBuilder class.""" @@ -4228,6 +4374,101 @@ def test_tool_use_quality_metric_no_tool_call_logs_warning( ) +@pytest.mark.usefixtures("google_auth_mock") +class TestRunAdkUserSimulation: + """Unit tests for the _run_adk_user_simulation function.""" + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_success( + self, + mock_config_cls, + mock_scenario_cls, + mock_simulator_cls, + mock_generator_cls, + mock_session_input_cls, + ): + row = pd.Series( + { + "starting_prompt": "start", + "conversation_plan": "plan", + "session_inputs": json.dumps({"user_id": "u1"}), + } + ) + mock_agent = mock.Mock() + mock_invocation = mock.Mock() + mock_invocation.user_content.model_dump.return_value = {"text": "user msg"} + mock_invocation.final_response.model_dump.return_value = {"text": "agent msg"} + mock_invocation.intermediate_data = None + mock_invocation.creation_timestamp = 12345 + mock_invocation.invocation_id = "turn1" + + mock_generator_cls._generate_inferences_from_root_agent = mock.AsyncMock( + return_value=[mock_invocation] + ) + + turns = await _evals_common._run_adk_user_simulation(row, mock_agent) + + assert len(turns) == 1 + turn = turns[0] + assert turn["turn_index"] == 0 + assert turn["turn_id"] == "turn1" + assert len(turn["events"]) == 2 + assert turn["events"][0]["author"] == "user" + assert turn["events"][0]["content"] == {"text": "user msg"} + assert turn["events"][1]["author"] == "agent" + assert turn["events"][1]["content"] == {"text": "agent msg"} + + mock_scenario_cls.assert_called_once_with( + starting_prompt="start", conversation_plan="plan" + ) + mock_session_input_cls.assert_called_once() + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_missing_columns( + self, + mock_config_cls, + mock_scenario_cls, + mock_simulator_cls, + mock_generator_cls, + mock_session_input_cls, + ): + row = pd.Series({"conversation_plan": "plan"}) + mock_agent = mock.Mock() + + with pytest.raises(ValueError, match="User simulation requires"): + await _evals_common._run_adk_user_simulation(row, mock_agent) + + @pytest.mark.usefixtures("google_auth_mock") class TestLLMMetricHandlerPayload: def setup_method(self): diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index f33320324a..f0e6965e89 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -54,10 +54,24 @@ from google.adk.agents import LlmAgent from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulator, + ) + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulatorConfig, + ) + from google.adk.evaluation.conversation_scenarios import ConversationScenario + from google.adk.evaluation.evaluation_generator import EvaluationGenerator + from google.adk.evaluation.eval_case import SessionInput as ADK_SessionInput except ImportError: LlmAgent = None Runner = None InMemorySessionService = None + LlmBackedUserSimulator = None + LlmBackedUserSimulatorConfig = None + ConversationScenario = None + EvaluationGenerator = None + ADK_SessionInput = None logger = logging.getLogger(__name__) @@ -267,6 +281,7 @@ def _execute_inference_concurrently( inference_fn: Optional[Callable[..., Any]] = None, agent_engine: Optional[Union[str, types.AgentEngine]] = None, agent: Optional[LlmAgent] = None, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> list[ Union[ genai_types.GenerateContentResponse, @@ -325,6 +340,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] agent_arg, inference_fn_arg, api_client_arg, + user_simulator_config_arg, ) -> Any: if agent_engine_arg: if isinstance(agent_engine_arg, str): @@ -343,6 +359,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] return inference_fn_arg( row=row_arg, contents=contents_arg, + user_simulator_config=user_simulator_config_arg, agent=agent_arg, ) @@ -354,6 +371,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] agent, inference_fn, api_client, + user_simulator_config, ) elif isinstance(model_or_fn, str): generation_content_config = _build_generate_content_config( @@ -720,6 +738,115 @@ def _run_inference_internal( return results_df +async def _run_adk_user_simulation( + row: pd.Series, + agent: LlmAgent, + config: Optional[types.UserSimulatorConfig] = None, +) -> list[dict[str, Any]]: + """Runs a multi-turn user simulation using ADK's EvaluationGenerator.""" + + starting_prompt = row.get("starting_prompt") + conversation_plan = row.get("conversation_plan") + + if not starting_prompt or not conversation_plan: + raise ValueError( + "User simulation requires 'starting_prompt' and 'conversation_plan'" + " columns." + ) + + scenario = ConversationScenario( + starting_prompt=starting_prompt, conversation_plan=conversation_plan + ) + + user_simulator_kwargs = {} + if config: + if config.model_name: + user_simulator_kwargs["model"] = config.model_name + if config.model_configuration is not None: + user_simulator_kwargs["model_configuration"] = config.model_configuration + if config.max_turn is not None: + user_simulator_kwargs["max_allowed_invocations"] = config.max_turn + + user_simulator_config = LlmBackedUserSimulatorConfig(**user_simulator_kwargs) + user_simulator = LlmBackedUserSimulator( + conversation_scenario=scenario, config=user_simulator_config + ) + + initial_session = _get_session_inputs(row) + + invocations = await EvaluationGenerator._generate_inferences_from_root_agent( # pylint: disable=protected-access + root_agent=agent, + user_simulator=user_simulator, + reset_func=getattr(agent, "reset_data", None), + initial_session=ADK_SessionInput( + app_name=initial_session.app_name or "user_simulation_app", + user_id=initial_session.user_id or "user_simulation_default_user", + state=initial_session.state or {}, + ), + ) + + turns = [] + for i, invocation in enumerate(invocations): + events = [] + if invocation.user_content: + events.append( + { + "id": str(uuid.uuid4()), + "author": "user", + "content": invocation.user_content.model_dump(mode="json"), + "timestamp": invocation.creation_timestamp, + } + ) + if invocation.intermediate_data: + if ( + hasattr(invocation.intermediate_data, "invocation_events") + and invocation.intermediate_data.invocation_events + ): + for ie in invocation.intermediate_data.invocation_events: + events.append( + { + "id": str(uuid.uuid4()), + "author": ie.author, + "content": ( + ie.content.model_dump(mode="json") + if ie.content + else None + ), + "timestamp": invocation.creation_timestamp, + } + ) + elif hasattr(invocation.intermediate_data, "tool_uses"): + for tool_call in invocation.intermediate_data.tool_uses: + events.append( + { + "id": str(uuid.uuid4()), + "author": "tool_call", + "content": tool_call.model_dump(mode="json"), + "timestamp": invocation.creation_timestamp, + } + ) + + if invocation.final_response: + events.append( + { + "id": str(uuid.uuid4()), + "author": "agent", + "content": invocation.final_response.model_dump(mode="json"), + "timestamp": invocation.creation_timestamp, + } + ) + + turns.append( + { + "turn_index": i, + "turn_id": invocation.invocation_id or str(uuid.uuid4()), + "events": events, + } + ) + + return turns + + def _apply_prompt_template( df: pd.DataFrame, prompt_template: types.PromptTemplate ) -> None: @@ -786,6 +913,7 @@ def _execute_inference( config: Optional[genai_types.GenerateContentConfig] = None, prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None, location: Optional[str] = None, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> pd.DataFrame: """Executes inference on a given dataset using the specified model. @@ -804,6 +932,8 @@ def _execute_inference( prompt_template: The prompt template to use for inference. location: The location to use for the inference. If not specified, the location configured in the client will be used. + user_simulator_config: The configuration for the user simulator in + multi-turn agent scraping. Returns: A pandas DataFrame containing the inference results. @@ -882,6 +1012,7 @@ def _execute_inference( agent_engine=agent_engine, agent=agent, prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, ) end_time = time.time() logger.info("Agent Run completed in %.2f seconds.", end_time - start_time) @@ -1286,11 +1417,23 @@ def _get_session_inputs(row: pd.Series) -> types.evals.SessionInput: ) +def _is_multi_turn_agent_run( + user_simulator_config: Optional[types.UserSimulatorConfig] = None, + prompt_dataset: pd.DataFrame = None, +) -> bool: + """Checks if the agent run is multi-turn.""" + return ( + user_simulator_config is not None + or "conversation_plan" in prompt_dataset.columns + ) + + def _run_agent_internal( api_client: BaseApiClient, agent_engine: Optional[Union[str, types.AgentEngine]], agent: Optional[LlmAgent], prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> pd.DataFrame: """Runs an agent.""" raw_responses = _run_agent( @@ -1298,32 +1441,41 @@ def _run_agent_internal( agent_engine=agent_engine, agent=agent, prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, ) processed_intermediate_events = [] processed_responses = [] + processed_agent_data = [] + for resp_item in raw_responses: intermediate_events_row: list[dict[str, Any]] = [] response_row = None - if isinstance(resp_item, list): + agent_data_row = None + + if _is_multi_turn_agent_run(user_simulator_config, prompt_dataset): + agent_data_row = {"turns": resp_item} + elif isinstance(resp_item, list): try: response_row = resp_item[-1]["content"]["parts"][0]["text"] for intermediate_event in resp_item[:-1]: intermediate_events_row.append( { - "event_id": intermediate_event["id"], - "content": intermediate_event["content"], - "creation_timestamp": intermediate_event["timestamp"], - "author": intermediate_event["author"], + "event_id": intermediate_event.get("id"), + "content": intermediate_event.get("content"), + "creation_timestamp": intermediate_event.get("timestamp"), + "author": intermediate_event.get("author"), } ) except Exception as e: # pylint: disable=broad-exception-caught error_payload = { "error": ( f"Failed to parse agent run response {str(resp_item)} to " - f"intermediate events and final response: {e}" + f"agent data: {e}" ), } response_row = json.dumps(error_payload) + elif isinstance(resp_item, dict) and "error" in resp_item: + response_row = json.dumps(resp_item) else: error_payload = { "error": "Unexpected response type from agent run", @@ -1334,30 +1486,42 @@ def _run_agent_internal( processed_intermediate_events.append(intermediate_events_row) processed_responses.append(response_row) - - if len(processed_responses) != len(prompt_dataset) or len( - processed_responses - ) != len(processed_intermediate_events): - raise RuntimeError( - "Critical prompt/response/intermediate_events count mismatch: %d" - " prompts vs %d vs %d responses. This indicates an issue in response" - " collection." - % ( - len(prompt_dataset), - len(processed_responses), - len(processed_intermediate_events), + processed_agent_data.append(agent_data_row) + + df_dict = {} + if _is_multi_turn_agent_run(user_simulator_config, prompt_dataset): + df_dict["agent_data"] = processed_agent_data + if len(processed_agent_data) != len(prompt_dataset): + raise RuntimeError( + "Critical prompt/agent_data count mismatch: %d" + " prompts vs %d agent_data. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_agent_data), + ) + ) + else: + df_dict[_evals_constant.INTERMEDIATE_EVENTS] = processed_intermediate_events + df_dict[_evals_constant.RESPONSE] = processed_responses + if len(processed_responses) != len(prompt_dataset) or len( + processed_responses + ) != len(processed_intermediate_events): + raise RuntimeError( + "Critical prompt/response/intermediate_events count mismatch: %d" + " prompts vs %d vs %d responses. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_responses), + len(processed_intermediate_events), + ) ) - ) - results_df_responses_only = pd.DataFrame( - { - _evals_constant.INTERMEDIATE_EVENTS: processed_intermediate_events, - _evals_constant.RESPONSE: processed_responses, - } - ) + results_df_raw = pd.DataFrame(df_dict) prompt_dataset_indexed = prompt_dataset.reset_index(drop=True) - results_df_responses_only_indexed = results_df_responses_only.reset_index(drop=True) + results_df_responses_only_indexed = results_df_raw.reset_index(drop=True) results_df = pd.concat( [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 @@ -1370,6 +1534,7 @@ def _run_agent( agent_engine: Optional[Union[str, types.AgentEngine]], agent: Optional[LlmAgent], prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> list[ Union[ list[dict[str, Any]], @@ -1385,6 +1550,7 @@ def _run_agent( prompt_dataset=prompt_dataset, progress_desc="Agent Run", gemini_config=None, + user_simulator_config=None, inference_fn=_execute_agent_run_with_retry, ) elif agent: @@ -1394,6 +1560,7 @@ def _run_agent( prompt_dataset=prompt_dataset, progress_desc="Local Agent Run", gemini_config=None, + user_simulator_config=user_simulator_config, inference_fn=_execute_local_agent_run_with_retry, ) else: @@ -1461,10 +1628,13 @@ def _execute_local_agent_run_with_retry( contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent: LlmAgent, max_retries: int = 3, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run locally for a single prompt synchronously.""" return asyncio.run( - _execute_local_agent_run_with_retry_async(row, contents, agent, max_retries) + _execute_local_agent_run_with_retry_async( + row, contents, agent, max_retries, user_simulator_config + ) ) @@ -1473,8 +1643,20 @@ async def _execute_local_agent_run_with_retry_async( contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent: LlmAgent, max_retries: int = 3, + user_simulator_config: Optional[types.UserSimulatorConfig] = None, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run locally for a single prompt asynchronously.""" + + # Multi-turn agent scraping with user simulation. + if user_simulator_config or "conversation_plan" in row: + try: + return await _run_adk_user_simulation(row, agent, user_simulator_config) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("User simulation failed: %s", e) + return { + "error": f"Multi-turn agent scraping with user simulation failed: {e}" + } + session_inputs = _get_session_inputs(row) user_id = session_inputs.user_id session_id = str(uuid.uuid4()) diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 3632628b87..8890072b89 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -1206,6 +1206,7 @@ def run_inference( prompt_template=config.prompt_template, location=location, config=config.generate_content_config, + user_simulator_config=config.user_simulator_config, ) def evaluate( diff --git a/vertexai/_genai/types/__init__.py b/vertexai/_genai/types/__init__.py index 8b02bc222c..ea91adbc4b 100644 --- a/vertexai/_genai/types/__init__.py +++ b/vertexai/_genai/types/__init__.py @@ -1042,6 +1042,12 @@ from .common import UpdateMultimodalDatasetConfig from .common import UpdateMultimodalDatasetConfigDict from .common import UpdateMultimodalDatasetConfigOrDict +from .common import UserScenario +from .common import UserScenarioDict +from .common import UserScenarioOrDict +from .common import UserSimulatorConfig +from .common import UserSimulatorConfigDict +from .common import UserSimulatorConfigOrDict from .common import VertexBaseConfig from .common import VertexBaseConfigDict from .common import VertexBaseConfigOrDict @@ -1059,6 +1065,9 @@ "PromptTemplateData", "PromptTemplateDataDict", "PromptTemplateDataOrDict", + "UserScenario", + "UserScenarioDict", + "UserScenarioOrDict", "EvaluationPrompt", "EvaluationPromptDict", "EvaluationPromptOrDict", @@ -1155,6 +1164,9 @@ "EvaluationRunAgentConfig", "EvaluationRunAgentConfigDict", "EvaluationRunAgentConfigOrDict", + "UserSimulatorConfig", + "UserSimulatorConfigDict", + "UserSimulatorConfigOrDict", "EvaluationRunInferenceConfig", "EvaluationRunInferenceConfigDict", "EvaluationRunInferenceConfigOrDict", diff --git a/vertexai/_genai/types/common.py b/vertexai/_genai/types/common.py index 2ec662eded..5c2553f94d 100644 --- a/vertexai/_genai/types/common.py +++ b/vertexai/_genai/types/common.py @@ -490,6 +490,32 @@ class PromptTemplateDataDict(TypedDict, total=False): PromptTemplateDataOrDict = Union[PromptTemplateData, PromptTemplateDataDict] +class UserScenario(_common.BaseModel): + """User scenario to help simulate multi-turn agent running results.""" + + starting_prompt: Optional[str] = Field( + default=None, + description="""The prompt that starts the conversation between the simulated user and the agent under test.""", + ) + conversation_plan: Optional[str] = Field( + default=None, + description="""The plan for the conversation, used to drive the multi-turn agent run and generate the simulated agent evaluation dataset.""", + ) + + +class UserScenarioDict(TypedDict, total=False): + """User scenario to help simulate multi-turn agent running results.""" + + starting_prompt: Optional[str] + """The prompt that starts the conversation between the simulated user and the agent under test.""" + + conversation_plan: Optional[str] + """The plan for the conversation, used to drive the multi-turn agent run and generate the simulated agent evaluation dataset.""" + + +UserScenarioOrDict = Union[UserScenario, UserScenarioDict] + + class EvaluationPrompt(_common.BaseModel): """Represents the prompt to be evaluated.""" @@ -501,6 +527,10 @@ class EvaluationPrompt(_common.BaseModel): prompt_template_data: Optional[PromptTemplateData] = Field( default=None, description="""Prompt template data.""" ) + user_scenario: Optional[UserScenario] = Field( + default=None, + description="""User scenario to help simulate multi-turn agent running results.""", + ) class EvaluationPromptDict(TypedDict, total=False): @@ -515,6 +545,9 @@ class EvaluationPromptDict(TypedDict, total=False): prompt_template_data: Optional[PromptTemplateDataDict] """Prompt template data.""" + user_scenario: Optional[UserScenarioDict] + """User scenario to help simulate multi-turn agent running results.""" + EvaluationPromptOrDict = Union[EvaluationPrompt, EvaluationPromptDict] @@ -1830,6 +1863,38 @@ class EvaluationRunAgentConfigDict(TypedDict, total=False): ] +class UserSimulatorConfig(_common.BaseModel): + """Configuration for a user simulator that uses an LLM to generate messages.""" + + model_name: Optional[str] = Field( + default=None, + description="""The model name to use for multi-turn agent scraping.""", + ) + model_configuration: Optional[genai_types.GenerateContentConfig] = Field( + default=None, description="""The configuration for the model.""" + ) + max_turn: Optional[int] = Field( + default=None, + description="""Maximum number of invocations allowed. Stops run-off conversations.""", + ) + + +class UserSimulatorConfigDict(TypedDict, total=False): + """Configuration for a user simulator that uses an LLM to generate messages.""" + + model_name: Optional[str] + """The model name to use for multi-turn agent scraping.""" + + model_configuration: Optional[genai_types.GenerateContentConfigDict] + """The configuration for the model.""" + + max_turn: Optional[int] + """Maximum number of invocations allowed. Stops run-off conversations.""" + + +UserSimulatorConfigOrDict = Union[UserSimulatorConfig, UserSimulatorConfigDict] + + class EvaluationRunInferenceConfig(_common.BaseModel): """This field is experimental and may change in future versions. @@ -1843,6 +1908,10 @@ class EvaluationRunInferenceConfig(_common.BaseModel): default=None, description="""The fully qualified name of the publisher model or endpoint to use for inference.""", ) + user_simulator_config: Optional[UserSimulatorConfig] = Field( + default=None, + description="""Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains conversation plans, user simulation will be triggered.""", + ) class EvaluationRunInferenceConfigDict(TypedDict, total=False): @@ -1857,6 +1926,9 @@ class EvaluationRunInferenceConfigDict(TypedDict, total=False): model: Optional[str] """The fully qualified name of the publisher model or endpoint to use for inference.""" + user_simulator_config: Optional[UserSimulatorConfigDict] + """Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains conversation plans, user simulation will be triggered.""" + EvaluationRunInferenceConfigOrDict = Union[ EvaluationRunInferenceConfig, EvaluationRunInferenceConfigDict @@ -13546,6 +13618,11 @@ class EvalRunInferenceConfig(_common.BaseModel): generate_content_config: Optional[genai_types.GenerateContentConfig] = Field( default=None, description="""The config for the generate content call.""" ) + user_simulator_config: Optional[UserSimulatorConfig] = Field( + default=None, + description="""Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains + conversation plans, user simulation will be triggered.""", + ) class EvalRunInferenceConfigDict(TypedDict, total=False): @@ -13560,6 +13637,10 @@ class EvalRunInferenceConfigDict(TypedDict, total=False): generate_content_config: Optional[genai_types.GenerateContentConfigDict] """The config for the generate content call.""" + user_simulator_config: Optional[UserSimulatorConfigDict] + """Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains + conversation plans, user simulation will be triggered.""" + EvalRunInferenceConfigOrDict = Union[EvalRunInferenceConfig, EvalRunInferenceConfigDict]