Skip to content

Commit ecf6a48

Browse files
authored
Merge pull request #31 from meta-llama/pretty_table
[CLI] visualize categorical scores eval results with bars
2 parents b8a050b + b6d8d10 commit ecf6a48

File tree

7 files changed

+54
-20
lines changed

7 files changed

+54
-20
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ dependencies = [
1515
"distro>=1.7.0, <2",
1616
"sniffio",
1717
"cached-property; python_version < '3.8'",
18-
"tabulate>=0.9.0",
1918
]
2019
requires-python = ">= 3.7"
2120
classifiers = [

requirements-dev.lock

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ sniffio==1.3.0
9090
# via anyio
9191
# via httpx
9292
# via llama-stack-client
93-
tabulate==0.9.0
94-
# via llama-stack-client
95-
termcolor==2.4.0
96-
# via llama-stack-client
9793
time-machine==2.9.0
9894
tomli==2.0.1
9995
# via mypy

requirements.lock

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@ sniffio==1.3.0
3939
# via anyio
4040
# via httpx
4141
# via llama-stack-client
42-
tabulate==0.9.0
43-
# via llama-stack-client
44-
termcolor==2.4.0
45-
# via llama-stack-client
4642
typing-extensions==4.8.0
4743
# via anyio
4844
# via llama-stack-client

src/llama_stack_client/lib/cli/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,10 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6+
7+
# Ignore tqdm experimental warning
8+
import warnings
9+
10+
from tqdm import TqdmExperimentalWarning
11+
12+
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)

src/llama_stack_client/lib/cli/common/utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,28 @@
33
#
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
6-
from tabulate import tabulate
6+
from rich.console import Console
7+
from rich.table import Table
78

89

9-
def print_table_from_response(response, headers=()):
10-
if not headers:
11-
headers = sorted(response[0].__dict__.keys())
10+
def create_bar_chart(data, labels, title=""):
11+
"""Create a bar chart using Rich Table."""
1212

13-
rows = []
14-
for spec in response:
15-
rows.append([spec.__dict__[headers[i]] for i in range(len(headers))])
13+
console = Console()
14+
table = Table(title=title)
15+
table.add_column("Score")
16+
table.add_column("Count")
1617

17-
print(tabulate(rows, headers=headers, tablefmt="grid"))
18+
max_value = max(data)
19+
total_count = sum(data)
20+
21+
# Define a list of colors to cycle through
22+
colors = ["green", "blue", "red", "yellow", "magenta", "cyan"]
23+
24+
for i, (label, value) in enumerate(zip(labels, data)):
25+
bar_length = int((value / max_value) * 20) # Adjust bar length as needed
26+
bar = "█" * bar_length + " " * (20 - bar_length)
27+
color = colors[i % len(colors)]
28+
table.add_row(label, f"[{color}]{bar}[/] {value}/{total_count}")
29+
30+
console.print(table)

src/llama_stack_client/lib/cli/eval/run_benchmark.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
from typing import Optional
1010

1111
import click
12+
from rich import print as rprint
1213
from tqdm.rich import tqdm
1314

15+
from ..common.utils import create_bar_chart
16+
1417

1518
@click.command("run_benchmark")
1619
@click.argument("eval-task-ids", nargs=-1, required=True)
@@ -28,9 +31,20 @@
2831
@click.option(
2932
"--num-examples", required=False, help="Number of examples to evaluate on, useful for debugging", default=None
3033
)
34+
@click.option(
35+
"--visualize",
36+
is_flag=True,
37+
default=False,
38+
help="Visualize evaluation results after completion",
39+
)
3140
@click.pass_context
3241
def run_benchmark(
33-
ctx, eval_task_ids: tuple[str, ...], eval_task_config: str, output_dir: str, num_examples: Optional[int]
42+
ctx,
43+
eval_task_ids: tuple[str, ...],
44+
eval_task_config: str,
45+
output_dir: str,
46+
num_examples: Optional[int],
47+
visualize: bool,
3448
):
3549
"""Run a evaluation benchmark"""
3650

@@ -79,4 +93,13 @@ def run_benchmark(
7993
with open(output_file, "w") as f:
8094
json.dump(output_res, f, indent=2)
8195

82-
print(f"Results saved to: {output_file}")
96+
rprint(f"[green]✓[/green] Results saved to: [blue]{output_file}[/blue]!\n")
97+
98+
if visualize:
99+
for scoring_fn in scoring_functions:
100+
res = output_res[scoring_fn]
101+
assert len(res) > 0 and "score" in res[0]
102+
scores = [str(r["score"]) for r in res]
103+
unique_scores = sorted(list(set(scores)))
104+
counts = [scores.count(s) for s in unique_scores]
105+
create_bar_chart(counts, unique_scores, title=f"{scoring_fn}")

src/llama_stack_client/lib/cli/llama_stack_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def cli(ctx, endpoint: str, config: str | None):
5757
base_url=endpoint,
5858
provider_data={
5959
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
60-
"togethers_api_key": os.environ.get("TOGETHERS_API_KEY", ""),
60+
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
6161
},
6262
)
6363
ctx.obj = {"client": client}

0 commit comments

Comments
 (0)