Skip to content

Commit 31296fb

Browse files
committed
pre-commit
1 parent b74e46a commit 31296fb

File tree

2 files changed

+9
-27
lines changed

2 files changed

+9
-27
lines changed

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -148,21 +148,13 @@ def run_benchmark(
148148
for aggregation_function in aggregation_functions:
149149
scoring_results = output_res[scoring_fn]
150150
if aggregation_function == "categorical_count":
151-
output_res[scoring_fn].append(
152-
aggregate_categorical_count(scoring_results)
153-
)
151+
output_res[scoring_fn].append(aggregate_categorical_count(scoring_results))
154152
elif aggregation_function == "average":
155-
output_res[scoring_fn].append(
156-
aggregate_average(scoring_results)
157-
)
153+
output_res[scoring_fn].append(aggregate_average(scoring_results))
158154
elif aggregation_function == "median":
159-
output_res[scoring_fn].append(
160-
aggregate_median(scoring_results)
161-
)
155+
output_res[scoring_fn].append(aggregate_median(scoring_results))
162156
elif aggregation_function == "accuracy":
163-
output_res[scoring_fn].append(
164-
aggregate_accuracy(scoring_results)
165-
)
157+
output_res[scoring_fn].append(aggregate_accuracy(scoring_results))
166158
else:
167159
raise NotImplementedError(
168160
f"Aggregation function {aggregation_function} is not supported yet"

src/llama_stack_client/lib/cli/eval/utils.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,32 @@
88

99

1010
def aggregate_categorical_count(
11-
scoring_results: List[
12-
Dict[str, Union[bool, float, str, List[object], object, None]]
13-
],
11+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
1412
) -> Dict[str, Any]:
1513
scores = [str(r["score"]) for r in scoring_results]
1614
unique_scores = sorted(list(set(scores)))
1715
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
1816

1917

2018
def aggregate_average(
21-
scoring_results: List[
22-
Dict[str, Union[bool, float, str, List[object], object, None]]
23-
],
19+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
2420
) -> Dict[str, Any]:
2521
return {
26-
"average": sum(
27-
result["score"] for result in scoring_results if result["score"] is not None
28-
)
22+
"average": sum(result["score"] for result in scoring_results if result["score"] is not None)
2923
/ len([_ for _ in scoring_results if _["score"] is not None]),
3024
}
3125

3226

3327
def aggregate_median(
34-
scoring_results: List[
35-
Dict[str, Union[bool, float, str, List[object], object, None]]
36-
],
28+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
3729
) -> Dict[str, Any]:
3830
scores = [r["score"] for r in scoring_results if r["score"] is not None]
3931
median = statistics.median(scores) if scores else None
4032
return {"median": median}
4133

4234

4335
def aggregate_accuracy(
44-
scoring_results: List[
45-
Dict[str, Union[bool, float, str, List[object], object, None]]
46-
],
36+
scoring_results: List[Dict[str, Union[bool, float, str, List[object], object, None]]],
4737
) -> Dict[str, Any]:
4838
num_correct = sum(result["score"] for result in scoring_results)
4939
avg_score = num_correct / len(scoring_results)

0 commit comments

Comments
 (0)