Skip to content

Commit c645726

Browse files
authored
refine the benchmark eval UX (#156)
## What does this PR do? Refine the benchmark eval CLI to have a better user experience to run benchmark eval on some standard benchmarks. The benchmarks need to be defined as resource in the distro template improvements include: - user don't need to pass in arbitrary eval-task-config, they only need to pass in the list of benchmarks they'd like to eval, the model id to be evaluated on and the output dir to store the eval results - output aggregate results to the output file. aggregate results are typically what user care most ## Test Plan spin up a llama stack server with eval benchmarks defined run `llama-stack-client --endpoint xxxx eval run-benchmark "meta-reference-simpleqa" --model_id "meta-llama/Llama-3.1-8B-Instruct" --output_dir "/home/markchen1015/" --num_examples 5` return <img width="1284" alt="Screenshot 2025-02-20 at 4 29 35 PM" src="https://github.com/user-attachments/assets/624e0b59-fcbf-46b8-b2cf-2a36bba9aee5" /> what are inside the output file <img width="1436" alt="Screenshot 2025-02-20 at 4 30 08 PM" src="https://github.com/user-attachments/assets/d0a370ff-df98-4fbf-93ba-78104838d02b" /> <img width="1444" alt="Screenshot 2025-02-20 at 4 17 05 PM" src="https://github.com/user-attachments/assets/7da65c03-cc03-48f3-be70-1489a9430f18" />
1 parent 39b1248 commit c645726

File tree

2 files changed

+154
-25
lines changed

2 files changed

+154
-25
lines changed

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,22 @@
1313
from tqdm.rich import tqdm
1414

1515
from ..common.utils import create_bar_chart
16+
from .utils import (
17+
aggregate_accuracy,
18+
aggregate_average,
19+
aggregate_categorical_count,
20+
aggregate_median,
21+
)
1622

1723

1824
@click.command("run-benchmark")
19-
@click.argument("eval-task-ids", nargs=-1, required=True)
25+
@click.argument("benchmark-ids", nargs=-1, required=True)
2026
@click.option(
21-
"--eval-task-config",
27+
"--model-id",
2228
required=True,
23-
help="Path to the eval task config file in JSON format",
24-
type=click.Path(exists=True),
29+
help="model id to run the benchmark eval on",
30+
default=None,
31+
type=str,
2532
)
2633
@click.option(
2734
"--output-dir",
@@ -35,6 +42,34 @@
3542
default=None,
3643
type=int,
3744
)
45+
@click.option(
46+
"--temperature",
47+
required=False,
48+
help="temperature in the sampling params to run generation",
49+
default=0.0,
50+
type=float,
51+
)
52+
@click.option(
53+
"--max-tokens",
54+
required=False,
55+
help="max-tokens in the sampling params to run generation",
56+
default=4096,
57+
type=int,
58+
)
59+
@click.option(
60+
"--top-p",
61+
required=False,
62+
help="top-p in the sampling params to run generation",
63+
default=0.9,
64+
type=float,
65+
)
66+
@click.option(
67+
"--repeat-penalty",
68+
required=False,
69+
help="repeat-penalty in the sampling params to run generation",
70+
default=1.0,
71+
type=float,
72+
)
3873
@click.option(
3974
"--visualize",
4075
is_flag=True,
@@ -44,36 +79,50 @@
4479
@click.pass_context
4580
def run_benchmark(
4681
ctx,
47-
eval_task_ids: tuple[str, ...],
48-
eval_task_config: str,
82+
benchmark_ids: tuple[str, ...],
83+
model_id: str,
4984
output_dir: str,
5085
num_examples: Optional[int],
86+
temperature: float,
87+
max_tokens: int,
88+
top_p: float,
89+
repeat_penalty: float,
5190
visualize: bool,
5291
):
5392
"""Run a evaluation benchmark task"""
5493

5594
client = ctx.obj["client"]
5695

57-
for eval_task_id in eval_task_ids:
58-
eval_task = client.eval_tasks.retrieve(name=eval_task_id)
59-
scoring_functions = eval_task.scoring_functions
60-
dataset_id = eval_task.dataset_id
96+
for benchmark_id in benchmark_ids:
97+
benchmark = client.benchmarks.retrieve(benchmark_id=benchmark_id)
98+
scoring_functions = benchmark.scoring_functions
99+
dataset_id = benchmark.dataset_id
61100

62101
rows = client.datasetio.get_rows_paginated(
63-
dataset_id=dataset_id, rows_in_page=-1 if num_examples is None else num_examples
102+
dataset_id=dataset_id,
103+
rows_in_page=-1 if num_examples is None else num_examples,
64104
)
65105

66-
with open(eval_task_config, "r") as f:
67-
eval_task_config = json.load(f)
68-
69106
output_res = {}
70107

71-
for r in tqdm(rows.rows):
72-
eval_res = client.eval.evaluate_rows(
73-
task_id=eval_task_id,
108+
for i, r in enumerate(tqdm(rows.rows)):
109+
eval_res = client.eval.evaluate_rows_alpha(
110+
benchmark_id=benchmark_id,
74111
input_rows=[r],
75112
scoring_functions=scoring_functions,
76-
task_config=eval_task_config,
113+
task_config={
114+
"type": "benchmark",
115+
"eval_candidate": {
116+
"type": "model",
117+
"model": model_id,
118+
"sampling_params": {
119+
"temperature": temperature,
120+
"max_tokens": max_tokens,
121+
"top_p": top_p,
122+
"repeat_penalty": repeat_penalty,
123+
},
124+
},
125+
},
77126
)
78127
for k in r.keys():
79128
if k not in output_res:
@@ -90,20 +139,55 @@ def run_benchmark(
90139
output_res[scoring_fn] = []
91140
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
92141

142+
aggregation_functions = client.scoring_functions.retrieve(
143+
scoring_fn_id=scoring_fn
144+
).params.aggregation_functions
145+
146+
# only output the aggregation result for the last row
147+
if i == len(rows.rows) - 1:
148+
for aggregation_function in aggregation_functions:
149+
scoring_results = output_res[scoring_fn]
150+
if aggregation_function == "categorical_count":
151+
output_res[scoring_fn].append(aggregate_categorical_count(scoring_results))
152+
elif aggregation_function == "average":
153+
output_res[scoring_fn].append(aggregate_average(scoring_results))
154+
elif aggregation_function == "median":
155+
output_res[scoring_fn].append(aggregate_median(scoring_results))
156+
elif aggregation_function == "accuracy":
157+
output_res[scoring_fn].append(aggregate_accuracy(scoring_results))
158+
else:
159+
raise NotImplementedError(
160+
f"Aggregation function {aggregation_function} is not supported yet"
161+
)
162+
93163
# Create output directory if it doesn't exist
94164
os.makedirs(output_dir, exist_ok=True)
95165
# Save results to JSON file
96-
output_file = os.path.join(output_dir, f"{eval_task_id}_results.json")
166+
output_file = os.path.join(output_dir, f"{benchmark_id}_results.json")
97167
with open(output_file, "w") as f:
98168
json.dump(output_res, f, indent=2)
99169

100170
rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n")
101171

102172
if visualize:
103173
for scoring_fn in scoring_functions:
104-
res = output_res[scoring_fn]
105-
assert len(res) > 0 and "score" in res[0]
106-
scores = [str(r["score"]) for r in res]
107-
unique_scores = sorted(list(set(scores)))
108-
counts = [scores.count(s) for s in unique_scores]
109-
create_bar_chart(counts, unique_scores, title=f"{scoring_fn}")
174+
aggregation_functions = client.scoring_functions.retrieve(
175+
scoring_fn_id=scoring_fn
176+
).params.aggregation_functions
177+
178+
for aggregation_function in aggregation_functions:
179+
res = output_res[scoring_fn]
180+
assert len(res) > 0 and "score" in res[0]
181+
if aggregation_function == "categorical_count":
182+
scores = [str(r["score"]) for r in res]
183+
unique_scores = sorted(list(set(scores)))
184+
counts = [scores.count(s) for s in unique_scores]
185+
create_bar_chart(
186+
counts,
187+
unique_scores,
188+
title=f"{scoring_fn}-{aggregation_function}",
189+
)
190+
else:
191+
raise NotImplementedError(
192+
f"Aggregation function {aggregation_function} ius not supported for visualization yet"
193+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from typing import Any, Dict, List, Union
8+
9+
10+
def aggregate_categorical_count(
11+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
12+
) -> Dict[str, Any]:
13+
scores = [str(r["score"]) for r in scoring_results]
14+
unique_scores = sorted(list(set(scores)))
15+
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
16+
17+
18+
def aggregate_average(
19+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
20+
) -> Dict[str, Any]:
21+
return {
22+
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
23+
/ len([_ for _ in scoring_results if _["score"] is not None]),
24+
}
25+
26+
27+
def aggregate_median(
28+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
29+
) -> Dict[str, Any]:
30+
scores = [r["score"] for r in scoring_results if r["score"] is not None]
31+
median = statistics.median(scores) if scores else None
32+
return {"median": median}
33+
34+
35+
def aggregate_accuracy(
36+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
37+
) -> Dict[str, Any]:
38+
num_correct = sum(result["score"] for result in scoring_results)
39+
avg_score = num_correct / len(scoring_results)
40+
41+
return {
42+
"accuracy": avg_score,
43+
"num_correct": num_correct,
44+
"num_total": len(scoring_results),
45+
}

0 commit comments

Comments
 (0)