diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 5f2f7641..58e92913 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -8,6 +8,7 @@ import ( "io" "os" "os/signal" + "strconv" "strings" "syscall" @@ -15,6 +16,8 @@ import ( "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/readline" + "github.com/docker/model-runner/pkg/inference" + "github.com/docker/model-runner/pkg/inference/scheduling" "github.com/fatih/color" "github.com/muesli/termenv" "github.com/spf13/cobra" @@ -90,11 +93,12 @@ func readMultilineInput(cmd *cobra.Command, scanner *bufio.Scanner) (string, err func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { usage := func() { fmt.Fprintln(os.Stderr, "Available Commands:") - fmt.Fprintln(os.Stderr, " /set system Set or update the system message") fmt.Fprintln(os.Stderr, " /bye Exit") + fmt.Fprintln(os.Stderr, " /set Set a session variable") fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? files Help for file inclusion with @ symbol") + fmt.Fprintln(os.Stderr, " /? set Help for /set command") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, `Use """ to begin a multi-line message.`) fmt.Fprintln(os.Stderr, "") @@ -134,6 +138,14 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. fmt.Fprintln(os.Stderr, "") } + usageSet := func() { + fmt.Fprintln(os.Stderr, "Available /set commands:") + fmt.Fprintln(os.Stderr, " /set system Set system message for the conversation") + fmt.Fprintln(os.Stderr, " /set num_ctx Set context window size (in tokens)") + fmt.Fprintln(os.Stderr, " /set parameter num_ctx Set context window size (in tokens) [deprecated]") + fmt.Fprintln(os.Stderr, "") + } + scanner, err := readline.New(readline.Prompt{ Prompt: "> ", AltPrompt: ". ", @@ -204,36 +216,103 @@ func generateInteractiveWithReadline(cmd *cobra.Command, desktopClient *desktop. case scanner.Pasting: fmt.Fprintln(&sb, line) continue - case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): + case strings.HasPrefix(line, "/"): args := strings.Fields(line) - if len(args) > 1 { + switch args[0] { + case "/help", "/?": + if len(args) > 1 { + switch args[1] { + case "shortcut", "shortcuts": + usageShortcuts() + case "file", "files": + usageFiles() + case "set": + usageSet() + default: + usage() + } + } else { + usage() + } + case "/exit", "/bye": + return nil + case "/set": + if len(args) < 2 { + usageSet() + continue + } switch args[1] { - case "shortcut", "shortcuts": - usageShortcuts() - case "file", "files": - usageFiles() + case "system": + // Extract the system prompt text after "/set system" + if len(args) > 2 { + systemPrompt = strings.Join(args[2:], " ") + } else { + systemPrompt = "" + } + if systemPrompt == "" { + fmt.Fprintln(os.Stderr, "Cleared system message.") + } else { + fmt.Fprintln(os.Stderr, "Set system message.") + } + case "num_ctx": + // Handle /set num_ctx syntax + if len(args) < 3 { + fmt.Fprintln(os.Stderr, "Usage: /set num_ctx ") + continue + } + paramValue := args[2] + if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 { + ctx := int32(val) + if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ + Model: model, + BackendConfiguration: inference.BackendConfiguration{ + ContextSize: &ctx, + }, + }); err != nil { + fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val) + } + } else { + fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue) + } + case "parameter": + // Handle legacy /set parameter syntax for backward compatibility + if len(args) < 4 { + fmt.Fprintln(os.Stderr, "Usage: /set parameter ") + fmt.Fprintln(os.Stderr, "Available parameters: num_ctx") + continue + } + paramName, paramValue := args[2], args[3] + switch paramName { + case "num_ctx": + if val, err := strconv.ParseInt(paramValue, 10, 32); err == nil && val > 0 { + ctx := int32(val) + if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ + Model: model, + BackendConfiguration: inference.BackendConfiguration{ + ContextSize: &ctx, + }, + }); err != nil { + fmt.Fprintf(os.Stderr, "Failed to set num_ctx: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "Set num_ctx to %d\n", val) + } + } else { + fmt.Fprintf(os.Stderr, "Invalid value for num_ctx: %s (must be a positive integer)\n", paramValue) + } + default: + fmt.Fprintf(os.Stderr, "Unknown parameter: %s\n", paramName) + fmt.Fprintln(os.Stderr, "Available parameters: num_ctx") + } default: - usage() + fmt.Fprintf(os.Stderr, "Unknown /set option: %s\n", args[1]) + usageSet() } - } else { - usage() - } - continue - case strings.HasPrefix(line, "/set system ") || line == "/set system": - // Extract the system prompt text after "/set system " - systemPrompt = strings.TrimPrefix(line, "/set system ") - systemPrompt = strings.TrimSpace(systemPrompt) - if systemPrompt == "" { - fmt.Fprintln(os.Stderr, "Cleared system message.") - } else { - fmt.Fprintln(os.Stderr, "Set system message.") + default: + fmt.Printf("Unknown command '%s'. Type /? for help\n", args[0]) } continue - case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"): - return nil - case strings.HasPrefix(line, "/"): - fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0]) - continue default: sb.WriteString(line) }