diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 5e49c15a..0a3eeb35 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -25,11 +25,16 @@ from .providers import providers from .scoring_functions import scoring_functions from .shields import shields +from .toolgroups import toolgroups @click.group() -@click.version_option(version=version("llama-stack-client"), prog_name="llama-stack-client") -@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="") +@click.version_option( + version=version("llama-stack-client"), prog_name="llama-stack-client" +) +@click.option( + "--endpoint", type=str, help="Llama Stack distribution endpoint", default="" +) @click.option("--config", type=str, help="Path to config file", default=None) @click.pass_context def cli(ctx, endpoint: str, config: str | None): @@ -81,6 +86,7 @@ def cli(ctx, endpoint: str, config: str | None): cli.add_command(inference, "inference") cli.add_command(post_training, "post_training") cli.add_command(inspect, "inspect") +cli.add_command(toolgroups, "toolgroups") def main(): diff --git a/src/llama_stack_client/lib/cli/models/models.py b/src/llama_stack_client/lib/cli/models/models.py index f080e314..a7eb2c44 100644 --- a/src/llama_stack_client/lib/cli/models/models.py +++ b/src/llama_stack_client/lib/cli/models/models.py @@ -26,7 +26,13 @@ def list_models(ctx): client = ctx.obj["client"] console = Console() - headers = ["identifier", "provider_id", "provider_resource_id", "metadata"] + headers = [ + "identifier", + "provider_id", + "provider_resource_id", + "metadata", + "model_type", + ] response = client.models.list() if response: table = Table() @@ -34,12 +40,8 @@ def list_models(ctx): table.add_column(header) for item in response: - table.add_row( - str(getattr(item, headers[0])), - str(getattr(item, headers[1])), - str(getattr(item, headers[2])), - str(getattr(item, headers[3])), - ) + row = [str(getattr(item, header)) for header in headers] + table.add_row(*row) console.print(table) @@ -79,14 +81,21 @@ def get_model(ctx, model_id: str): @click.pass_context @handle_client_errors("register model") def register_model( - ctx, model_id: str, provider_id: Optional[str], provider_model_id: Optional[str], metadata: Optional[str] + ctx, + model_id: str, + provider_id: Optional[str], + provider_model_id: Optional[str], + metadata: Optional[str], ): """Register a new model at distribution endpoint""" client = ctx.obj["client"] console = Console() response = client.models.register( - model_id=model_id, provider_id=provider_id, provider_model_id=provider_model_id, metadata=metadata + model_id=model_id, + provider_id=provider_id, + provider_model_id=provider_model_id, + metadata=metadata, ) if response: console.print(f"[green]Successfully registered model {model_id}[/green]") diff --git a/src/llama_stack_client/lib/cli/toolgroups/__init__.py b/src/llama_stack_client/lib/cli/toolgroups/__init__.py new file mode 100644 index 00000000..912d911b --- /dev/null +++ b/src/llama_stack_client/lib/cli/toolgroups/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .toolgroups import toolgroups + +__all__ = ["toolgroups"] diff --git a/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py b/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py new file mode 100644 index 00000000..acde6b03 --- /dev/null +++ b/src/llama_stack_client/lib/cli/toolgroups/toolgroups.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +import click +from rich.console import Console +from rich.table import Table + +from ..common.utils import handle_client_errors + + +@click.group() +def toolgroups(): + """Query details about available toolgroups on Llama Stack distribution.""" + pass + + +@click.command( + name="list", help="Show available llama toolgroups at distribution endpoint" +) +@click.pass_context +@handle_client_errors("list toolgroups") +def list_toolgroups(ctx): + client = ctx.obj["client"] + console = Console() + + headers = ["identifier", "provider_id", "args", "mcp_endpoint"] + response = client.toolgroups.list() + if response: + table = Table() + for header in headers: + table.add_column(header) + + for item in response: + row = [str(getattr(item, header)) for header in headers] + table.add_row(*row) + console.print(table) + + +@click.command(name="get") +@click.argument("toolgroup_id") +@click.pass_context +@handle_client_errors("get toolgroup details") +def get_toolgroup(ctx, toolgroup_id: str): + """Show available llama toolgroups at distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + toolgroups_get_response = client.tools.list() + # filter response to only include provided toolgroup_id + toolgroups_get_response = [ + toolgroup + for toolgroup in toolgroups_get_response + if toolgroup.toolgroup_id == toolgroup_id + ] + if len(toolgroups_get_response) == 0: + console.print( + f"Toolgroup {toolgroup_id} is not found at distribution endpoint. " + "Please ensure endpoint is serving specified toolgroup.", + style="bold red", + ) + return + + headers = sorted(toolgroups_get_response[0].__dict__.keys()) + table = Table() + for header in headers: + table.add_column(header) + + for toolgroup in toolgroups_get_response: + row = [str(getattr(toolgroup, header)) for header in headers] + table.add_row(*row) + console.print(table) + + +@click.command( + name="register", help="Register a new toolgroup at distribution endpoint" +) +@click.argument("toolgroup_id") +@click.option("--provider-id", help="Provider ID for the toolgroup", default=None) +@click.option("--provider-toolgroup-id", help="Provider's toolgroup ID", default=None) +@click.option("--mcp-config", help="JSON mcp_config for the toolgroup", default=None) +@click.option("--args", help="JSON args for the toolgroup", default=None) +@click.pass_context +@handle_client_errors("register toolgroup") +def register_toolgroup( + ctx, + toolgroup_id: str, + provider_id: Optional[str], + provider_toolgroup_id: Optional[str], + mcp_config: Optional[str], + args: Optional[str], +): + """Register a new toolgroup at distribution endpoint""" + client = ctx.obj["client"] + console = Console() + + response = client.toolgroups.register( + toolgroup_id=toolgroup_id, + provider_id=provider_id, + args=args, + mcp_config=mcp_config, + ) + if response: + console.print( + f"[green]Successfully registered toolgroup {toolgroup_id}[/green]" + ) + + +@click.command( + name="unregister", help="Unregister a toolgroup from distribution endpoint" +) +@click.argument("toolgroup_id") +@click.pass_context +@handle_client_errors("unregister toolgroup") +def unregister_toolgroup(ctx, toolgroup_id: str): + client = ctx.obj["client"] + console = Console() + + response = client.toolgroups.unregister(tool_group_id=toolgroup_id) + if response: + console.print(f"[green]Successfully deleted toolgroup {toolgroup_id}[/green]") + + +# Register subcommands +toolgroups.add_command(list_toolgroups) +toolgroups.add_command(get_toolgroup) +toolgroups.add_command(register_toolgroup) +toolgroups.add_command(unregister_toolgroup)