From dd68f9dc827390988404f593821aea90586b846a Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 19 Dec 2025 10:42:30 +0100 Subject: [PATCH 1/3] WIP: Add serverless GPU compute support to SSH tunnel Jobs API is not yet ready --- experimental/ssh/cmd/connect.go | 21 +- experimental/ssh/cmd/server.go | 4 + experimental/ssh/internal/client/client.go | 201 +++++++++++++----- .../internal/client/ssh-server-bootstrap.py | 48 +++-- experimental/ssh/internal/keys/keys.go | 9 +- experimental/ssh/internal/keys/secrets.go | 12 +- experimental/ssh/internal/server/server.go | 16 +- experimental/ssh/internal/server/sshd.go | 2 +- experimental/ssh/internal/setup/setup.go | 24 ++- experimental/ssh/internal/setup/setup_test.go | 13 +- .../ssh/internal/workspace/workspace.go | 32 +-- 11 files changed, 270 insertions(+), 112 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 7063a5799e..8570ccb144 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -1,6 +1,7 @@ package ssh import ( + "errors" "time" "github.com/databricks/cli/cmd/root" @@ -18,10 +19,18 @@ func newConnectCommand() *cobra.Command { This command establishes an SSH connection to Databricks compute, setting up the SSH server and handling the connection proxy. +For dedicated clusters: + databricks ssh connect --cluster= + +For serverless compute: + databricks ssh connect --name= [--accelerator=] + ` + disclaimer, } var clusterID string + var connectionName string + // var accelerator string var proxyMode bool var serverMetadata string var shutdownDelay time.Duration @@ -31,8 +40,8 @@ the SSH server and handling the connection proxy. var autoStartCluster bool var userKnownHostsFile string - cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)") - cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") @@ -64,9 +73,17 @@ the SSH server and handling the connection proxy. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() wsClient := cmdctx.WorkspaceClient(ctx) + + if !proxyMode && clusterID == "" && connectionName == "" { + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") + } + + // TODO: validate connectionName if provided + opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, + ConnectionName: connectionName, ProxyMode: proxyMode, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go index 77b8c1c156..efe283f28a 100644 --- a/experimental/ssh/cmd/server.go +++ b/experimental/ssh/cmd/server.go @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes. var maxClients int var shutdownDelay time.Duration var clusterID string + var sessionID string var version string var secretScopeName string var authorizedKeySecretName string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)") + cmd.MarkFlagRequired("session-id") cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys") cmd.MarkFlagRequired("secret-scope-name") cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key") @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes. wsc := cmdctx.WorkspaceClient(ctx) opts := server.ServerOptions{ ClusterID: clusterID, + SessionID: sessionID, MaxClients: maxClients, ShutdownDelay: shutdownDelay, Version: version, diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6a294b5935..8b24645c3f 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -36,8 +36,12 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") type ClientOptions struct { - // Id of the cluster to connect to + // Id of the cluster to connect to (for dedicated clusters) ClusterID string + // Connection name (for serverless compute). Used as unique identifier instead of ClusterID. + ConnectionName string + // GPU accelerator type (for serverless compute) + Accelerator string // Delay before shutting down the server after the last client disconnects ShutdownDelay time.Duration // Maximum number of SSH clients @@ -46,7 +50,7 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool - // Expected format: ",". + // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string // How often the CLI should reconnect to the server with new auth. @@ -72,6 +76,19 @@ type ClientOptions struct { UserKnownHostsFile string } +func (o *ClientOptions) IsServerlessMode() bool { + return o.ClusterID == "" +} + +// SessionIdentifier returns the unique identifier for the session. +// For dedicated clusters, this is the cluster ID. For serverless, this is the connection name. +func (o *ClientOptions) SessionIdentifier() string { + if o.IsServerlessMode() { + return o.ConnectionName + } + return o.ClusterID +} + func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -84,22 +101,30 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cancel() }() - err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) - if err != nil { - return err + sessionID := opts.SessionIdentifier() + if sessionID == "" { + return errors.New("either --cluster or --name must be provided") } - secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, opts.ClusterID) + // Only check cluster state for dedicated clusters + if !opts.IsServerlessMode() { + err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) + if err != nil { + return err + } + } + + secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return fmt.Errorf("failed to create secret scope: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) if err != nil { return fmt.Errorf("failed to get or generate SSH key pair from secrets: %w", err) } - keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) + keyPath, err := keys.GetLocalSSHKeyPath(sessionID, opts.SSHKeysDir) if err != nil { return fmt.Errorf("failed to get local keys folder: %w", err) } @@ -113,6 +138,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt var userName string var serverPort int + var clusterID string version := build.GetInfo().Version @@ -121,14 +147,15 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil { return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err) } - userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) + userName, serverPort, clusterID, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) if err != nil { return fmt.Errorf("failed to ensure that ssh server is running: %w", err) } } else { + // Metadata format: ",," metadata := strings.Split(opts.ServerMetadata, ",") - if len(metadata) != 2 { - return fmt.Errorf("invalid metadata: %s, expected format: ,", opts.ServerMetadata) + if len(metadata) < 2 { + return fmt.Errorf("invalid metadata: %s, expected format: ,[,]", opts.ServerMetadata) } userName = metadata[0] if userName == "" { @@ -138,55 +165,88 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err != nil { return fmt.Errorf("cannot parse port from metadata: %s, %w", opts.ServerMetadata, err) } + if len(metadata) >= 3 { + clusterID = metadata[2] + } else { + clusterID = opts.ClusterID + } + } + + // For serverless mode, we need the cluster ID from metadata for Driver Proxy connections + if opts.IsServerlessMode() && clusterID == "" { + return errors.New("cluster ID is required for serverless connections but was not found in metadata") } cmdio.LogString(ctx, "Remote user name: "+userName) cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort)) + if opts.IsServerlessMode() { + cmdio.LogString(ctx, "Cluster ID (from serverless job): "+clusterID) + } if opts.ProxyMode { - return runSSHProxy(ctx, client, serverPort, opts) + return runSSHProxy(ctx, client, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) - return spawnSSHClient(ctx, userName, keyPath, serverPort, opts) + return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, clusterID, version string) (int, string, error) { - serverPort, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, clusterID) +// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +// For dedicated clusters, clusterID should be the same as sessionID. +// For serverless, clusterID is read from the workspace metadata. +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version string) (int, string, string, error) { + wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) if err != nil { - return 0, "", errors.Join(errServerMetadata, err) + return 0, "", "", errors.Join(errServerMetadata, err) + } + cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata)) + + // For serverless mode, the cluster ID comes from the metadata + effectiveClusterID := clusterID + if wsMetadata.ClusterID != "" { + effectiveClusterID = wsMetadata.ClusterID } + + if effectiveClusterID == "" { + return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata")) + } + workspaceID, err := client.CurrentWorkspaceID(ctx) if err != nil { - return 0, "", err + return 0, "", "", err } - metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, clusterID, serverPort) + metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, effectiveClusterID, wsMetadata.Port) + cmdio.LogString(ctx, "Metadata URL: "+metadataURL) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) if err != nil { - return 0, "", err + return 0, "", "", err } if err := client.Config.Authenticate(req); err != nil { - return 0, "", err + return 0, "", "", err } resp, err := http.DefaultClient.Do(req) if err != nil { - return 0, "", err + return 0, "", "", err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return 0, "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) - } - bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return 0, "", err + return 0, "", "", err } - return serverPort, string(bodyBytes), nil + cmdio.LogString(ctx, "Metadata response: "+string(bodyBytes)) + + if resp.StatusCode != http.StatusOK { + return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) + } + + return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil } func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { - contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, opts.ClusterID) + sessionID := opts.SessionIdentifier() + contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { return 0, fmt.Errorf("failed to get workspace content directory: %w", err) } @@ -196,7 +256,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) } - sshTunnelJobName := "ssh-server-bootstrap-" + opts.ClusterID + sshTunnelJobName := "ssh-server-bootstrap-" + sessionID jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap")) notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent)) @@ -212,26 +272,45 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) } + baseParams := map[string]string{ + "version": version, + "secretScopeName": secretScopeName, + "authorizedKeySecretName": opts.ClientPublicKeyName, + "shutdownDelay": opts.ShutdownDelay.String(), + "maxClients": strconv.Itoa(opts.MaxClients), + "sessionId": sessionID, + } + + task := jobs.SubmitTask{ + TaskKey: "start_ssh_server", + NotebookTask: &jobs.NotebookTask{ + NotebookPath: jobNotebookPath, + BaseParameters: baseParams, + }, + TimeoutSeconds: int(opts.ServerTimeout.Seconds()), + } + + if opts.IsServerlessMode() { + task.EnvironmentKey = "ssh-tunnel-serverless" + // TODO: Add GPU accelerator configuration when Jobs API supports it + } else { + task.ExistingClusterId = opts.ClusterID + } + submitRun := jobs.SubmitRun{ RunName: sshTunnelJobName, TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - Tasks: []jobs.SubmitTask{ - { - TaskKey: "start_ssh_server", - NotebookTask: &jobs.NotebookTask{ - NotebookPath: jobNotebookPath, - BaseParameters: map[string]string{ - "version": version, - "secretScopeName": secretScopeName, - "authorizedKeySecretName": opts.ClientPublicKeyName, - "shutdownDelay": opts.ShutdownDelay.String(), - "maxClients": strconv.Itoa(opts.MaxClients), - }, - }, - TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - ExistingClusterId: opts.ClusterID, + Tasks: []jobs.SubmitTask{task}, + } + + if opts.IsServerlessMode() { + env := jobs.JobEnvironment{ + EnvironmentKey: "ssh-tunnel-serverless", + Spec: &compute.Environment{ + EnvironmentVersion: "3", }, - }, + } + submitRun.Environments = []jobs.JobEnvironment{env} } cmdio.LogString(ctx, "Submitting a job to start the ssh server...") @@ -243,12 +322,14 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return runResult.Response.RunId, nil } -func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, opts ClientOptions) error { - proxyCommand, err := setup.GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) +func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { + proxyCommand, err := setup.GenerateProxyCommand(opts.SessionIdentifier(), clusterID, opts.IsServerlessMode(), opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } + hostName := opts.SessionIdentifier() + sshArgs := []string{ "-l", userName, "-i", privateKeyPath, @@ -259,7 +340,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server if opts.UserKnownHostsFile != "" { sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) } - sshArgs = append(sshArgs, opts.ClusterID) + sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) @@ -273,9 +354,9 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server return sshCmd.Run() } -func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, opts ClientOptions) error { +func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { - return createWebsocketConnection(ctx, client, connID, opts.ClusterID, serverPort) + return createWebsocketConnection(ctx, client, connID, clusterID, serverPort) } requestHandoverTick := func() <-chan time.Time { return time.After(opts.HandoverTimeout) @@ -303,14 +384,18 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, return nil } -func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, error) { - serverPort, userName, err := getServerMetadata(ctx, client, opts.ClusterID, version) +func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { + sessionID := opts.SessionIdentifier() + // For dedicated clusters, use clusterID; for serverless, it will be read from metadata + clusterID := opts.ClusterID + + serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version) if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) if err != nil { - return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err) + return "", 0, "", fmt.Errorf("failed to submit ssh server job: %w", err) } cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", runID)) @@ -318,21 +403,21 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC maxRetries := 30 for retries := range maxRetries { if ctx.Err() != nil { - return "", 0, ctx.Err() + return "", 0, "", ctx.Err() } - serverPort, userName, err = getServerMetadata(ctx, client, opts.ClusterID, version) + serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version) if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break } else if retries < maxRetries-1 { time.Sleep(2 * time.Second) } else { - return "", 0, fmt.Errorf("failed to start the ssh server: %w", err) + return "", 0, "", fmt.Errorf("failed to start the ssh server: %w", err) } } } else if err != nil { - return "", 0, err + return "", 0, "", err } - return userName, serverPort, nil + return userName, serverPort, effectiveClusterID, nil } diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py index 8b8170bf42..8dc0aff16a 100644 --- a/experimental/ssh/internal/client/ssh-server-bootstrap.py +++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py @@ -17,6 +17,7 @@ dbutils.widgets.text("authorizedKeySecretName", "") dbutils.widgets.text("maxClients", "10") dbutils.widgets.text("shutdownDelay", "10m") +dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session def cleanup(): @@ -111,6 +112,9 @@ def run_ssh_server(): shutdown_delay = dbutils.widgets.get("shutdownDelay") max_clients = dbutils.widgets.get("maxClients") + session_id = dbutils.widgets.get("sessionId") + if not session_id: + raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.") arch = platform.machine() if arch == "x86_64": @@ -127,29 +131,29 @@ def run_ssh_server(): binary_path = f"/Workspace/Users/{user_name}/.databricks/ssh-tunnel/{version}/{cli_name}/databricks" + server_args = [ + binary_path, + "ssh", + "server", + f"--cluster={ctx.clusterId}", + f"--session-id={session_id}", + f"--secret-scope-name={secrets_scope}", + f"--authorized-key-secret-name={public_key_secret_name}", + f"--max-clients={max_clients}", + f"--shutdown-delay={shutdown_delay}", + f"--version={version}", + # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) + "--log-level=info", + "--log-format=json", + # To get the server logs: + # 1. Get a job run id from the "databricks ssh connect" output + # 2. Run "databricks jobs get-run " and open a run_page_url + # TODO: file with log rotation + "--log-file=stdout", + ] + try: - subprocess.run( - [ - binary_path, - "ssh", - "server", - f"--cluster={ctx.clusterId}", - f"--secret-scope-name={secrets_scope}", - f"--authorized-key-secret-name={public_key_secret_name}", - f"--max-clients={max_clients}", - f"--shutdown-delay={shutdown_delay}", - f"--version={version}", - # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) - "--log-level=info", - "--log-format=json", - # To get the server logs: - # 1. Get a job run id from the "databricks ssh connect" output - # 2. Run "databricks jobs get-run " and open a run_page_url - # TODO: file with log rotation - "--log-file=stdout", - ], - check=True, - ) + subprocess.run(server_args, check=True) finally: kill_all_children() diff --git a/experimental/ssh/internal/keys/keys.go b/experimental/ssh/internal/keys/keys.go index a1c279c749..735f4f0f1a 100644 --- a/experimental/ssh/internal/keys/keys.go +++ b/experimental/ssh/internal/keys/keys.go @@ -14,8 +14,9 @@ import ( "golang.org/x/crypto/ssh" ) -// We use different client keys for each cluster as a good practice for better isolation and control. -func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { +// We use different client keys for each session as a good practice for better isolation and control. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetLocalSSHKeyPath(sessionID, keysDir string) (string, error) { if keysDir == "" { homeDir, err := os.UserHomeDir() if err != nil { @@ -23,7 +24,7 @@ func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { } keysDir = filepath.Join(homeDir, ".databricks", "ssh-tunnel-keys") } - return filepath.Join(keysDir, clusterID), nil + return filepath.Join(keysDir, sessionID), nil } func generateSSHKeyPair() ([]byte, []byte, error) { @@ -68,7 +69,7 @@ func SaveSSHKeyPair(keyPath string, privateKeyBytes, publicKeyBytes []byte) erro return nil } -func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { +func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { privateKeyBytes, err := GetSecret(ctx, client, secretScopeName, privateKeyName) if err != nil { privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair() diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index 0a7b2c1266..eac692f235 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -10,12 +10,14 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) -func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) (string, error) { +// CreateKeysSecretScope creates or retrieves the secret scope for SSH keys. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID string) (string, error) { me, err := client.CurrentUser.Me(ctx) if err != nil { return "", fmt.Errorf("failed to get current user: %w", err) } - secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, clusterID) + secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) @@ -53,8 +55,10 @@ func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, k return nil } -func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID, key, value string) (string, error) { - scopeName, err := CreateKeysSecretScope(ctx, client, clusterID) +// PutSecretInScope creates the secret scope if needed and stores the secret. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID, key, value string) (string, error) { + scopeName, err := CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return "", err } diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go index 66837cbc72..92fa76050a 100644 --- a/experimental/ssh/internal/server/server.go +++ b/experimental/ssh/internal/server/server.go @@ -29,8 +29,11 @@ type ServerOptions struct { MaxClients int // Delay before shutting down the server when there are no active connections ShutdownDelay time.Duration - // The cluster ID that the client started this server on + // The cluster ID that the client started this server on (required for Driver Proxy connections) ClusterID string + // SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). + // Used for metadata storage path. Defaults to ClusterID if not set. + SessionID string // The directory to store sshd configuration ConfigDir string // The name of the secrets scope to use for client and server keys @@ -56,7 +59,12 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt listenAddr := fmt.Sprintf("0.0.0.0:%d", port) log.Info(ctx, "Starting server on "+listenAddr) - err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.ClusterID, port) + // Save metadata including ClusterID (required for Driver Proxy connections in serverless mode) + metadata := &workspace.WorkspaceMetadata{ + Port: port, + ClusterID: opts.ClusterID, + } + err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.SessionID, metadata) if err != nil { return fmt.Errorf("failed to save metadata to the workspace: %w", err) } @@ -77,6 +85,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt connections := proxy.NewConnectionsManager(opts.MaxClients, opts.ShutdownDelay) http.Handle("/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) http.HandleFunc("/metadata", serveMetadata) + + http.Handle("/driver-proxy-http/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) + http.HandleFunc("/driver-proxy-http/metadata", serveMetadata) + go handleTimeout(ctx, connections.TimedOut, opts.ShutdownDelay) return http.ListenAndServe(listenAddr, nil) diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 2f45588a33..7a038e73b5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -36,7 +36,7 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", fmt.Errorf("failed to create SSH directory: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) if err != nil { return "", fmt.Errorf("failed to get SSH key pair from secrets: %w", err) } diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 4359a954da..7961c3b10f 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -56,17 +56,31 @@ func resolveConfigPath(configPath string) (string, error) { return filepath.Join(homeDir, ".ssh", "config"), nil } -func GenerateProxyCommand(clusterId string, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { +// GenerateProxyCommand generates the ProxyCommand string for SSH config. +// sessionID is the unique identifier (cluster ID for dedicated clusters, connection name for serverless). +// clusterID is the actual cluster ID for Driver Proxy connections (same as sessionID for dedicated clusters, +// but obtained from job metadata for serverless). +func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { executablePath, err := os.Executable() if err != nil { return "", fmt.Errorf("failed to get current executable path: %w", err) } - proxyCommand := fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", - executablePath, clusterId, autoStartCluster, shutdownDelay.String()) + var proxyCommand string + if serverlessMode { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --name=%s --shutdown-delay=%s", + executablePath, sessionID, shutdownDelay.String()) + } else { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", + executablePath, clusterID, autoStartCluster, shutdownDelay.String()) + } if userName != "" && serverPort != 0 { - proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + if serverlessMode && clusterID != "" { + proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + "," + clusterID + } else { + proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + } } if handoverTimeout > 0 { @@ -86,7 +100,7 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) + proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) if err != nil { return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 7c4cb20925..27a0ced5bc 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -56,7 +56,7 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "", "", 0, 0) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0) assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +65,7 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -73,6 +73,15 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { assert.Contains(t, cmd, " --profile=test-profile") } +func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { + cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0) + assert.NoError(t, err) + assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") + assert.NotContains(t, cmd, "--cluster=") + assert.NotContains(t, cmd, "--auto-start-cluster") +} + func TestGenerateHostConfig_Valid(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() diff --git a/experimental/ssh/internal/workspace/workspace.go b/experimental/ssh/internal/workspace/workspace.go index 10e593951a..2f017cbae1 100644 --- a/experimental/ssh/internal/workspace/workspace.go +++ b/experimental/ssh/internal/workspace/workspace.go @@ -16,6 +16,8 @@ const metadataFileName = "metadata.json" type WorkspaceMetadata struct { Port int `json:"port"` + // ClusterID is required for Driver Proxy websocket connections (for any compute type, including serverless) + ClusterID string `json:"cluster_id,omitempty"` } func getWorkspaceRootDir(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { @@ -34,49 +36,55 @@ func GetWorkspaceVersionedDir(ctx context.Context, client *databricks.WorkspaceC return filepath.ToSlash(filepath.Join(contentDir, version)), nil } -func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (string, error) { +// GetWorkspaceContentDir returns the directory for storing session content. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (string, error) { contentDir, err := GetWorkspaceVersionedDir(ctx, client, version) if err != nil { return "", fmt.Errorf("failed to get versioned workspace directory: %w", err) } - return filepath.ToSlash(filepath.Join(contentDir, clusterID)), nil + return filepath.ToSlash(filepath.Join(contentDir, sessionID)), nil } -func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (int, error) { - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) +// GetWorkspaceMetadata loads session metadata from the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (*WorkspaceMetadata, error) { + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + return nil, fmt.Errorf("failed to get workspace content directory: %w", err) } metadataPath := filepath.ToSlash(filepath.Join(contentDir, metadataFileName)) content, err := client.Workspace.Download(ctx, metadataPath) if err != nil { - return 0, fmt.Errorf("failed to download metadata file: %w", err) + return nil, fmt.Errorf("failed to download metadata file: %w", err) } defer content.Close() metadataBytes, err := io.ReadAll(content) if err != nil { - return 0, fmt.Errorf("failed to read metadata content: %w", err) + return nil, fmt.Errorf("failed to read metadata content: %w", err) } var metadata WorkspaceMetadata err = json.Unmarshal(metadataBytes, &metadata) if err != nil { - return 0, fmt.Errorf("failed to parse metadata JSON: %w", err) + return nil, fmt.Errorf("failed to parse metadata JSON: %w", err) } - return metadata.Port, nil + return &metadata, nil } -func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string, port int) error { - metadataBytes, err := json.Marshal(WorkspaceMetadata{Port: port}) +// SaveWorkspaceMetadata saves session metadata to the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string, metadata *WorkspaceMetadata) error { + metadataBytes, err := json.Marshal(metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { return fmt.Errorf("failed to get workspace content directory: %w", err) } From 2e6e2a71c4a332368cd790cbb00ab6a4a886b254 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 23 Dec 2025 12:11:25 +0100 Subject: [PATCH 2/3] Add liteswap header value for traffic routing (dev/test only). --- experimental/ssh/cmd/connect.go | 5 +++++ experimental/ssh/internal/client/client.go | 16 +++++++++++----- experimental/ssh/internal/client/websockets.go | 5 ++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 8570ccb144..f44a25cb28 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -39,6 +39,7 @@ For serverless compute: var releasesDir string var autoStartCluster bool var userKnownHostsFile string + var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") @@ -59,6 +60,9 @@ For serverless compute: cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client") cmd.Flags().MarkHidden("user-known-hosts-file") + cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)") + cmd.Flags().MarkHidden("liteswap") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -95,6 +99,7 @@ For serverless compute: ClientPublicKeyName: clientPublicKeyName, ClientPrivateKeyName: clientPrivateKeyName, UserKnownHostsFile: userKnownHostsFile, + Liteswap: liteswap, AdditionalArgs: args, } return client.Run(ctx, wsClient, opts) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 8b24645c3f..46e6380f6a 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -74,6 +74,8 @@ type ClientOptions struct { AdditionalArgs []string // Optional path to the user known hosts file. UserKnownHostsFile string + // Liteswap header value for traffic routing (dev/test only). + Liteswap string } func (o *ClientOptions) IsServerlessMode() bool { @@ -107,7 +109,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } // Only check cluster state for dedicated clusters - if !opts.IsServerlessMode() { + // TODO: we can remove liteswap check when we can start serverless GPU clusters via API. + if !opts.IsServerlessMode() && opts.Liteswap == "" { err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -195,7 +198,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. // For serverless, clusterID is read from the workspace metadata. -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version string) (int, string, string, error) { +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) { wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) if err != nil { return 0, "", "", errors.Join(errServerMetadata, err) @@ -222,6 +225,9 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", err } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) + } if err := client.Config.Authenticate(req); err != nil { return 0, "", "", err } @@ -356,7 +362,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { - return createWebsocketConnection(ctx, client, connID, clusterID, serverPort) + return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap) } requestHandoverTick := func() <-chan time.Time { return time.After(opts.HandoverTimeout) @@ -389,7 +395,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC // For dedicated clusters, use clusterID; for serverless, it will be read from metadata clusterID := opts.ClusterID - serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version) + serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") @@ -405,7 +411,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC if ctx.Err() != nil { return "", 0, "", ctx.Err() } - serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version) + serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break diff --git a/experimental/ssh/internal/client/websockets.go b/experimental/ssh/internal/client/websockets.go index b1ab20889f..fba53c891e 100644 --- a/experimental/ssh/internal/client/websockets.go +++ b/experimental/ssh/internal/client/websockets.go @@ -9,7 +9,7 @@ import ( "github.com/gorilla/websocket" ) -func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int) (*websocket.Conn, error) { +func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int, liteswap string) (*websocket.Conn, error) { url, err := getProxyURL(ctx, client, connID, clusterID, serverPort) if err != nil { return nil, fmt.Errorf("failed to get proxy URL: %w", err) @@ -20,6 +20,9 @@ func createWebsocketConnection(ctx context.Context, client *databricks.Workspace return nil, fmt.Errorf("failed to create request: %w", err) } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) + } if err := client.Config.Authenticate(req); err != nil { return nil, fmt.Errorf("failed to authenticate: %w", err) } From 92fea47a1d86e00f67158ffb9c3d5b877b412f6e Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 23 Dec 2025 14:30:05 +0100 Subject: [PATCH 3/3] Add liteswap option to the ProxyCommand. --- experimental/ssh/internal/setup/setup.go | 8 ++++++-- experimental/ssh/internal/setup/setup_test.go | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 7961c3b10f..6c86087413 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -60,7 +60,7 @@ func resolveConfigPath(configPath string) (string, error) { // sessionID is the unique identifier (cluster ID for dedicated clusters, connection name for serverless). // clusterID is the actual cluster ID for Driver Proxy connections (same as sessionID for dedicated clusters, // but obtained from job metadata for serverless). -func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { +func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration, liteswap string) (string, error) { executablePath, err := os.Executable() if err != nil { return "", fmt.Errorf("failed to get current executable path: %w", err) @@ -91,6 +91,10 @@ func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStart proxyCommand += " --profile=" + profile } + if liteswap != "" { + proxyCommand += " --liteswap=" + liteswap + } + return proxyCommand, nil } @@ -100,7 +104,7 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) + proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0, "") if err != nil { return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 27a0ced5bc..f2e1bf6c1b 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -56,7 +56,7 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +65,7 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -74,7 +74,7 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { } func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { - cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0) + cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id")