diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6c4695b6ac..63141e9497 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -12,7 +12,6 @@ import ( "os/exec" "os/signal" "path/filepath" - "regexp" "strconv" "strings" "syscall" @@ -20,6 +19,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -281,7 +281,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k databricksUserName := currentUser.UserName // Ensure SSH config entry exists - configPath, err := getSSHConfigPath() + configPath, err := sshconfig.GetMainConfigPath() if err != nil { return fmt.Errorf("failed to get SSH config path: %w", err) } @@ -310,47 +310,11 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return ideCmd.Run() } -func getSSHConfigPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - // Ensure SSH directory and config file exist - sshDir := filepath.Dir(configPath) - err := os.MkdirAll(sshDir, 0o700) + // Ensure the Include directive exists in the main SSH config + err := sshconfig.EnsureIncludeDirective(configPath) if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - - _, err = os.Stat(configPath) - if os.IsNotExist(err) { - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - - // Check if the host entry already exists - existingContent, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - hostPattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.Match(hostPattern, existingContent) - if err != nil { - return fmt.Errorf("failed to check for existing host: %w", err) - } - - if matched { - cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists", hostName)) - return nil + return err } // Generate ProxyCommand with server metadata @@ -362,30 +326,14 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - // Generate host config - hostConfig := fmt.Sprintf(` -Host %s - User %s - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, hostName, userName, keyPath, proxyCommand) - - // Append to config file - content := string(existingContent) - if !strings.HasSuffix(content, "\n") && content != "" { - content += "\n" - } - content += hostConfig + hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) - err = os.WriteFile(configPath, []byte(content), 0o600) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) + return err } - cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) + cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) return nil } diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index eac692f235..d4e00d10ba 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -18,10 +18,23 @@ func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClie return "", fmt.Errorf("failed to get current user: %w", err) } secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) + + // Do not create the scope if it already exists. + // We can instead filter out "resource already exists" errors from CreateScope, + // but that API can also lead to "limit exceeded" errors, even if the scope does actually exist. + scope, err := client.Secrets.ListSecretsByScope(ctx, secretScopeName) + if err != nil && !errors.Is(err, databricks.ErrResourceDoesNotExist) { + return "", fmt.Errorf("failed to check if secret scope %s exists: %w", secretScopeName, err) + } + + if scope != nil && err == nil { + return secretScopeName, nil + } + err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) - if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) { + if err != nil { return "", fmt.Errorf("failed to create secrets scope: %w", err) } return secretScopeName, nil diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 7a038e73b5..c8f23d02a5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -52,15 +52,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", err } - // Set all available env vars, wrapping values in quotes and escaping quotes inside values + // Set all available env vars, wrapping values in quotes, escaping quotes, and stripping newlines setEnv := "SetEnv" for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue + if len(parts) == 2 { + setEnv += " " + parts[0] + "=\"" + escapeEnvValue(parts[1]) + "\"" } - valEscaped := strings.ReplaceAll(parts[1], "\"", "\\\"") - setEnv += " " + parts[0] + "=\"" + valEscaped + "\"" } setEnv += " DATABRICKS_CLI_UPSTREAM=databricks_ssh_tunnel" setEnv += " DATABRICKS_CLI_UPSTREAM_VERSION=" + opts.Version @@ -94,3 +92,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i") } + +// escapeEnvValue escapes a value for use in sshd SetEnv directive. +// It strips newlines and escapes backslashes and quotes. +func escapeEnvValue(val string) string { + val = strings.ReplaceAll(val, "\r", "") + val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\\", "\\\\") + val = strings.ReplaceAll(val, "\"", "\\\"") + return val +} diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go new file mode 100644 index 0000000000..a453d987a0 --- /dev/null +++ b/experimental/ssh/internal/server/sshd_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeEnvValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple value", + input: "hello", + expected: "hello", + }, + { + name: "value with quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "value with newline", + input: "line1\nline2", + expected: "line1line2", + }, + { + name: "value with carriage return", + input: "line1\rline2", + expected: "line1line2", + }, + { + name: "value with CRLF", + input: "line1\r\nline2", + expected: "line1line2", + }, + { + name: "value with quotes and newlines", + input: "say \"hello\"\nworld", + expected: `say \"hello\"world`, + }, + { + name: "empty value", + input: "", + expected: "", + }, + { + name: "only newlines", + input: "\n\r\n", + expected: "", + }, + { + name: "backslashes", + input: `foo\bar\`, + expected: `foo\\bar\\`, + }, + { + name: "backslash before quote", + input: `foo\"bar`, + expected: `foo\\\"bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeEnvValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 0d76071a65..99b5a68902 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,13 +4,10 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "regexp" - "strings" "time" "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" @@ -46,97 +43,16 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func resolveConfigPath(configPath string) (string, error) { - if configPath != "" { - return configPath, nil - } - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - hostConfig := fmt.Sprintf(` -Host %s - User root - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, opts.HostName, identityFilePath, opts.ProxyCommand) - + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) return hostConfig, nil } -func ensureSSHConfigExists(configPath string) error { - _, err := os.Stat(configPath) - if os.IsNotExist(err) { - sshDir := filepath.Dir(configPath) - err = os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - return nil - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - return nil -} - -func checkExistingHosts(content []byte, hostName string) (bool, error) { - existingContent := string(content) - pattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.MatchString(pattern, existingContent) - if err != nil { - return false, fmt.Errorf("failed to check for existing host: %w", err) - } - if matched { - return true, nil - } - return false, nil -} - -func createBackup(content []byte, configPath string) (string, error) { - backupPath := configPath + ".bak" - err := os.WriteFile(backupPath, content, 0o600) - if err != nil { - return backupPath, fmt.Errorf("failed to create backup of SSH config file: %w", err) - } - return backupPath, nil -} - -func updateSSHConfigFile(configPath, hostConfig, hostName string) error { - content, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - existingContent := string(content) - if !strings.HasSuffix(existingContent, "\n") && existingContent != "" { - existingContent += "\n" - } - newContent := existingContent + hostConfig - - err = os.WriteFile(configPath, []byte(newContent), 0o600) - if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) - } - - return nil -} - func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") @@ -174,50 +90,51 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - configPath, err := resolveConfigPath(opts.SSHConfigPath) + configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath) if err != nil { return err } - hostConfig, err := generateHostConfig(opts) + err = sshconfig.EnsureIncludeDirective(configPath) if err != nil { return err } - err = ensureSSHConfigExists(configPath) + hostConfig, err := generateHostConfig(opts) if err != nil { return err } - existingContent, err := os.ReadFile(configPath) + exists, err := sshconfig.HostConfigExists(opts.HostName) if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) + return err } - if len(existingContent) > 0 { - exists, err := checkExistingHosts(existingContent, opts.HostName) + recreate := false + if exists { + recreate, err = sshconfig.PromptRecreateConfig(ctx, opts.HostName) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("Host '%s' already exists in the SSH config, skipping setup", opts.HostName)) + if !recreate { + cmdio.LogString(ctx, fmt.Sprintf("Skipping setup for host '%s'", opts.HostName)) return nil } - backupPath, err := createBackup(existingContent, configPath) - if err != nil { - return err - } - cmdio.LogString(ctx, "Created backup of existing SSH config at "+backupPath) } cmdio.LogString(ctx, "Adding new entry to the SSH config:\n"+hostConfig) - err = updateSSHConfigFile(configPath, hostConfig, opts.HostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, opts.HostName, hostConfig, recreate) + if err != nil { + return err + } + + hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config file at %s with '%s' host", configPath, opts.HostName)) + cmdio.LogString(ctx, fmt.Sprintf("Created SSH config file at %s for '%s' host", hostConfigPath, opts.HostName)) cmdio.LogString(ctx, fmt.Sprintf("You can now connect to the cluster using 'ssh %s' terminal command, or use remote capabilities of your IDE", opts.HostName)) return nil } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index aa803dfe1c..975828a3c8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -118,15 +118,24 @@ func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { } func TestGenerateHostConfig_Valid(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "test-profile", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) @@ -139,29 +148,35 @@ func TestGenerateHostConfig_Valid(t *testing.T) { assert.Contains(t, result, "--shutdown-delay=30s") assert.Contains(t, result, "--profile=test-profile") - // Check that identity file path is included expectedKeyPath := filepath.Join(tmpDir, "cluster-123") assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedKeyPath)) } func TestGenerateHostConfig_WithoutProfile(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, - Profile: "", // No profile + Profile: "", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) assert.NoError(t, err) - // Should not contain profile option assert.NotContains(t, result, "--profile=") - // But should contain other elements assert.Contains(t, result, "Host test-host") assert.Contains(t, result, "--cluster=cluster-123") } @@ -187,181 +202,12 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } -func TestEnsureSSHConfigExists(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, ".ssh", "config") - - err := ensureSSHConfigExists(configPath) - assert.NoError(t, err) - - // Check that directory was created - _, err = os.Stat(filepath.Dir(configPath)) - assert.NoError(t, err) - - // Check that file was created - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Check that file is empty - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Empty(t, content) -} - -func TestCheckExistingHosts_NoExistingHost(t *testing.T) { - content := []byte(`Host other-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostAlreadyExists(t *testing.T) { - content := []byte(`Host test-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "another-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_EmptyContent(t *testing.T) { - content := []byte("") - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostNameWithWhitespaces(t *testing.T) { - content := []byte(` Host test-host `) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_PartialNameMatch(t *testing.T) { - content := []byte(`Host test-host-long`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCreateBackup_CreatesBackupSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - content := []byte("original content") - - backupPath, err := createBackup(content, configPath) - assert.NoError(t, err) - assert.Equal(t, configPath+".bak", backupPath) - - // Check that backup file was created with correct content - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, content, backupContent) -} - -func TestCreateBackup_OverwritesExistingBackup(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - backupPath := configPath + ".bak" - - // Create existing backup - oldContent := []byte("old backup") - err := os.WriteFile(backupPath, oldContent, 0o644) - require.NoError(t, err) - - // Create new backup - newContent := []byte("new content") - resultPath, err := createBackup(newContent, configPath) - assert.NoError(t, err) - assert.Equal(t, backupPath, resultPath) - - // Check that backup was overwritten - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, newContent, backupContent) -} - -func TestUpdateSSHConfigFile_UpdatesSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create initial config file - initialContent := "# SSH Config\nHost existing\n User root\n" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n HostName example.com\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was appended - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_AddsNewlineIfMissing(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create config file without trailing newline - initialContent := "Host existing\n User root" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that newline was added before the new content - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + "\n" + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesEmptyFile(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create empty config file - err := os.WriteFile(configPath, []byte(""), 0o600) - require.NoError(t, err) - - hostConfig := "Host new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was added without extra newlines - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Equal(t, hostConfig, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) { - configPath := "/nonexistent/file" - hostConfig := "Host new-host\n" - - err := updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read SSH config file") -} - func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") m := mocks.NewMockWorkspaceClient(t) @@ -380,22 +226,43 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - err := Setup(ctx, m.WorkspaceClient, opts) + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand + + err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) - // Check that config file was created + // Check that main config has Include directive content, err := os.ReadFile(configPath) assert.NoError(t, err) - configStr := string(content) - assert.Contains(t, configStr, "Host test-host") - assert.Contains(t, configStr, "--cluster=cluster-123") - assert.Contains(t, configStr, "--profile=test-profile") + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host test-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-123") + assert.Contains(t, hostConfigStr, "--profile=test-profile") } func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") // Create existing config file @@ -418,54 +285,34 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - err = Setup(ctx, m.WorkspaceClient, opts) - assert.NoError(t, err) - - // Check that config file was updated and backup was created - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - - configStr := string(content) - assert.Contains(t, configStr, "# Existing SSH Config") // Original content preserved - assert.Contains(t, configStr, "Host new-host") // New content added - assert.Contains(t, configStr, "--cluster=cluster-456") - - // Check backup was created - backupPath := configPath + ".bak" - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, existingContent, string(backupContent)) -} - -func TestSetup_DoesNotOverrideExistingHost(t *testing.T) { - ctx := cmdio.MockDiscard(context.Background()) - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "ssh_config") - - // Create config file with existing host - existingContent := "Host duplicate-host\n User root\n" - err := os.WriteFile(configPath, []byte(existingContent), 0o600) - require.NoError(t, err) - - m := mocks.NewMockWorkspaceClient(t) - clustersAPI := m.GetMockClustersAPI() - - clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "cluster-123"}).Return(&compute.ClusterDetails{ - DataSecurityMode: compute.DataSecurityModeSingleUser, - }, nil) - - opts := SetupOptions{ - HostName: "duplicate-host", // Same as existing - ClusterID: "cluster-123", - SSHConfigPath: configPath, - SSHKeysDir: tmpDir, - ShutdownDelay: 30 * time.Second, + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) + // Check that main config has Include directive and preserves existing content content, err := os.ReadFile(configPath) assert.NoError(t, err) - assert.Equal(t, "Host duplicate-host\n User root\n", string(content)) + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "# Existing SSH Config") + assert.Contains(t, configStr, "Host existing-host") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "new-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host new-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-456") } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go new file mode 100644 index 0000000000..3a6713acbf --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -0,0 +1,172 @@ +package sshconfig + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +const ( + // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. + configDirName = ".databricks/ssh-tunnel-configs" +) + +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, configDirName), nil +} + +func GetMainConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func GetMainConfigPathOrDefault(configPath string) (string, error) { + if configPath != "" { + return configPath, nil + } + return GetMainConfigPath() +} + +func EnsureMainConfigExists(configPath string) error { + _, err := os.Stat(configPath) + if os.IsNotExist(err) { + sshDir := filepath.Dir(configPath) + err = os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + return nil + } + return err +} + +func EnsureIncludeDirective(configPath string) error { + configDir, err := GetConfigDir() + if err != nil { + return err + } + + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create Databricks SSH config directory: %w", err) + } + + err = EnsureMainConfigExists(configPath) + if err != nil { + return err + } + + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + // Convert path to forward slashes for SSH config compatibility across platforms + configDirUnix := filepath.ToSlash(configDir) + + includeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if strings.Contains(string(content), includeLine) { + return nil + } + + newContent := includeLine + "\n" + if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { + newContent += "\n" + } + newContent += string(content) + + err = os.WriteFile(configPath, []byte(newContent), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file with Include directive: %w", err) + } + + return nil +} + +func GetHostConfigPath(hostName string) (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, hostName), nil +} + +func HostConfigExists(hostName string) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check host config file: %w", err) + } + return true, nil +} + +// Returns true if the config was created/updated, false if it was skipped. +func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + + exists, err := HostConfigExists(hostName) + if err != nil { + return false, err + } + + if exists && !recreate { + return false, nil + } + + configDir := filepath.Dir(configPath) + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return false, fmt.Errorf("failed to create config directory: %w", err) + } + + err = os.WriteFile(configPath, []byte(hostConfig), 0o600) + if err != nil { + return false, fmt.Errorf("failed to write host config file: %w", err) + } + + return true, nil +} + +func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { + response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) + if err != nil { + return false, err + } + return response, nil +} + +func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, identityFile, proxyCommand) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go new file mode 100644 index 0000000000..5fa13923ee --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -0,0 +1,223 @@ +package sshconfig + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + assert.NoError(t, err) + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs")) +} + +func TestGetMainConfigPath(t *testing.T) { + path, err := GetMainConfigPath() + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestGetMainConfigPathOrDefault(t *testing.T) { + path, err := GetMainConfigPathOrDefault("/custom/path") + assert.NoError(t, err) + assert.Equal(t, "/custom/path", path) + + path, err = GetMainConfigPathOrDefault("") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestEnsureMainConfigExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureMainConfigExists(configPath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Dir(configPath)) + assert.NoError(t, err) + + _, err = os.Stat(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Empty(t, content) +} + +func TestEnsureIncludeDirective_NewConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") +} + +func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir() + require.NoError(t, err) + + // Use forward slashes as that's what SSH config uses + configDirUnix := filepath.ToSlash(configDir) + existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingContent, string(content)) +} + +func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + existingContent := "Host example\n User test\n" + err := os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "Host example") + + includeIndex := len("Include") + hostIndex := len(configStr) - len(existingContent) + assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") +} + +func TestGetHostConfigPath(t *testing.T) { + path, err := GetHostConfigPath("test-host") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host")) +} + +func TestHostConfigExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + exists, err := HostConfigExists("nonexistent") + assert.NoError(t, err) + assert.False(t, exists) + + configDir := filepath.Join(tmpDir, configDirName) + err = os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) + require.NoError(t, err) + + exists, err = HostConfigExists("existing-host") + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + hostConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, hostConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, false) + assert.NoError(t, err) + assert.False(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, true) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, newConfig, string(content)) +}