diff --git a/src/llama_stack_client/lib/cli/configure.py b/src/llama_stack_client/lib/cli/configure.py index 88f5b0bb..b6bcba6f 100644 --- a/src/llama_stack_client/lib/cli/configure.py +++ b/src/llama_stack_client/lib/cli/configure.py @@ -24,42 +24,42 @@ def get_config(): @click.command() -@click.option("--host", type=str, help="Llama Stack distribution host") -@click.option("--port", type=str, help="Llama Stack distribution port number") -@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint") -def configure(host: str | None, port: str | None, endpoint: str | None): +@click.option("--endpoint", type=str, help="Llama Stack distribution endpoint", default="") +@click.option("--api-key", type=str, help="Llama Stack distribution API key", default="") +def configure(endpoint: str | None, api_key: str | None): """Configure Llama Stack Client CLI""" os.makedirs(LLAMA_STACK_CLIENT_CONFIG_DIR, exist_ok=True) config_path = get_config_file_path() - if endpoint: + if endpoint != "": final_endpoint = endpoint else: - if host and port: - final_endpoint = f"http://{host}:{port}" - else: - host = prompt( - "> Enter the host name of the Llama Stack distribution server: ", - validator=Validator.from_callable( - lambda x: len(x) > 0, - error_message="Host cannot be empty, please enter a valid host", - ), - ) - port = prompt( - "> Enter the port number of the Llama Stack distribution server: ", - validator=Validator.from_callable( - lambda x: x.isdigit(), - error_message="Please enter a valid port number", - ), - ) - final_endpoint = f"http://{host}:{port}" + final_endpoint = prompt( + "> Enter the endpoint of the Llama Stack distribution server: ", + validator=Validator.from_callable( + lambda x: len(x) > 0, + error_message="Endpoint cannot be empty, please enter a valid endpoint", + ), + ) + + if api_key != "": + final_api_key = api_key + else: + final_api_key = prompt( + "> Enter the API key (leave empty if no key is needed): ", + ) + + # Prepare config dict before writing it + config_dict = { + "endpoint": final_endpoint, + } + if final_api_key != "": + config_dict["api_key"] = final_api_key with open(config_path, "w") as f: f.write( yaml.dump( - { - "endpoint": final_endpoint, - }, + config_dict, sort_keys=True, ) ) diff --git a/src/llama_stack_client/lib/cli/llama_stack_client.py b/src/llama_stack_client/lib/cli/llama_stack_client.py index 1bd5cfac..d40f2bc8 100644 --- a/src/llama_stack_client/lib/cli/llama_stack_client.py +++ b/src/llama_stack_client/lib/cli/llama_stack_client.py @@ -35,9 +35,12 @@ @click.option( "--endpoint", type=str, help="Llama Stack distribution endpoint", default="" ) +@click.option( + "--api-key", type=str, help="Llama Stack distribution API key", default="" +) @click.option("--config", type=str, help="Path to config file", default=None) @click.pass_context -def cli(ctx, endpoint: str, config: str | None): +def cli(ctx, endpoint: str, api_key: str, config: str | None): """Welcome to the LlamaStackClient CLI""" ctx.ensure_object(dict) @@ -55,6 +58,7 @@ def cli(ctx, endpoint: str, config: str | None): with open(config, "r") as f: config_dict = yaml.safe_load(f) endpoint = config_dict.get("endpoint", endpoint) + api_key = config_dict.get("api_key", "") except Exception as e: click.echo(f"Error loading config from {config}: {str(e)}", err=True) click.echo("Falling back to HTTP client with endpoint", err=True) @@ -62,6 +66,12 @@ def cli(ctx, endpoint: str, config: str | None): if endpoint == "": endpoint = "http://localhost:8321" + default_headers = {} + if api_key != "": + default_headers = { + "Authorization": f"Bearer {api_key}", + } + client = LlamaStackClient( base_url=endpoint, provider_data={ @@ -69,6 +79,7 @@ def cli(ctx, endpoint: str, config: str | None): "together_api_key": os.environ.get("TOGETHER_API_KEY", ""), "openai_api_key": os.environ.get("OPENAI_API_KEY", ""), }, + default_headers=default_headers, ) ctx.obj = {"client": client}