|
9 | 9 | from typing import Optional |
10 | 10 |
|
11 | 11 | import click |
| 12 | +from rich import print as rprint |
12 | 13 | from tqdm.rich import tqdm |
13 | 14 |
|
| 15 | +from ..common.utils import create_bar_chart |
| 16 | + |
14 | 17 |
|
15 | 18 | @click.command("run_benchmark") |
16 | 19 | @click.argument("eval-task-ids", nargs=-1, required=True) |
|
28 | 31 | @click.option( |
29 | 32 | "--num-examples", required=False, help="Number of examples to evaluate on, useful for debugging", default=None |
30 | 33 | ) |
| 34 | +@click.option( |
| 35 | + "--visualize", |
| 36 | + is_flag=True, |
| 37 | + default=False, |
| 38 | + help="Visualize evaluation results after completion", |
| 39 | +) |
31 | 40 | @click.pass_context |
32 | 41 | def run_benchmark( |
33 | | - ctx, eval_task_ids: tuple[str, ...], eval_task_config: str, output_dir: str, num_examples: Optional[int] |
| 42 | + ctx, |
| 43 | + eval_task_ids: tuple[str, ...], |
| 44 | + eval_task_config: str, |
| 45 | + output_dir: str, |
| 46 | + num_examples: Optional[int], |
| 47 | + visualize: bool, |
34 | 48 | ): |
35 | 49 | """Run a evaluation benchmark""" |
36 | 50 |
|
@@ -79,4 +93,13 @@ def run_benchmark( |
79 | 93 | with open(output_file, "w") as f: |
80 | 94 | json.dump(output_res, f, indent=2) |
81 | 95 |
|
82 | | - print(f"Results saved to: {output_file}") |
| 96 | + rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n") |
| 97 | + |
| 98 | + if visualize: |
| 99 | + for scoring_fn in scoring_functions: |
| 100 | + res = output_res[scoring_fn] |
| 101 | + assert len(res) > 0 and "score" in res[0] |
| 102 | + scores = [str(r["score"]) for r in res] |
| 103 | + unique_scores = sorted(list(set(scores))) |
| 104 | + counts = [scores.count(s) for s in unique_scores] |
| 105 | + create_bar_chart(counts, unique_scores, title=f"{scoring_fn}") |
0 commit comments