diff --git a/cmd/dbc/main.go b/cmd/dbc/main.go index d559aeb..5d7a4c0 100644 --- a/cmd/dbc/main.go +++ b/cmd/dbc/main.go @@ -153,7 +153,8 @@ func formatErr(err error) string { case errors.Is(err, auth.ErrNoTrialLicense): return errStyle.Render("Could not download license, trial not started") case errors.Is(err, dbc.ErrUnauthorized): - return errStyle.Render(err.Error()) + return errStyle.Render(err.Error()) + "\n" + + msgStyle.Render("Did you run `dbc auth login`?") case errors.Is(err, dbc.ErrUnauthorizedColumnar): return errStyle.Render(err.Error()) + "\n" + msgStyle.Render("Installing this driver requires a license. Verify you have an active license at https://console.columnar.tech/licenses and try this command again. Contact support@columnar.tech if you believe this is an error.") diff --git a/cmd/dbc/main_test.go b/cmd/dbc/main_test.go index 206945f..9981a42 100644 --- a/cmd/dbc/main_test.go +++ b/cmd/dbc/main_test.go @@ -17,13 +17,55 @@ package main import ( "bytes" "context" + "fmt" + "strings" "testing" "time" tea "github.com/charmbracelet/bubbletea" + "github.com/columnar-tech/dbc" "github.com/stretchr/testify/require" ) +func TestFormatErr(t *testing.T) { + tests := []struct { + name string + err error + wantSubstring []string + }{ + { + name: "ErrUnauthorized direct", + err: dbc.ErrUnauthorized, + wantSubstring: []string{dbc.ErrUnauthorized.Error(), "Did you run `dbc auth login`?"}, + }, + { + name: "ErrUnauthorized wrapped", + err: fmt.Errorf("operation failed: %w", dbc.ErrUnauthorized), + wantSubstring: []string{dbc.ErrUnauthorized.Error(), "Did you run `dbc auth login`?"}, + }, + { + name: "ErrUnauthorizedColumnar direct", + err: dbc.ErrUnauthorizedColumnar, + wantSubstring: []string{dbc.ErrUnauthorizedColumnar.Error(), "active license", "support@columnar.tech"}, + }, + { + name: "ErrUnauthorizedColumnar wrapped", + err: fmt.Errorf("operation failed: %w", dbc.ErrUnauthorizedColumnar), + wantSubstring: []string{dbc.ErrUnauthorizedColumnar.Error(), "active license", "support@columnar.tech"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatErr(tt.err) + for _, want := range tt.wantSubstring { + require.True(t, strings.Contains(got, want), + "formatErr(%v) = %q, expected to contain %q", tt.err, got, want) + } + }) + } +} + func TestCmdStatus(t *testing.T) { tests := []struct { name string