Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .eval import EvalParser

__all__ = ["EvalParser"]
28 changes: 28 additions & 0 deletions src/llama_stack_client/lib/cli/eval/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


import argparse

from llama_stack_client.lib.cli.subcommand import Subcommand

from .run_benchmark import EvalRunBenchmark


class EvalParser(Subcommand):
"""Run evaluation benchmark tasks."""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"eval",
prog="llama-stack-client eval",
description="Run evaluation tasks.",
formatter_class=argparse.RawTextHelpFormatter,
)

subparsers = self.parser.add_subparsers(title="eval_subcommands")
EvalRunBenchmark.create(subparsers)
139 changes: 139 additions & 0 deletions src/llama_stack_client/lib/cli/eval/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import json

import os
from pathlib import Path

from tqdm import tqdm

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand


class EvalRunBenchmark(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"run_benchmark",
prog="llama-stack-client eval run_benchmark",
description="Run evaluation benchmark tasks.",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_benchmark_cmd)

def _add_arguments(self):
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
)
self.parser.add_argument(
"--benchmark-id",
type=str,
help="Benchmark Task ID",
)

def _run_benchmark_cmd(self, args: argparse.Namespace):
args.endpoint = get_config().get("endpoint") or args.endpoint

client = LlamaStackClient(
base_url=args.endpoint,
provider_data={
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY"),
},
)

# eval_tasks_list_response = client.eval_tasks.list()

# register dataset
# TODO: move this to files for registering benchmarks
# client.datasets.register(
# dataset_id="mmlu",
# dataset_schema={
# "input_query": {
# "type": "string",
# },
# "expected_answer": {
# "type": "string",
# },
# },
# url={
# "uri": "https://huggingface.co/datasets/llamastack/evals",
# },
# metadata={
# "path": "llamastack/evals",
# "name": "evals__mmlu__details",
# "split": "train",
# },
# provider_id="huggingface-0",
# )
# # list datasets
# print(client.datasets.list())

# register eval task
# task_id = "meta-reference-mmlu"
# client.eval_tasks.register(
# eval_task_id=task_id,
# dataset_id="mmlu",
# scoring_functions=scoring_functions,
# )

scoring_functions = [
"basic::regex_parser_multiple_choice_answer",
]
rows = client.datasetio.get_rows_paginated(
dataset_id="mmlu",
rows_in_page=10,
)

output_res = {
"chat_completion_input": [],
"generated_output": [],
"expected_output": [],
}
for x in scoring_functions:
output_res[x] = []

# run evaluate_rows row by row
# TODO: jobs for background job scheduling
for r in tqdm(rows.rows):
eval_response = client.eval.evaluate_rows(
task_id="meta-reference-mmlu",
input_rows=[r],
scoring_functions=scoring_functions,
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "Llama3.2-3B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0,
},
},
},
)
output_res["chat_completion_input"].append(r["chat_completion_input"])
output_res["expected_output"].append(r["expected_answer"])
output_res["generated_output"].append(eval_response.generations[0]["generated_answer"])
for scoring_fn in scoring_functions:
output_res[scoring_fn].append(eval_response.scores[scoring_fn].score_rows[0])

# TODO: specify output file
save_path = Path(os.path.abspath(__file__)).parent / f"eval-result-{task_id}.json"
with open(save_path, "w") as f:
json.dump(output_res, f, indent=4)

print(f"Eval result saved at {save_path}!")
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .configure import ConfigureParser
from .datasets import DatasetsParser
from .eval import EvalParser
from .eval_tasks import EvalTasksParser
from .memory_banks import MemoryBanksParser

Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self):
ProvidersParser.create(subparsers)
DatasetsParser.create(subparsers)
ScoringFunctionsParser.create(subparsers)
EvalParser.create(subparsers)

def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
Expand Down
3 changes: 1 addition & 2 deletions src/llama_stack_client/lib/cli/models/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = ["identifier", "llama_model", "provider_id", "metadata"]
response = client.models.list()
if response:
print_table_from_response(response, headers)
print_table_from_response(response)