From 1bc31a57e5f91c5a847754bdbdba2053e43a6cda Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Fri, 12 Dec 2025 14:05:05 -0800 Subject: [PATCH] refactor config dir into common internal helper --- auth/credentials.go | 31 ++++----------------- drivers.go | 10 +++---- internal/dirs.go | 25 +++++++++++++++++ internal/dirs_test.go | 63 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 internal/dirs.go create mode 100644 internal/dirs_test.go diff --git a/auth/credentials.go b/auth/credentials.go index 2bd67d64..82465dc9 100644 --- a/auth/credentials.go +++ b/auth/credentials.go @@ -22,10 +22,10 @@ import ( "net/url" "os" "path/filepath" - "runtime" "slices" "sync" + "github.com/columnar-tech/dbc/internal" "github.com/pelletier/go-toml/v2" ) @@ -136,32 +136,11 @@ func init() { } func getCredentialPath() (string, error) { - dir := os.Getenv("XDG_DATA_HOME") - if dir == "" { - switch runtime.GOOS { - case "windows": - dir = os.Getenv("LocalAppData") - if dir == "" { - return "", errors.New("%LocalAppData% is not set") - } - case "darwin": - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %w", err) - } - dir = filepath.Join(home, "Library") - default: // unix - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %w", err) - } - dir = filepath.Join(home, ".local", "share") - } - } else if !filepath.IsAbs(dir) { - return "", errors.New("path in $XDG_DATA_HOME is relative") + dbcConfigDir, err := internal.GetDbcConfigDir() + if err != nil { + return "", fmt.Errorf("failed to get credentials path: %v", err) } - - return filepath.Join(dir, "dbc", "credentials", "credentials.toml"), nil + return filepath.Join(dbcConfigDir, "credentials", "credentials.toml"), nil } func loadCreds() ([]Credential, error) { diff --git a/drivers.go b/drivers.go index 3b9473e2..9e286ca9 100644 --- a/drivers.go +++ b/drivers.go @@ -34,6 +34,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gopenpgp/v3/crypto" "github.com/columnar-tech/dbc/auth" + "github.com/columnar-tech/dbc/internal" "github.com/go-faster/yaml" "github.com/google/uuid" machineid "github.com/zeroshade/machine-id" @@ -108,7 +109,7 @@ func init() { mid, _ = machineid.ProtectedID() // get user config dir - userdir, err := os.UserConfigDir() + dbcConfigDir, err := internal.GetDbcConfigDir() if err != nil { // if we can't get the dir for some reason, just generate a new UUID uid = uuid.New() @@ -116,12 +117,7 @@ func init() { } // try to read the existing UUID file - dirname := "columnar" - if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { - dirname = "Columnar" - } - - fp := filepath.Join(userdir, dirname, "dbc", "uid.uuid") + fp := filepath.Join(dbcConfigDir, "uid.uuid") data, err := os.ReadFile(fp) if err == nil { if err = uid.UnmarshalBinary(data); err == nil { diff --git a/internal/dirs.go b/internal/dirs.go new file mode 100644 index 00000000..ebb6ec08 --- /dev/null +++ b/internal/dirs.go @@ -0,0 +1,25 @@ +package internal + +import ( + "fmt" + "os" + "path/filepath" + "runtime" +) + +// Get a platform-specific config dir for reading and writing dbc config files +// and credentails +func GetDbcConfigDir() (string, error) { + userdir, err := os.UserConfigDir() + if err != nil { + return "", fmt.Errorf("failed to get dbc configuration directory: %v", err) + } + + orgDirName := "columnar" + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + orgDirName = "Columnar" + } + dbcDirName := "dbc" + + return filepath.Join(userdir, orgDirName, dbcDirName), nil +} diff --git a/internal/dirs_test.go b/internal/dirs_test.go new file mode 100644 index 00000000..a8d4b011 --- /dev/null +++ b/internal/dirs_test.go @@ -0,0 +1,63 @@ +package internal + +import ( + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetDbcConfigDirBasicTest(t *testing.T) { + dir, err := GetDbcConfigDir() + require.NoError(t, err) + assert.NotEmpty(t, dir) + assert.True(t, filepath.IsAbs(dir), "should return absolute path") +} + +func TestGetDbcConfigDirCapitalization(t *testing.T) { + dir, err := GetDbcConfigDir() + require.NoError(t, err) + assert.NotEmpty(t, dir) + + parent := filepath.Dir(dir) + orgName := filepath.Base(parent) + + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + assert.Equal(t, "Columnar", orgName) + } else { + assert.Equal(t, "columnar", orgName) + } +} + +func TestGetDbcConfigDirDarwin(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("skipping macOS-specific test") + } + dir, err := GetDbcConfigDir() + require.NoError(t, err) + home, _ := os.UserHomeDir() + assert.Equal(t, filepath.Join(home, "Library/Application Support/Columnar/dbc"), dir) +} + +func TestGetDbcConfigDirLinux(t *testing.T) { + if runtime.GOOS != "linux" { + t.Skip("skipping Linux-specific test") + } + dir, err := GetDbcConfigDir() + require.NoError(t, err) + home, _ := os.UserHomeDir() + assert.Equal(t, filepath.Join(home, ".config/columnar/dbc"), dir) +} + +func TestGetDbcConfigDirWindows(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("skipping Windows-specific test") + } + dir, err := GetDbcConfigDir() + require.NoError(t, err) + appData := os.Getenv("AppData") + assert.Equal(t, filepath.Join(appData, "Columnar", "dbc"), dir) +}