Skip to content

Commit fa75193

Browse files
committed
add toolgroup cli commands
1 parent 2183dc9 commit fa75193

File tree

4 files changed

+170
-13
lines changed

4 files changed

+170
-13
lines changed

src/llama_stack_client/lib/cli/llama_stack_client.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,36 @@
55
# the root directory of this source tree.
66

77
import os
8+
from importlib.metadata import version
89

910
import click
1011
import yaml
1112

1213
from llama_stack_client import LlamaStackClient
13-
from importlib.metadata import version
1414

1515
from .configure import configure
1616
from .constants import get_config_file_path
1717
from .datasets import datasets
1818
from .eval import eval
1919
from .eval_tasks import eval_tasks
2020
from .inference import inference
21+
from .inspect import inspect
2122
from .memory_banks import memory_banks
2223
from .models import models
2324
from .post_training import post_training
2425
from .providers import providers
2526
from .scoring_functions import scoring_functions
2627
from .shields import shields
27-
from .inspect import inspect
28+
from .toolgroups import toolgroups
29+
2830

2931
@click.group()
30-
@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client")
31-
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
32+
@click.version_option(
33+
version=version("llama-stack-client"), prog_name="llama-stack-client"
34+
)
35+
@click.option(
36+
"--endpoint", type=str, help="Llama Stack distribution endpoint", default=""
37+
)
3238
@click.option("--config", type=str, help="Path to config file", default=None)
3339
@click.pass_context
3440
def cli(ctx, endpoint: str, config: str | None):
@@ -80,6 +86,7 @@ def cli(ctx, endpoint: str, config: str | None):
8086
cli.add_command(inference, "inference")
8187
cli.add_command(post_training, "post_training")
8288
cli.add_command(inspect, "inspect")
89+
cli.add_command(toolgroups, "toolgroups")
8390

8491

8592
def main():

src/llama_stack_client/lib/cli/models/models.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@ def list_models(ctx):
2626
client = ctx.obj["client"]
2727
console = Console()
2828

29-
headers = ["identifier", "provider_id", "provider_resource_id", "metadata"]
29+
headers = [
30+
"identifier",
31+
"provider_id",
32+
"provider_resource_id",
33+
"metadata",
34+
"model_type",
35+
]
3036
response = client.models.list()
3137
if response:
3238
table = Table()
3339
for header in headers:
3440
table.add_column(header)
3541

3642
for item in response:
37-
table.add_row(
38-
str(getattr(item, headers[0])),
39-
str(getattr(item, headers[1])),
40-
str(getattr(item, headers[2])),
41-
str(getattr(item, headers[3])),
42-
)
43+
row = [str(getattr(item, header)) for header in headers]
44+
table.add_row(*row)
4345
console.print(table)
4446

4547

@@ -79,14 +81,21 @@ def get_model(ctx, model_id: str):
7981
@click.pass_context
8082
@handle_client_errors("register model")
8183
def register_model(
82-
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
84+
ctx,
85+
model_id: str,
86+
provider_id: Optional[str],
87+
provider_model_id: Optional[str],
88+
metadata: Optional[str],
8389
):
8490
"""Register a new model at distribution endpoint"""
8591
client = ctx.obj["client"]
8692
console = Console()
8793

8894
response = client.models.register(
89-
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
95+
model_id=model_id,
96+
provider_id=provider_id,
97+
provider_model_id=provider_model_id,
98+
metadata=metadata,
9099
)
91100
if response:
92101
console.print(f"[green]Successfully registered model {model_id}[/green]")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from .toolgroups import toolgroups
8+
9+
__all__ = ["toolgroups"]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import click
10+
from rich.console import Console
11+
from rich.table import Table
12+
13+
from ..common.utils import handle_client_errors
14+
15+
16+
@click.group()
17+
def toolgroups():
18+
"""Query details about available toolgroups on Llama Stack distribution."""
19+
pass
20+
21+
22+
@click.command(
23+
name="list", help="Show available llama toolgroups at distribution endpoint"
24+
)
25+
@click.pass_context
26+
@handle_client_errors("list toolgroups")
27+
def list_toolgroups(ctx):
28+
client = ctx.obj["client"]
29+
console = Console()
30+
31+
headers = ["identifier", "provider_id", "args", "mcp_endpoint"]
32+
response = client.toolgroups.list()
33+
if response:
34+
table = Table()
35+
for header in headers:
36+
table.add_column(header)
37+
38+
for item in response:
39+
row = [str(getattr(item, header)) for header in headers]
40+
table.add_row(*row)
41+
console.print(table)
42+
43+
44+
@click.command(name="get")
45+
@click.argument("toolgroup_id")
46+
@click.pass_context
47+
@handle_client_errors("get toolgroup details")
48+
def get_toolgroup(ctx, toolgroup_id: str):
49+
"""Show available llama toolgroups at distribution endpoint"""
50+
client = ctx.obj["client"]
51+
console = Console()
52+
53+
toolgroups_get_response = client.tools.list()
54+
# filter response to only include provided toolgroup_id
55+
toolgroups_get_response = [
56+
toolgroup
57+
for toolgroup in toolgroups_get_response
58+
if toolgroup.toolgroup_id == toolgroup_id
59+
]
60+
if len(toolgroups_get_response) == 0:
61+
console.print(
62+
f"Toolgroup {toolgroup_id} is not found at distribution endpoint. "
63+
"Please ensure endpoint is serving specified toolgroup.",
64+
style="bold red",
65+
)
66+
return
67+
68+
headers = sorted(toolgroups_get_response[0].__dict__.keys())
69+
table = Table()
70+
for header in headers:
71+
table.add_column(header)
72+
73+
for toolgroup in toolgroups_get_response:
74+
row = [str(getattr(toolgroup, header)) for header in headers]
75+
table.add_row(*row)
76+
console.print(table)
77+
78+
79+
@click.command(
80+
name="register", help="Register a new toolgroup at distribution endpoint"
81+
)
82+
@click.argument("toolgroup_id")
83+
@click.option("--provider-id", help="Provider ID for the toolgroup", default=None)
84+
@click.option("--provider-toolgroup-id", help="Provider's toolgroup ID", default=None)
85+
@click.option("--mcp-config", help="JSON mcp_config for the toolgroup", default=None)
86+
@click.option("--args", help="JSON args for the toolgroup", default=None)
87+
@click.pass_context
88+
@handle_client_errors("register toolgroup")
89+
def register_toolgroup(
90+
ctx,
91+
toolgroup_id: str,
92+
provider_id: Optional[str],
93+
provider_toolgroup_id: Optional[str],
94+
mcp_config: Optional[str],
95+
args: Optional[str],
96+
):
97+
"""Register a new toolgroup at distribution endpoint"""
98+
client = ctx.obj["client"]
99+
console = Console()
100+
101+
response = client.toolgroups.register(
102+
toolgroup_id=toolgroup_id,
103+
provider_id=provider_id,
104+
args=args,
105+
mcp_config=mcp_config,
106+
)
107+
if response:
108+
console.print(
109+
f"[green]Successfully registered toolgroup {toolgroup_id}[/green]"
110+
)
111+
112+
113+
@click.command(
114+
name="unregister", help="Unregister a toolgroup from distribution endpoint"
115+
)
116+
@click.argument("toolgroup_id")
117+
@click.pass_context
118+
@handle_client_errors("unregister toolgroup")
119+
def unregister_toolgroup(ctx, toolgroup_id: str):
120+
client = ctx.obj["client"]
121+
console = Console()
122+
123+
response = client.toolgroups.unregister(tool_group_id=toolgroup_id)
124+
if response:
125+
console.print(f"[green]Successfully deleted toolgroup {toolgroup_id}[/green]")
126+
127+
128+
# Register subcommands
129+
toolgroups.add_command(list_toolgroups)
130+
toolgroups.add_command(get_toolgroup)
131+
toolgroups.add_command(register_toolgroup)
132+
toolgroups.add_command(unregister_toolgroup)

0 commit comments

Comments
 (0)