Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@
from .providers import providers
from .scoring_functions import scoring_functions
from .shields import shields
from .toolgroups import toolgroups


@click.group()
@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client")
@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="")
@click.version_option(
version=version("llama-stack-client"), prog_name="llama-stack-client"
)
@click.option(
"--endpoint", type=str, help="Llama Stack distribution endpoint", default=""
)
@click.option("--config", type=str, help="Path to config file", default=None)
@click.pass_context
def cli(ctx, endpoint: str, config: str | None):
Expand Down Expand Up @@ -81,6 +86,7 @@ def cli(ctx, endpoint: str, config: str | None):
cli.add_command(inference, "inference")
cli.add_command(post_training, "post_training")
cli.add_command(inspect, "inspect")
cli.add_command(toolgroups, "toolgroups")


def main():
Expand Down
27 changes: 18 additions & 9 deletions src/llama_stack_client/lib/cli/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ def list_models(ctx):
client = ctx.obj["client"]
console = Console()

headers = ["identifier", "provider_id", "provider_resource_id", "metadata"]
headers = [
"identifier",
"provider_id",
"provider_resource_id",
"metadata",
"model_type",
]
response = client.models.list()
if response:
table = Table()
for header in headers:
table.add_column(header)

for item in response:
table.add_row(
str(getattr(item, headers[0])),
str(getattr(item, headers[1])),
str(getattr(item, headers[2])),
str(getattr(item, headers[3])),
)
row = [str(getattr(item, header)) for header in headers]
table.add_row(*row)
console.print(table)


Expand Down Expand Up @@ -79,14 +81,21 @@ def get_model(ctx, model_id: str):
@click.pass_context
@handle_client_errors("register model")
def register_model(
ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str]
ctx,
model_id: str,
provider_id: Optional[str],
provider_model_id: Optional[str],
metadata: Optional[str],
):
"""Register a new model at distribution endpoint"""
client = ctx.obj["client"]
console = Console()

response = client.models.register(
model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata
model_id=model_id,
provider_id=provider_id,
provider_model_id=provider_model_id,
metadata=metadata,
)
if response:
console.print(f"[green]Successfully registered model {model_id}[/green]")
Expand Down
9 changes: 9 additions & 0 deletions src/llama_stack_client/lib/cli/toolgroups/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .toolgroups import toolgroups

__all__ = ["toolgroups"]
132 changes: 132 additions & 0 deletions src/llama_stack_client/lib/cli/toolgroups/toolgroups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Optional

import click
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors


@click.group()
def toolgroups():
"""Query details about available toolgroups on Llama Stack distribution."""
pass


@click.command(
name="list", help="Show available llama toolgroups at distribution endpoint"
)
@click.pass_context
@handle_client_errors("list toolgroups")
def list_toolgroups(ctx):
client = ctx.obj["client"]
console = Console()

headers = ["identifier", "provider_id", "args", "mcp_endpoint"]
response = client.toolgroups.list()
if response:
table = Table()
for header in headers:
table.add_column(header)

for item in response:
row = [str(getattr(item, header)) for header in headers]
table.add_row(*row)
console.print(table)


@click.command(name="get")
@click.argument("toolgroup_id")
@click.pass_context
@handle_client_errors("get toolgroup details")
def get_toolgroup(ctx, toolgroup_id: str):
"""Show available llama toolgroups at distribution endpoint"""
client = ctx.obj["client"]
console = Console()

toolgroups_get_response = client.tools.list()
# filter response to only include provided toolgroup_id
toolgroups_get_response = [
toolgroup
for toolgroup in toolgroups_get_response
if toolgroup.toolgroup_id == toolgroup_id
]
if len(toolgroups_get_response) == 0:
console.print(
f"Toolgroup {toolgroup_id} is not found at distribution endpoint. "
"Please ensure endpoint is serving specified toolgroup.",
style="bold red",
)
return

headers = sorted(toolgroups_get_response[0].__dict__.keys())
table = Table()
for header in headers:
table.add_column(header)

for toolgroup in toolgroups_get_response:
row = [str(getattr(toolgroup, header)) for header in headers]
table.add_row(*row)
console.print(table)


@click.command(
name="register", help="Register a new toolgroup at distribution endpoint"
)
@click.argument("toolgroup_id")
@click.option("--provider-id", help="Provider ID for the toolgroup", default=None)
@click.option("--provider-toolgroup-id", help="Provider's toolgroup ID", default=None)
@click.option("--mcp-config", help="JSON mcp_config for the toolgroup", default=None)
@click.option("--args", help="JSON args for the toolgroup", default=None)
@click.pass_context
@handle_client_errors("register toolgroup")
def register_toolgroup(
ctx,
toolgroup_id: str,
provider_id: Optional[str],
provider_toolgroup_id: Optional[str],
mcp_config: Optional[str],
args: Optional[str],
):
"""Register a new toolgroup at distribution endpoint"""
client = ctx.obj["client"]
console = Console()

response = client.toolgroups.register(
toolgroup_id=toolgroup_id,
provider_id=provider_id,
args=args,
mcp_config=mcp_config,
)
if response:
console.print(
f"[green]Successfully registered toolgroup {toolgroup_id}[/green]"
)


@click.command(
name="unregister", help="Unregister a toolgroup from distribution endpoint"
)
@click.argument("toolgroup_id")
@click.pass_context
@handle_client_errors("unregister toolgroup")
def unregister_toolgroup(ctx, toolgroup_id: str):
client = ctx.obj["client"]
console = Console()

response = client.toolgroups.unregister(tool_group_id=toolgroup_id)
if response:
console.print(f"[green]Successfully deleted toolgroup {toolgroup_id}[/green]")


# Register subcommands
toolgroups.add_command(list_toolgroups)
toolgroups.add_command(get_toolgroup)
toolgroups.add_command(register_toolgroup)
toolgroups.add_command(unregister_toolgroup)
Loading