Skip to content

Commit f43e34f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
prefix: Add new GenAI Eval converter
PiperOrigin-RevId: 781578088
1 parent bcdf041 commit f43e34f

File tree

3 files changed

+284
-2
lines changed

3 files changed

+284
-2
lines changed

vertexai/_genai/_evals_data_converters.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class EvalDatasetSchema(_common.CaseInSensitiveEnum):
3535
GEMINI = "gemini"
3636
FLATTEN = "flatten"
3737
OPENAI = "openai"
38+
OBSERVABILITY = "observability"
3839
UNKNOWN = "unknown"
3940

4041

@@ -442,6 +443,179 @@ def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset:
442443
return types.EvaluationDataset(eval_cases=eval_cases)
443444

444445

446+
class _ObservabilityDataConverter(_EvalDataConverter):
447+
"""Converter for dataset in GCP Observability GenAI format."""
448+
449+
def _message_to_content(self, message: dict[str, Any]) -> genai_types.Content:
450+
"""Converts Obs GenAI message format to Content."""
451+
parts = []
452+
message_parts = message.get("parts", [])
453+
if isinstance(message_parts, list):
454+
for message_part in message_parts:
455+
part = None
456+
part_type = message_part.get("type", "")
457+
match part_type:
458+
case "text":
459+
part = genai_types.Part(
460+
text=message_part.get("content", "")
461+
)
462+
463+
case "blob":
464+
part = genai_types.Part(inline_data=genai_types.Blob(
465+
data=message_part.get("data", ""),
466+
mime_type=message_part.get("mime_type", "")
467+
))
468+
469+
case "file_data":
470+
part = genai_types.Part(file_data=genai_types.FileData(
471+
file_uri=message_part.get("file_uri", ""),
472+
mime_type=message_part.get("mime_type", "")
473+
))
474+
475+
case "tool_call":
476+
part = genai_types.Part(
477+
function_call=genai_types.FunctionCall(
478+
id=message_part.get("id", ""),
479+
name=message_part.get("name", ""),
480+
args=message_part.get("arguments", {})
481+
)
482+
)
483+
484+
case "tool_call_response":
485+
part = genai_types.Part(
486+
function_response=genai_types.FunctionResponse(
487+
id=message_part.get("id", ""),
488+
name=message_part.get("name", ""),
489+
response=message_part.get("result", {})
490+
)
491+
)
492+
493+
case _:
494+
logger.warning(
495+
"Unrecgonized message part type of '%s' found."
496+
"Skipping part.",
497+
part_type
498+
)
499+
500+
if part is not None:
501+
parts.append(part)
502+
503+
return genai_types.Content(
504+
parts=parts,
505+
role=message.get("role", "")
506+
)
507+
508+
def _parse_messages(
509+
self,
510+
eval_case_id: str,
511+
input_dict: dict[str, Any],
512+
output_dict: dict[str, Any],
513+
system_dict: Optional[dict[str, Any]] = None
514+
) -> types.EvalCase:
515+
"""Parses a set of messages into an EvalCase."""
516+
517+
# System message
518+
system_instruction = None
519+
if system_dict is not None:
520+
system_msgs = system_dict.get("messages", [])
521+
if system_msgs:
522+
system_instruction = self._message_to_content(system_msgs[0])
523+
524+
# Input message
525+
prompt = None
526+
conversation_history = []
527+
input_msgs = input_dict.get("messages", [])
528+
if input_msgs:
529+
# Extract latest message as prompt
530+
prompt = self._message_to_content(input_msgs[-1])
531+
532+
# All previous messages are history
533+
if len(input_msgs) > 1:
534+
for turn_id, msg in enumerate(input_msgs[:-1]):
535+
conversation_history.append(types.Message(
536+
turn_id=str(turn_id),
537+
content=self._message_to_content(msg),
538+
author=msg.get("role", "")
539+
))
540+
541+
# Output message
542+
responses = []
543+
output_choices = output_dict.get("choices", [])
544+
for choice in output_choices:
545+
response = types.ResponseCandidate(
546+
response=self._message_to_content(choice.get("message", {}))
547+
)
548+
responses.append(response)
549+
550+
return types.EvalCase(
551+
eval_case_id=eval_case_id,
552+
prompt=prompt,
553+
responses=responses,
554+
system_instruction=system_instruction,
555+
conversation_history=conversation_history,
556+
reference=None
557+
)
558+
559+
def _load_raw_data(self, data: Any, case_index: int) -> dict[Any, str]:
560+
"""Loads raw data into dict if possible."""
561+
if isinstance(data, str):
562+
try:
563+
loaded_json = json.loads(data)
564+
if isinstance(loaded_json, dict):
565+
return loaded_json
566+
else:
567+
logger.warning(
568+
"Decoded response JSON is not a dictionary for case"
569+
" %s. Type: %s",
570+
case_index,
571+
type(loaded_json),
572+
)
573+
except json.JSONDecodeError:
574+
logger.warning(
575+
"Could not decode response JSON string for case %s."
576+
" Treating as empty response.",
577+
case_index,
578+
)
579+
elif isinstance(data, dict):
580+
return data
581+
582+
@override
583+
def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset:
584+
"""Converts a list of GCP Observability GenAI data into an EvaluationDataset."""
585+
eval_cases = []
586+
587+
for i, item in enumerate(raw_data):
588+
eval_case_id = f"observability_eval_case_{i}"
589+
590+
if "input" not in item or "output" not in item:
591+
logger.warning(
592+
"Skipping case %s due to missing 'input' or 'output' key.",
593+
i
594+
)
595+
continue
596+
597+
input_data = item.get("input", {})
598+
input_dict = self._load_raw_data(input_data, i)
599+
600+
output_data = item.get("output", {})
601+
output_dict = self._load_raw_data(output_data, i)
602+
603+
system_dict = None
604+
if "system" in item:
605+
system_data = item.get("system", {})
606+
system_dict = self._load_raw_data(system_data, i)
607+
608+
eval_case = self._parse_messages(
609+
eval_case_id,
610+
input_dict,
611+
output_dict,
612+
system_dict
613+
)
614+
eval_cases.append(eval_case)
615+
616+
return types.EvaluationDataset(eval_cases=eval_cases)
617+
618+
445619
def auto_detect_dataset_schema(
446620
raw_dataset: list[dict[str, Any]],
447621
) -> Union[EvalDatasetSchema, str]:
@@ -476,6 +650,11 @@ def auto_detect_dataset_schema(
476650
if "role" in messages_list[0] and "content" in messages_list[0]:
477651
return EvalDatasetSchema.OPENAI
478652

653+
if "format" in keys:
654+
format_content = first_item.get("format", "")
655+
if isinstance(format_content, str) and format_content == "observability":
656+
return EvalDatasetSchema.OBSERVABILITY
657+
479658
if {"prompt", "response"}.issubset(keys) or {
480659
"response",
481660
"reference",
@@ -489,6 +668,7 @@ def auto_detect_dataset_schema(
489668
EvalDatasetSchema.GEMINI: _GeminiEvalDataConverter,
490669
EvalDatasetSchema.FLATTEN: _FlattenEvalDataConverter,
491670
EvalDatasetSchema.OPENAI: _OpenAIDataConverter,
671+
EvalDatasetSchema.OBSERVABILITY: _ObservabilityDataConverter,
492672
}
493673

494674

vertexai/_genai/_evals_visualization.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,9 +490,18 @@ def display_evaluation_result(
490490
processed_df = _preprocess_df_for_json(single_dataset.eval_dataset_df)
491491
if processed_df is not None:
492492
for _, row in processed_df.iterrows():
493-
prompt_key = "request" if "request" in row else "prompt"
493+
prompt_key = "prompt"
494+
if "request" in row:
495+
prompt_key = "request"
496+
elif "input" in row:
497+
prompt_key = "input"
498+
499+
response_key = "response"
500+
if "output" in row:
501+
response_key = "output"
502+
494503
prompt_info = _extract_text_and_raw_json(row.get(prompt_key))
495-
response_info = _extract_text_and_raw_json(row.get("response"))
504+
response_info = _extract_text_and_raw_json(row.get(response_key))
496505
processed_row = {
497506
"prompt_display_text": prompt_info["display_text"],
498507
"prompt_raw_json": prompt_info["raw_json"],

vertexai/_genai/types.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6395,6 +6395,99 @@ def _check_pandas_installed(cls, data: Any) -> Any:
63956395
)
63966396
return data
63976397

6398+
@classmethod
6399+
def load_from_sources(
6400+
cls,
6401+
input_source: str,
6402+
output_source: str,
6403+
system_source: Optional[str] = None,
6404+
client: Optional[Any] = None,
6405+
) -> "EvaluationDataset":
6406+
if (
6407+
not input_source.startswith("gs://")
6408+
or not output_source.startswith("gs://")
6409+
or (
6410+
system_source is not None
6411+
and not system_source.startswith("gs://")
6412+
)
6413+
):
6414+
raise TypeError("Only GCS sources are supported.")
6415+
6416+
try:
6417+
from google.cloud import storage
6418+
6419+
storage_client = storage.Client(
6420+
credentials=client._api_client._credentials if client else None
6421+
)
6422+
6423+
# Input source
6424+
try:
6425+
path_without_prefix = input_source[len("gs://") :]
6426+
bucket_name, blob_path = path_without_prefix.split("/", 1)
6427+
6428+
bucket = storage_client.bucket(bucket_name)
6429+
blob = bucket.blob(blob_path)
6430+
6431+
input_str = blob.download_as_bytes().decode("utf-8")
6432+
except Exception as e:
6433+
raise IOError(
6434+
f"Failed to read from GCS path {input_source}: {e}"
6435+
) from e
6436+
6437+
# Output source
6438+
try:
6439+
path_without_prefix = output_source[len("gs://") :]
6440+
bucket_name, blob_path = path_without_prefix.split("/", 1)
6441+
6442+
bucket = storage_client.bucket(bucket_name)
6443+
blob = bucket.blob(blob_path)
6444+
6445+
output_str = blob.download_as_bytes().decode("utf-8")
6446+
except Exception as e:
6447+
raise IOError(
6448+
f"Failed to read from GCS path {output_source}: {e}"
6449+
) from e
6450+
6451+
# System source
6452+
system_str = ""
6453+
if system_source is not None:
6454+
try:
6455+
path_without_prefix = system_source[len("gs://") :]
6456+
bucket_name, blob_path = path_without_prefix.split("/", 1)
6457+
6458+
bucket = storage_client.bucket(bucket_name)
6459+
blob = bucket.blob(blob_path)
6460+
6461+
system_str = blob.download_as_bytes().decode("utf-8")
6462+
except Exception as e:
6463+
raise IOError(
6464+
f"Failed to read from GCS path {system_str}: {e}"
6465+
) from e
6466+
6467+
except ImportError as e:
6468+
raise ImportError(
6469+
"Reading from GCS requires the 'google-cloud-storage'"
6470+
" library. Please install it with 'pip install"
6471+
" google-cloud-aiplatform[evaluation]'."
6472+
) from e
6473+
6474+
try:
6475+
import pandas as pd
6476+
6477+
eval_dataset_df = pd.DataFrame(
6478+
{
6479+
"format": ["observability"],
6480+
"input": [input_str],
6481+
"output": [output_str],
6482+
"system": [system_str],
6483+
}
6484+
)
6485+
6486+
except ImportError as e:
6487+
raise ImportError("Pandas DataFrame library is required.") from e
6488+
6489+
return EvaluationDataset(eval_dataset_df=eval_dataset_df)
6490+
63986491
def show(self) -> None:
63996492
"""Shows the evaluation dataset."""
64006493
from . import _evals_visualization

0 commit comments

Comments
 (0)