diff --git a/src/llama_stack_client/lib/cli/eval/__init__.py b/src/llama_stack_client/lib/cli/eval/__init__.py new file mode 100644 index 00000000..ab29a9d9 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval import EvalParser + +__all__ = ["EvalParser"] diff --git a/src/llama_stack_client/lib/cli/eval/eval.py b/src/llama_stack_client/lib/cli/eval/eval.py new file mode 100644 index 00000000..30f11966 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/eval.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import argparse + +from llama_stack_client.lib.cli.subcommand import Subcommand + +from .run_benchmark import EvalRunBenchmark + + +class EvalParser(Subcommand): + """Run evaluation benchmark tasks.""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "eval", + prog="llama-stack-client eval", + description="Run evaluation tasks.", + formatter_class=argparse.RawTextHelpFormatter, + ) + + subparsers = self.parser.add_subparsers(title="eval_subcommands") + EvalRunBenchmark.create(subparsers) diff --git a/src/llama_stack_client/lib/cli/eval/run_benchmark.py b/src/llama_stack_client/lib/cli/eval/run_benchmark.py new file mode 100644 index 00000000..4e9f0b82 --- /dev/null +++ b/src/llama_stack_client/lib/cli/eval/run_benchmark.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import json + +import os +from pathlib import Path + +from tqdm import tqdm + +from llama_stack_client import LlamaStackClient +from llama_stack_client.lib.cli.configure import get_config +from llama_stack_client.lib.cli.subcommand import Subcommand + + +class EvalRunBenchmark(Subcommand): + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "run_benchmark", + prog="llama-stack-client eval run_benchmark", + description="Run evaluation benchmark tasks.", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_benchmark_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "--endpoint", + type=str, + help="Llama Stack distribution endpoint", + ) + self.parser.add_argument( + "--benchmark-id", + type=str, + help="Benchmark Task ID", + ) + + def _run_benchmark_cmd(self, args: argparse.Namespace): + args.endpoint = get_config().get("endpoint") or args.endpoint + + client = LlamaStackClient( + base_url=args.endpoint, + provider_data={ + "fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"), + }, + ) + + # eval_tasks_list_response = client.eval_tasks.list() + + # register dataset + # TODO: move this to files for registering benchmarks + # client.datasets.register( + # dataset_id="mmlu", + # dataset_schema={ + # "input_query": { + # "type": "string", + # }, + # "expected_answer": { + # "type": "string", + # }, + # }, + # url={ + # "uri": "https://huggingface.co/datasets/llamastack/evals", + # }, + # metadata={ + # "path": "llamastack/evals", + # "name": "evals__mmlu__details", + # "split": "train", + # }, + # provider_id="huggingface-0", + # ) + # # list datasets + # print(client.datasets.list()) + + # register eval task + # task_id = "meta-reference-mmlu" + # client.eval_tasks.register( + # eval_task_id=task_id, + # dataset_id="mmlu", + # scoring_functions=scoring_functions, + # ) + + scoring_functions = [ + "basic::regex_parser_multiple_choice_answer", + ] + rows = client.datasetio.get_rows_paginated( + dataset_id="mmlu", + rows_in_page=10, + ) + + output_res = { + "chat_completion_input": [], + "generated_output": [], + "expected_output": [], + } + for x in scoring_functions: + output_res[x] = [] + + # run evaluate_rows row by row + # TODO: jobs for background job scheduling + for r in tqdm(rows.rows): + eval_response = client.eval.evaluate_rows( + task_id="meta-reference-mmlu", + input_rows=[r], + scoring_functions=scoring_functions, + task_config={ + "type": "benchmark", + "eval_candidate": { + "type": "model", + "model": "Llama3.2-3B-Instruct", + "sampling_params": { + "strategy": "greedy", + "temperature": 0, + "top_p": 0.95, + "top_k": 0, + "max_tokens": 0, + "repetition_penalty": 1.0, + }, + }, + }, + ) + output_res["chat_completion_input"].append(r["chat_completion_input"]) + output_res["expected_output"].append(r["expected_answer"]) + output_res["generated_output"].append(eval_response.generations[0]["generated_answer"]) + for scoring_fn in scoring_functions: + output_res[scoring_fn].append(eval_response.scores[scoring_fn].score_rows[0]) + + # TODO: specify output file + save_path = Path(os.path.abspath(__file__)).parent / f"eval-result-{task_id}.json" + with open(save_path, "w") as f: + json.dump(output_res, f, indent=4) + + print(f"Eval result saved at {save_path}!") diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 3631aa25..2362a950 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -10,6 +10,7 @@ from .configure import ConfigureParser from .datasets import DatasetsParser +from .eval import EvalParser from .eval_tasks import EvalTasksParser from .memory_banks import MemoryBanksParser @@ -41,6 +42,7 @@ def __init__(self): ProvidersParser.create(subparsers) DatasetsParser.create(subparsers) ScoringFunctionsParser.create(subparsers) + EvalParser.create(subparsers) def parse_args(self) -> argparse.Namespace: return self.parser.parse_args() diff --git a/src/llama_stack_client/lib/cli/models/list.py b/src/llama_stack_client/lib/cli/models/list.py index 42d665e2..998388fa 100644 --- a/src/llama_stack_client/lib/cli/models/list.py +++ b/src/llama_stack_client/lib/cli/models/list.py @@ -40,7 +40,6 @@ def _run_models_list_cmd(self, args: argparse.Namespace): base_url=args.endpoint, ) - headers = ["identifier", "llama_model", "provider_id", "metadata"] response = client.models.list() if response: - print_table_from_response(response, headers) + print_table_from_response(response)