Skip to content

Commit 9171e97

Browse files
authored
add toolgroup cli commands (#78)
add llama stack client subcommand for tool groups test plan: ``` llama-stack-client toolgroups list ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━┓ ┃ identifier ┃ provider_id ┃ args ┃ mcp_endpoint ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━┩ │ builtin::code_interpreter │ code-interpreter │ None │ None │ │ builtin::memory │ memory-runtime │ None │ None │ │ builtin::websearch │ tavily-search │ None │ None │ │ mcp-memory │ model-context-protocol │ None │ None │ │ brave-search │ model-context-protocol │ None │ None │ └───────────────────────────┴────────────────────────┴──────┴──────────────┘ ``` ![Screenshot 2025-01-09 at 1 29 50 PM](https://github.com/user-attachments/assets/350127fa-befd-44d1-a1a5-c209013b02e5)
1 parent 6a38b39 commit 9171e97

File tree

4 files changed

+167
-11
lines changed

4 files changed

+167
-11
lines changed

src/llama_stack_client/lib/cli/llama_stack_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,16 @@
2525
from .providers import providers
2626
from .scoring_functions import scoring_functions
2727
from .shields import shields
28+
from .toolgroups import toolgroups
2829

2930

3031
@click.group()
31-
@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client")
32-
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
32+
@click.version_option(
33+
version=version("llama-stack-client"), prog_name="llama-stack-client"
34+
)
35+
@click.option(
36+
"--endpoint", type=str, help="Llama Stack distribution endpoint", default=""
37+
)
3338
@click.option("--config", type=str, help="Path to config file", default=None)
3439
@click.pass_context
3540
def cli(ctx, endpoint: str, config: str | None):
@@ -81,6 +86,7 @@ def cli(ctx, endpoint: str, config: str | None):
8186
cli.add_command(inference, "inference")
8287
cli.add_command(post_training, "post_training")
8388
cli.add_command(inspect, "inspect")
89+
cli.add_command(toolgroups, "toolgroups")
8490

8591

8692
def main():

src/llama_stack_client/lib/cli/models/models.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@ def list_models(ctx):
2626
client = ctx.obj["client"]
2727
console = Console()
2828

29-
headers = ["identifier", "provider_id", "provider_resource_id", "metadata"]
29+
headers = [
30+
"identifier",
31+
"provider_id",
32+
"provider_resource_id",
33+
"metadata",
34+
"model_type",
35+
]
3036
response = client.models.list()
3137
if response:
3238
table = Table()
3339
for header in headers:
3440
table.add_column(header)
3541

3642
for item in response:
37-
table.add_row(
38-
str(getattr(item, headers[0])),
39-
str(getattr(item, headers[1])),
40-
str(getattr(item, headers[2])),
41-
str(getattr(item, headers[3])),
42-
)
43+
row = [str(getattr(item, header)) for header in headers]
44+
table.add_row(*row)
4345
console.print(table)
4446

4547

@@ -79,14 +81,21 @@ def get_model(ctx, model_id: str):
7981
@click.pass_context
8082
@handle_client_errors("register model")
8183
def register_model(
82-
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
84+
ctx,
85+
model_id: str,
86+
provider_id: Optional[str],
87+
provider_model_id: Optional[str],
88+
metadata: Optional[str],
8389
):
8490
"""Register a new model at distribution endpoint"""
8591
client = ctx.obj["client"]
8692
console = Console()
8793

8894
response = client.models.register(
89-
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
95+
model_id=model_id,
96+
provider_id=provider_id,
97+
provider_model_id=provider_model_id,
98+
metadata=metadata,
9099
)
91100
if response:
92101
console.print(f"[green]Successfully registered model {model_id}[/green]")
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from .toolgroups import toolgroups
8+
9+
__all__ = ["toolgroups"]
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from typing import Optional
8+
9+
import click
10+
from rich.console import Console
11+
from rich.table import Table
12+
13+
from ..common.utils import handle_client_errors
14+
15+
16+
@click.group()
17+
def toolgroups():
18+
"""Query details about available toolgroups on Llama Stack distribution."""
19+
pass
20+
21+
22+
@click.command(
23+
name="list", help="Show available llama toolgroups at distribution endpoint"
24+
)
25+
@click.pass_context
26+
@handle_client_errors("list toolgroups")
27+
def list_toolgroups(ctx):
28+
client = ctx.obj["client"]
29+
console = Console()
30+
31+
headers = ["identifier", "provider_id", "args", "mcp_endpoint"]
32+
response = client.toolgroups.list()
33+
if response:
34+
table = Table()
35+
for header in headers:
36+
table.add_column(header)
37+
38+
for item in response:
39+
row = [str(getattr(item, header)) for header in headers]
40+
table.add_row(*row)
41+
console.print(table)
42+
43+
44+
@click.command(name="get")
45+
@click.argument("toolgroup_id")
46+
@click.pass_context
47+
@handle_client_errors("get toolgroup details")
48+
def get_toolgroup(ctx, toolgroup_id: str):
49+
"""Show available llama toolgroups at distribution endpoint"""
50+
client = ctx.obj["client"]
51+
console = Console()
52+
53+
toolgroups_get_response = client.tools.list()
54+
# filter response to only include provided toolgroup_id
55+
toolgroups_get_response = [
56+
toolgroup
57+
for toolgroup in toolgroups_get_response
58+
if toolgroup.toolgroup_id == toolgroup_id
59+
]
60+
if len(toolgroups_get_response) == 0:
61+
console.print(
62+
f"Toolgroup {toolgroup_id} is not found at distribution endpoint. "
63+
"Please ensure endpoint is serving specified toolgroup.",
64+
style="bold red",
65+
)
66+
return
67+
68+
headers = sorted(toolgroups_get_response[0].__dict__.keys())
69+
table = Table()
70+
for header in headers:
71+
table.add_column(header)
72+
73+
for toolgroup in toolgroups_get_response:
74+
row = [str(getattr(toolgroup, header)) for header in headers]
75+
table.add_row(*row)
76+
console.print(table)
77+
78+
79+
@click.command(
80+
name="register", help="Register a new toolgroup at distribution endpoint"
81+
)
82+
@click.argument("toolgroup_id")
83+
@click.option("--provider-id", help="Provider ID for the toolgroup", default=None)
84+
@click.option("--provider-toolgroup-id", help="Provider's toolgroup ID", default=None)
85+
@click.option("--mcp-config", help="JSON mcp_config for the toolgroup", default=None)
86+
@click.option("--args", help="JSON args for the toolgroup", default=None)
87+
@click.pass_context
88+
@handle_client_errors("register toolgroup")
89+
def register_toolgroup(
90+
ctx,
91+
toolgroup_id: str,
92+
provider_id: Optional[str],
93+
provider_toolgroup_id: Optional[str],
94+
mcp_config: Optional[str],
95+
args: Optional[str],
96+
):
97+
"""Register a new toolgroup at distribution endpoint"""
98+
client = ctx.obj["client"]
99+
console = Console()
100+
101+
response = client.toolgroups.register(
102+
toolgroup_id=toolgroup_id,
103+
provider_id=provider_id,
104+
args=args,
105+
mcp_config=mcp_config,
106+
)
107+
if response:
108+
console.print(
109+
f"[green]Successfully registered toolgroup {toolgroup_id}[/green]"
110+
)
111+
112+
113+
@click.command(
114+
name="unregister", help="Unregister a toolgroup from distribution endpoint"
115+
)
116+
@click.argument("toolgroup_id")
117+
@click.pass_context
118+
@handle_client_errors("unregister toolgroup")
119+
def unregister_toolgroup(ctx, toolgroup_id: str):
120+
client = ctx.obj["client"]
121+
console = Console()
122+
123+
response = client.toolgroups.unregister(tool_group_id=toolgroup_id)
124+
if response:
125+
console.print(f"[green]Successfully deleted toolgroup {toolgroup_id}[/green]")
126+
127+
128+
# Register subcommands
129+
toolgroups.add_command(list_toolgroups)
130+
toolgroups.add_command(get_toolgroup)
131+
toolgroups.add_command(register_toolgroup)
132+
toolgroups.add_command(unregister_toolgroup)

0 commit comments

Comments
 (0)