Skip to content

Commit 1c792c7

Browse files
jsondaicopybara-github
authored andcommitted
feat: GenAI Client(evals) - Add location override parameter to run_inference and evaluate methods
PiperOrigin-RevId: 836360793
1 parent f1df7a2 commit 1c792c7

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def mock_api_client_fixture():
6969
)
7070
mock_client._credentials.universe_domain = "googleapis.com"
7171
mock_client._evals_client = mock.Mock(spec=evals.Evals)
72+
mock_client._http_options = None
7273
return mock_client
7374

7475

@@ -139,6 +140,40 @@ def mock_evaluate_instances_side_effect(*args, **kwargs):
139140
}
140141

141142

143+
class TestGetApiClientWithLocation:
144+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
145+
def test_get_api_client_with_location_override(
146+
self, mock_vertexai_client, mock_api_client_fixture
147+
):
148+
mock_api_client_fixture.location = "us-central1"
149+
new_location = "europe-west1"
150+
_evals_common._get_api_client_with_location(mock_api_client_fixture, new_location)
151+
mock_vertexai_client.assert_called_once_with(
152+
project=mock_api_client_fixture.project,
153+
location=new_location,
154+
credentials=mock_api_client_fixture._credentials,
155+
http_options=mock_api_client_fixture._http_options,
156+
)
157+
158+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
159+
def test_get_api_client_with_same_location(
160+
self, mock_vertexai_client, mock_api_client_fixture
161+
):
162+
mock_api_client_fixture.location = "us-central1"
163+
new_location = "us-central1"
164+
_evals_common._get_api_client_with_location(mock_api_client_fixture, new_location)
165+
mock_vertexai_client.assert_not_called()
166+
167+
@mock.patch("vertexai._genai._evals_common.vertexai.Client")
168+
def test_get_api_client_with_none_location(
169+
self, mock_vertexai_client, mock_api_client_fixture
170+
):
171+
mock_api_client_fixture.location = "us-central1"
172+
new_location = None
173+
_evals_common._get_api_client_with_location(mock_api_client_fixture, new_location)
174+
mock_vertexai_client.assert_not_called()
175+
176+
142177
class TestEvals:
143178
"""Unit tests for the GenAI client."""
144179

vertexai/_genai/_evals_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,22 @@
5656
AGENT_MAX_WORKERS = 10
5757

5858

59+
def _get_api_client_with_location(
60+
api_client: BaseApiClient, location: Optional[str]
61+
) -> BaseApiClient:
62+
"""Returns a new API client with the specified location."""
63+
if not location or location == api_client.location:
64+
return api_client
65+
66+
logger.info("Overriding location from %s to %s", api_client.location, location)
67+
return vertexai.Client(
68+
project=api_client.project,
69+
location=location,
70+
credentials=api_client._credentials,
71+
http_options=api_client._http_options,
72+
)._api_client
73+
74+
5975
def _get_agent_engine_instance(
6076
agent_name: str, api_client: BaseApiClient
6177
) -> Union[types.AgentEngine, Any]:
@@ -715,6 +731,7 @@ def _execute_inference(
715731
dest: Optional[str] = None,
716732
config: Optional[genai_types.GenerateContentConfig] = None,
717733
prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None,
734+
location: Optional[str] = None,
718735
) -> pd.DataFrame:
719736
"""Executes inference on a given dataset using the specified model.
720737
@@ -730,12 +747,18 @@ def _execute_inference(
730747
representing a file path or a GCS URI.
731748
config: The generation configuration for the model.
732749
prompt_template: The prompt template to use for inference.
750+
location: The location to use for the inference. If not specified, the
751+
location configured in the client will be used.
733752
734753
Returns:
735754
A pandas DataFrame containing the inference results.
736755
"""
737756
if not api_client:
738757
raise ValueError("'api_client' instance must be provided.")
758+
759+
if location:
760+
api_client = _get_api_client_with_location(api_client, location)
761+
739762
prompt_dataset = _load_dataframe(api_client, src)
740763
if prompt_template:
741764
logger.info("Applying prompt template...")
@@ -1056,6 +1079,7 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10561079
metrics: list[types.Metric],
10571080
dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None,
10581081
dest: Optional[str] = None,
1082+
location: Optional[str] = None,
10591083
**kwargs,
10601084
) -> types.EvaluationResult:
10611085
"""Evaluates a dataset using the provided metrics.
@@ -1066,12 +1090,17 @@ def _execute_evaluation( # type: ignore[no-untyped-def]
10661090
metrics: The metrics to evaluate the dataset against.
10671091
dataset_schema: The schema of the dataset.
10681092
dest: The destination to save the evaluation results.
1093+
location: The location to use for the evaluation. If not specified, the
1094+
location configured in the client will be used.
10691095
**kwargs: Extra arguments to pass to evaluation, such as `agent_info`.
10701096
10711097
Returns:
10721098
The evaluation result.
10731099
"""
10741100

1101+
if location:
1102+
api_client = _get_api_client_with_location(api_client, location)
1103+
10751104
logger.info("Preparing dataset(s) and metrics...")
10761105
if isinstance(dataset, types.EvaluationDataset):
10771106
dataset_list = [dataset]

vertexai/_genai/evals.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,7 @@ def run_inference(
904904
src: Union[str, pd.DataFrame, types.EvaluationDataset],
905905
model: Optional[Union[str, Callable[[Any], Any]]] = None,
906906
agent: Optional[Union[str, types.AgentEngine]] = None,
907+
location: Optional[str] = None,
907908
config: Optional[types.EvalRunInferenceConfigOrDict] = None,
908909
) -> types.EvaluationDataset:
909910
"""Runs inference on a dataset for evaluation.
@@ -928,6 +929,10 @@ def run_inference(
928929
`projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`,
929930
run_inference will fetch the agent engine from the resource name.
930931
- Or `types.AgentEngine` object.
932+
location: The location to use for the inference. If not specified, the
933+
location configured in the client will be used. If specified,
934+
this will override the location set in `vertexai.Client` only
935+
for this API call.
931936
config: The optional configuration for the inference run. Must be a dict or
932937
`types.EvalRunInferenceConfig` type.
933938
- dest: The destination path for storage of the inference results.
@@ -955,8 +960,9 @@ def run_inference(
955960
agent_engine=agent,
956961
src=src,
957962
dest=config.dest,
958-
config=config.generate_content_config,
959963
prompt_template=config.prompt_template,
964+
location=location,
965+
config=config.generate_content_config,
960966
)
961967

962968
def evaluate(
@@ -968,6 +974,7 @@ def evaluate(
968974
list[types.EvaluationDatasetOrDict],
969975
],
970976
metrics: list[types.MetricOrDict] = None,
977+
location: Optional[str] = None,
971978
config: Optional[types.EvaluateMethodConfigOrDict] = None,
972979
**kwargs,
973980
) -> types.EvaluationResult:
@@ -977,6 +984,10 @@ def evaluate(
977984
dataset: The dataset(s) to evaluate. Can be a pandas DataFrame, a single
978985
`types.EvaluationDataset` or a list of `types.EvaluationDataset`.
979986
metrics: The list of metrics to use for evaluation.
987+
location: The location to use for the evaluation service. If not specified,
988+
the location configured in the client will be used. If specified,
989+
this will override the location set in `vertexai.Client` only for
990+
this API call.
980991
config: Optional configuration for the evaluation. Can be a dictionary or a
981992
`types.EvaluateMethodConfig` object.
982993
- dataset_schema: Schema to use for the dataset. If not specified, the
@@ -1022,6 +1033,7 @@ def evaluate(
10221033
metrics=metrics,
10231034
dataset_schema=config.dataset_schema,
10241035
dest=config.dest,
1036+
location=location,
10251037
**kwargs,
10261038
)
10271039

0 commit comments

Comments
 (0)