From 853c24989f14a7b58104f31cafcf220089da738d Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 26 Feb 2025 10:04:54 -0500 Subject: [PATCH] add support for chat sessions today, to chat with a model, a user has to run one command per completion add --session which enables a user to have an interactive chat session with the inference model. --session can be passes with or without --message. If no --message is passed, the user is prompted to give the first message ``` llama-stack-client inference chat-completion --session >>> hi whats up! Assistant> Not much! How's your day going so far? Is there something I can help you with or would you like to chat? >>> what color is the sky? Assistant> The color of the sky can vary depending on the time of day and atmospheric conditions. Here are some common colors you might see: * During the daytime, when the sun is overhead, the sky typically appears blue. * At sunrise and sunset, the sky can take on hues of red, orange, pink, and purple due to the scattering of light by atmospheric particles. * On a clear day with no clouds, the sky can appear a bright blue, often referred to as "cerulean." * In areas with high levels of pollution or dust, the sky can appear more hazy or grayish. * At night, the sky can be dark and black, although some stars and moonlight can make it visible. So, what's your favorite color of the sky? >>> ``` Signed-off-by: Charlie Doern --- .../lib/cli/inference/inference.py | 64 +++++++++++++++---- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/src/llama_stack_client/lib/cli/inference/inference.py b/src/llama_stack_client/lib/cli/inference/inference.py index 7c4562a4..7280ceff 100644 --- a/src/llama_stack_client/lib/cli/inference/inference.py +++ b/src/llama_stack_client/lib/cli/inference/inference.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Optional, List, Dict +import traceback import click from rich.console import Console @@ -19,13 +20,20 @@ def inference(): @click.command("chat-completion") -@click.option("--message", required=True, help="Message") +@click.option("--message", help="Message") @click.option("--stream", is_flag=True, help="Streaming", default=False) +@click.option("--session", is_flag=True, help="Start a Chat Session", default=False) @click.option("--model-id", required=False, help="Model ID") @click.pass_context @handle_client_errors("inference chat-completion") -def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]): +def chat_completion(ctx, message: str, stream: bool, session: bool, model_id: Optional[str]): """Show available inference chat completion endpoints on distribution endpoint""" + if not message and not session: + click.secho( + "you must specify either --message or --session", + fg="red", + ) + raise click.exceptions.Exit(1) client = ctx.obj["client"] console = Console() @@ -33,16 +41,46 @@ def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]): available_models = [model.identifier for model in client.models.list() if model.model_type == "llm"] model_id = available_models[0] - response = client.inference.chat_completion( - model_id=model_id, - messages=[{"role": "user", "content": message}], - stream=stream, - ) - if not stream: - console.print(response) - else: - for event in EventLogger().log(response): - event.print() + messages = [] + if message: + messages.append({"role": "user", "content": message}) + response = client.inference.chat_completion( + model_id=model_id, + messages=messages, + stream=stream, + ) + if not stream: + console.print(response) + else: + for event in EventLogger().log(response): + event.print() + if session: + chat_session(client=client, model_id=model_id, messages=messages, console=console) + + +def chat_session(client, model_id: Optional[str], messages: List[Dict[str, str]], console: Console): + """Run an interactive chat session with the served model""" + while True: + try: + message = input(">>> ") + if message in ["\\q", "quit"]: + console.print("Exiting") + break + messages.append({"role": "user", "content": message}) + response = client.inference.chat_completion( + model_id=model_id, + messages=messages, + stream=True, + ) + for event in EventLogger().log(response): + event.print() + except Exception as exc: + traceback.print_exc() + console.print(f"Error in chat session {exc}") + break + except KeyboardInterrupt as exc: + console.print("\nDetected user interrupt, exiting") + break # Register subcommands