From a4e263cf981eb98fa72a6297c40daf0a7fc830f7 Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Mon, 22 Dec 2025 11:14:26 +0000 Subject: [PATCH 1/7] custom scopes support in u2m --- config/auth_u2m.go | 30 +++-- config/auth_u2m_test.go | 118 +++++++++++++--- credentials/u2m/persistent_auth.go | 37 ++++- credentials/u2m/persistent_auth_test.go | 171 ++++++++++++++++++++++++ 4 files changed, 326 insertions(+), 30 deletions(-) diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 4f0c602a3..9283b64bd 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -12,12 +12,18 @@ import ( "golang.org/x/oauth2" ) +// persistentAuthFactory is a function that creates a token source for U2M +// authentication. It can be replaced in tests to spy on the options passed. +type persistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) + // u2mCredentials is a credentials strategy that uses the U2M OAuth flow to // authenticate with Databricks. It loads a token from the token cache for the // given workspace or account, refreshing it using the associated refresh token // if needed. type u2mCredentials struct { - testTokenSource oauth2.TokenSource // replace u2m token source + // newPersistentAuth is the factory function to create a PersistentAuth. + // If nil, the default u2m.NewPersistentAuth is used. + newPersistentAuth persistentAuthFactory } // Name implements CredentialsStrategy. @@ -38,14 +44,22 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, err } - var ts oauth2.TokenSource - if u.testTokenSource != nil { - ts = u.testTokenSource - } else { - ts, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort)) - if err != nil { - return nil, err + var factory persistentAuthFactory + if u.newPersistentAuth == nil { + factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + return u2m.NewPersistentAuth(ctx, opts...) } + } else { + factory = u.newPersistentAuth + } + ts, err := factory(ctx, + u2m.WithOAuthArgument(arg), + u2m.WithPort(cfg.OAuthCallbackPort), + u2m.WithScopes(cfg.GetScopes()), + u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken), + ) + if err != nil { + return nil, err } // TODO: Having to handle the CLI error here is not ideal as it couples the diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 29be41bad..99ebfdd8e 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -5,12 +5,14 @@ import ( "errors" "fmt" "net/http" + "sort" "strings" "testing" "time" "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/google/go-cmp/cmp" "golang.org/x/oauth2" ) @@ -53,14 +55,37 @@ var ( errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{} ) +// mockPersistentAuthFactory returns a persistentAuthFactory that returns ts. +func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory { + return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + return ts, nil + } +} + +// capturingPersistentAuthFactory returns a persistentAuthFactory that applies +// options to a real PersistentAuth and calls onCapture, allowing tests to spy +// on the options passed. It returns ts for token operations. +func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) persistentAuthFactory { + return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + pa, err := u2m.NewPersistentAuth(ctx, opts...) + if err != nil { + return nil, err + } + if onCapture != nil { + onCapture(pa) + } + return ts, nil + } +} + func TestU2MCredentials_Configure(t *testing.T) { testCases := []struct { - desc string - cfg *Config - testTokenSource *testTokenSource - wantConfigErr string // error message from Configure() - wantHeaderErr string // error message from SetHeaders() - wantAuthHeader string // expected Authorization header + desc string + cfg *Config + tokenSource *testTokenSource + wantConfigErr string // error message from Configure() + wantHeaderErr string // error message from SetHeaders() + wantAuthHeader string // expected Authorization header }{ { desc: "missing host returns error", @@ -74,7 +99,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ token: testValidToken, }, wantAuthHeader: "Bearer valid-access-token", @@ -85,7 +110,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Host: "https://accounts.cloud.databricks.com", AccountID: "abc-123", }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ token: testValidToken, }, wantAuthHeader: "Bearer valid-access-token", @@ -95,7 +120,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ token: testExpiredToken, }, wantAuthHeader: "Bearer expired-access-token", @@ -105,7 +130,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errNetwork, }, wantHeaderErr: "network timeout", @@ -115,7 +140,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errAuthentication, }, wantHeaderErr: "authentication failed", @@ -127,7 +152,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "my-workspace", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --profile my-workspace", @@ -138,7 +163,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Host: "https://workspace.cloud.databricks.com", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://workspace.cloud.databricks.com", @@ -151,7 +176,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "prod-account", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --profile prod-account", @@ -163,7 +188,7 @@ func TestU2MCredentials_Configure(t *testing.T) { AccountID: "abc-123", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123", @@ -175,7 +200,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "test", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: fmt.Errorf("oauth2: %w", errInvalidRefreshToken), }, wantHeaderErr: "databricks auth login --profile test", @@ -187,7 +212,7 @@ func TestU2MCredentials_Configure(t *testing.T) { AccountID: "abc-456", resolved: true, }, - testTokenSource: &testTokenSource{ + tokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456", @@ -197,7 +222,9 @@ func TestU2MCredentials_Configure(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { ctx := context.Background() - u := u2mCredentials{testTokenSource: tc.testTokenSource} + u := u2mCredentials{ + newPersistentAuth: mockPersistentAuthFactory(tc.tokenSource), + } cp, gotConfigErr := u.Configure(ctx, tc.cfg) @@ -238,7 +265,9 @@ func TestU2MCredentials_Configure(t *testing.T) { func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { ts := &testTokenSource{token: testValidToken} - u := u2mCredentials{testTokenSource: ts} + u := u2mCredentials{ + newPersistentAuth: mockPersistentAuthFactory(ts), + } cfg := &Config{ Host: "https://workspace.cloud.databricks.com", } @@ -261,3 +290,54 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts) } } + +func TestU2MCredentials_Configure_Scopes(t *testing.T) { + testCases := []struct { + desc string + configScopes []string + expectedScopes []string + sortScopes bool // whether to sort captured scopes before comparison + }{ + { + desc: "default scopes when not specified", + configScopes: nil, + expectedScopes: []string{"all-apis"}, + sortScopes: false, + }, + { + desc: "custom scopes are passed through", + configScopes: []string{"sql", "clusters"}, + expectedScopes: []string{"clusters", "sql"}, // sorted during config resolution + sortScopes: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ts := &testTokenSource{token: testValidToken} + var capturedScopes []string + + u := u2mCredentials{ + newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) { + capturedScopes = pa.GetScopes() + }), + } + cfg := &Config{ + Host: "https://workspace.cloud.databricks.com", + Scopes: tc.configScopes, + } + + _, err := u.Configure(context.Background(), cfg) + if err != nil { + t.Fatalf("Configure() error = %v", err) + } + + if tc.sortScopes { + sort.Strings(capturedScopes) + } + if diff := cmp.Diff(tc.expectedScopes, capturedScopes); diff != "" { + t.Errorf("scopes mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index b91be3b6a..bd59f1754 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -88,6 +88,14 @@ type PersistentAuth struct { // netListen is an optional function to listen on a TCP address. If not set, // it will use net.Listen by default. This is useful for testing. netListen func(network, address string) (net.Listener, error) + + // scopes is the list of OAuth scopes to request. + scopes []string + + // disableOfflineAccess controls whether offline_access scope is requested. + // When true, offline_access will NOT be automatically added to scopes, + // meaning the token will not include a refresh token. + disableOfflineAccess bool } type PersistentAuthOption func(*PersistentAuth) @@ -135,6 +143,26 @@ func WithPort(port int) PersistentAuthOption { } } +// WithScopes sets the OAuth scopes for the PersistentAuth. +func WithScopes(scopes []string) PersistentAuthOption { + return func(a *PersistentAuth) { + a.scopes = scopes + } +} + +// WithDisableOfflineAccess controls whether offline_access scope is requested. +// When true, offline_access will NOT be automatically added to scopes. +func WithDisableOfflineAccess(disable bool) PersistentAuthOption { + return func(a *PersistentAuth) { + a.disableOfflineAccess = disable + } +} + +// GetScopes returns the OAuth scopes configured for this PersistentAuth. +func (a *PersistentAuth) GetScopes() []string { + return a.scopes +} + // NewPersistentAuth creates a new PersistentAuth with the provided options. func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { p := &PersistentAuth{} @@ -368,10 +396,13 @@ func (a *PersistentAuth) validateArg() error { // oauth2Config returns the OAuth2 configuration for the given OAuthArgument. func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { - scopes := []string{ - "offline_access", // ensures OAuth token includes refresh token - "all-apis", // ensures OAuth token has access to all control-plane APIs + scopes := a.scopes + if !a.disableOfflineAccess { + // Use append to create a new slice with "offline_access" added, + // avoiding mutation of the original a.scopes slice. + scopes = append(append([]string{}, scopes...), "offline_access") } + var endpoints *OAuthAuthorizationServer var err error switch argg := a.oAuthArgument.(type) { diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index a09cf8d03..b42fa14c2 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -487,3 +487,174 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) { t.Fatalf("pa.startListener(): want error %v, got %v", testError, gotErr) } } + +// TestU2M_ScopesAndOfflineAccess verifies that OAuth scopes are correctly configured +// and sent during the authorization flow, and that the disableOfflineAccess flag +// correctly controls whether offline_access is added to the scope. +func TestU2M_ScopesAndOfflineAccess(t *testing.T) { + tests := []struct { + name string + createArg func() (OAuthArgument, error) + scopes []string + disableOffline bool + expectedScope string + expectedPath string + tokenResponse string + tokenResource string + }{ + { + name: "single scope", + createArg: func() (OAuthArgument, error) { + return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + }, + scopes: []string{"dashboards"}, + disableOffline: false, + expectedScope: "dashboards offline_access", + tokenResponse: `access_token=token&refresh_token=refresh`, + tokenResource: "/oidc/accounts/xyz/v1/token", + }, + { + name: "multiple scopes", + createArg: func() (OAuthArgument, error) { + return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") + }, + scopes: []string{"files", "jobs", "mlflow"}, + disableOffline: false, + expectedScope: "files jobs mlflow offline_access", + tokenResponse: `access_token=token&refresh_token=refresh`, + tokenResource: "/oidc/accounts/xyz/v1/token", + }, + { + name: "workspace OAuth argument", + createArg: func() (OAuthArgument, error) { + return NewBasicWorkspaceOAuthArgument("https://my-workspace.cloud.databricks.com") + }, + scopes: []string{"genie"}, + disableOffline: false, + expectedScope: "genie offline_access", + expectedPath: "/oidc/v1/authorize", + tokenResponse: `access_token=token&refresh_token=refresh`, + tokenResource: "/oidc/v1/token", + }, + { + name: "account OAuth argument", + createArg: func() (OAuthArgument, error) { + return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "my-account") + }, + scopes: []string{"files", "iam"}, + disableOffline: false, + expectedScope: "files iam offline_access", + expectedPath: "/oidc/accounts/my-account/v1/authorize", + tokenResponse: `access_token=token&refresh_token=refresh`, + tokenResource: "/oidc/accounts/my-account/v1/token", + }, + { + name: "unified OAuth argument", + createArg: func() (OAuthArgument, error) { + return NewBasicUnifiedOAuthArgument("https://unified.cloud.databricks.com", "my-account") + }, + scopes: []string{"pipelines", "workspaces"}, + disableOffline: false, + expectedScope: "pipelines workspaces offline_access", + expectedPath: "/oidc/accounts/my-account/v1/authorize", + tokenResponse: `access_token=token&refresh_token=refresh`, + tokenResource: "/oidc/accounts/my-account/v1/token", + }, + { + name: "disable offline_access", + createArg: func() (OAuthArgument, error) { + return NewBasicWorkspaceOAuthArgument("https://my-workspace.cloud.databricks.com") + }, + scopes: []string{"files", "jobs"}, + disableOffline: true, + expectedScope: "files jobs", + tokenResponse: `access_token=token`, + tokenResource: "/oidc/v1/token", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + scopeReceived := make(chan string, 1) + browserOpened := make(chan string, 1) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + if tt.expectedPath != "" && u.Path != tt.expectedPath { + t.Errorf("browser(): want path '%s', got %s", tt.expectedPath, u.Path) + } + query := u.Query() + scopeReceived <- query.Get("scope") + browserOpened <- query.Get("state") + return nil + } + + cache := &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + return nil + }, + } + + arg, err := tt.createArg() + if err != nil { + t.Fatalf("createArg(): want no error, got %v", err) + } + + opts := []PersistentAuthOption{ + WithTokenCache(cache), + WithBrowser(browser), + WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: tt.tokenResource, + Response: tt.tokenResponse, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, + }, + }), + WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + WithOAuthArgument(arg), + WithDisableOfflineAccess(tt.disableOffline), + WithScopes(tt.scopes), + } + + p, err := NewPersistentAuth(ctx, opts...) + if err != nil { + t.Fatalf("NewPersistentAuth(): want no error, got %v", err) + } + defer p.Close() + + errc := make(chan error) + go func() { + err := p.Challenge() + errc <- err + close(errc) + }() + + scope := <-scopeReceived + state := <-browserOpened + + if scope != tt.expectedScope { + t.Errorf("scope: want '%s', got '%s'", tt.expectedScope, scope) + } + + resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__CODE__&state=%s", state)) + if err != nil { + t.Fatalf("http.Get(): want no error, got %v", err) + } + defer resp.Body.Close() + + err = <-errc + if err != nil { + t.Fatalf("p.Challenge(): want no error, got %v", err) + } + }) + } +} From 06702001d2cb9130660e0b7dfad376ded965b32b Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Sun, 4 Jan 2026 13:46:16 +0000 Subject: [PATCH 2/7] clean up --- config/auth_u2m_test.go | 49 +++++++++++++++--------------- credentials/u2m/persistent_auth.go | 6 ++-- 2 files changed, 27 insertions(+), 28 deletions(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 99ebfdd8e..838d5d2f3 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -55,16 +55,17 @@ var ( errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{} ) -// mockPersistentAuthFactory returns a persistentAuthFactory that returns ts. +// mockPersistentAuthFactory creates a test factory for bypassing real auth setup. +// Use this when tests only need to control token behavior without caring about auth configuration. func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory { return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { return ts, nil } } -// capturingPersistentAuthFactory returns a persistentAuthFactory that applies -// options to a real PersistentAuth and calls onCapture, allowing tests to spy -// on the options passed. It returns ts for token operations. +// capturingPersistentAuthFactory creates a test factory for inspecting auth configuration. +// Use this when tests need to verify what options were passed to PersistentAuth while +// still controlling token behavior through ts. func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) persistentAuthFactory { return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { pa, err := u2m.NewPersistentAuth(ctx, opts...) @@ -80,12 +81,12 @@ func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.P func TestU2MCredentials_Configure(t *testing.T) { testCases := []struct { - desc string - cfg *Config - tokenSource *testTokenSource - wantConfigErr string // error message from Configure() - wantHeaderErr string // error message from SetHeaders() - wantAuthHeader string // expected Authorization header + desc string + cfg *Config + testTokenSource *testTokenSource + wantConfigErr string // error message from Configure() + wantHeaderErr string // error message from SetHeaders() + wantAuthHeader string // expected Authorization header }{ { desc: "missing host returns error", @@ -99,7 +100,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ token: testValidToken, }, wantAuthHeader: "Bearer valid-access-token", @@ -110,7 +111,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Host: "https://accounts.cloud.databricks.com", AccountID: "abc-123", }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ token: testValidToken, }, wantAuthHeader: "Bearer valid-access-token", @@ -120,7 +121,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ token: testExpiredToken, }, wantAuthHeader: "Bearer expired-access-token", @@ -130,7 +131,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errNetwork, }, wantHeaderErr: "network timeout", @@ -140,7 +141,7 @@ func TestU2MCredentials_Configure(t *testing.T) { cfg: &Config{ Host: "https://workspace.cloud.databricks.com", }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errAuthentication, }, wantHeaderErr: "authentication failed", @@ -152,7 +153,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "my-workspace", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --profile my-workspace", @@ -163,7 +164,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Host: "https://workspace.cloud.databricks.com", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://workspace.cloud.databricks.com", @@ -176,7 +177,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "prod-account", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --profile prod-account", @@ -188,7 +189,7 @@ func TestU2MCredentials_Configure(t *testing.T) { AccountID: "abc-123", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://accounts.cloud.databricks.com --account-id abc-123", @@ -200,7 +201,7 @@ func TestU2MCredentials_Configure(t *testing.T) { Profile: "test", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: fmt.Errorf("oauth2: %w", errInvalidRefreshToken), }, wantHeaderErr: "databricks auth login --profile test", @@ -212,7 +213,7 @@ func TestU2MCredentials_Configure(t *testing.T) { AccountID: "abc-456", resolved: true, }, - tokenSource: &testTokenSource{ + testTokenSource: &testTokenSource{ err: errInvalidRefreshToken, }, wantHeaderErr: "databricks auth login --host https://accounts.azure.databricks.net --account-id abc-456", @@ -223,7 +224,7 @@ func TestU2MCredentials_Configure(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { ctx := context.Background() u := u2mCredentials{ - newPersistentAuth: mockPersistentAuthFactory(tc.tokenSource), + newPersistentAuth: mockPersistentAuthFactory(tc.testTokenSource), } cp, gotConfigErr := u.Configure(ctx, tc.cfg) @@ -296,7 +297,7 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { desc string configScopes []string expectedScopes []string - sortScopes bool // whether to sort captured scopes before comparison + sortScopes bool }{ { desc: "default scopes when not specified", @@ -307,7 +308,7 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { { desc: "custom scopes are passed through", configScopes: []string{"sql", "clusters"}, - expectedScopes: []string{"clusters", "sql"}, // sorted during config resolution + expectedScopes: []string{"clusters", "sql"}, sortScopes: true, }, } diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index bd59f1754..1c552c98c 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -16,6 +16,7 @@ import ( "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "github.com/pkg/browser" + "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/oauth2/authhandler" ) @@ -151,7 +152,6 @@ func WithScopes(scopes []string) PersistentAuthOption { } // WithDisableOfflineAccess controls whether offline_access scope is requested. -// When true, offline_access will NOT be automatically added to scopes. func WithDisableOfflineAccess(disable bool) PersistentAuthOption { return func(a *PersistentAuth) { a.disableOfflineAccess = disable @@ -398,9 +398,7 @@ func (a *PersistentAuth) validateArg() error { func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { scopes := a.scopes if !a.disableOfflineAccess { - // Use append to create a new slice with "offline_access" added, - // avoiding mutation of the original a.scopes slice. - scopes = append(append([]string{}, scopes...), "offline_access") + scopes = append(slices.Clone(scopes), "offline_access") } var endpoints *OAuthAuthorizationServer From 3c3886a72f027135ff95ee0c86e065dab00b2c6f Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Mon, 5 Jan 2026 08:38:26 +0000 Subject: [PATCH 3/7] simplify tests --- config/auth_u2m_test.go | 36 +++++----- credentials/u2m/persistent_auth_test.go | 91 ++++++------------------- 2 files changed, 36 insertions(+), 91 deletions(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 838d5d2f3..785d02086 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "sort" "strings" "testing" "time" @@ -294,27 +293,29 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { func TestU2MCredentials_Configure_Scopes(t *testing.T) { testCases := []struct { - desc string - configScopes []string - expectedScopes []string - sortScopes bool + name string + scopes []string + want []string }{ { - desc: "default scopes when not specified", - configScopes: nil, - expectedScopes: []string{"all-apis"}, - sortScopes: false, + name: "nil scopes uses default", + scopes: nil, + want: []string{"all-apis"}, }, { - desc: "custom scopes are passed through", - configScopes: []string{"sql", "clusters"}, - expectedScopes: []string{"clusters", "sql"}, - sortScopes: true, + name: "empty scopes uses default", + scopes: []string{}, + want: []string{"all-apis"}, + }, + { + name: "multiple scopes are sorted", + scopes: []string{"clusters", "jobs", "sql:read"}, + want: []string{"clusters", "jobs", "sql:read"}, }, } for _, tc := range testCases { - t.Run(tc.desc, func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { ts := &testTokenSource{token: testValidToken} var capturedScopes []string @@ -325,7 +326,7 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { } cfg := &Config{ Host: "https://workspace.cloud.databricks.com", - Scopes: tc.configScopes, + Scopes: tc.scopes, } _, err := u.Configure(context.Background(), cfg) @@ -333,10 +334,7 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { t.Fatalf("Configure() error = %v", err) } - if tc.sortScopes { - sort.Strings(capturedScopes) - } - if diff := cmp.Diff(tc.expectedScopes, capturedScopes); diff != "" { + if diff := cmp.Diff(tc.want, capturedScopes); diff != "" { t.Errorf("scopes mismatch (-want +got):\n%s", diff) } }) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index b42fa14c2..7a8b5c0d9 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -494,82 +494,27 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) { func TestU2M_ScopesAndOfflineAccess(t *testing.T) { tests := []struct { name string - createArg func() (OAuthArgument, error) scopes []string disableOffline bool - expectedScope string - expectedPath string - tokenResponse string - tokenResource string + want string }{ { - name: "single scope", - createArg: func() (OAuthArgument, error) { - return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") - }, + name: "single scope with offline_access", scopes: []string{"dashboards"}, disableOffline: false, - expectedScope: "dashboards offline_access", - tokenResponse: `access_token=token&refresh_token=refresh`, - tokenResource: "/oidc/accounts/xyz/v1/token", - }, - { - name: "multiple scopes", - createArg: func() (OAuthArgument, error) { - return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "xyz") - }, - scopes: []string{"files", "jobs", "mlflow"}, - disableOffline: false, - expectedScope: "files jobs mlflow offline_access", - tokenResponse: `access_token=token&refresh_token=refresh`, - tokenResource: "/oidc/accounts/xyz/v1/token", - }, - { - name: "workspace OAuth argument", - createArg: func() (OAuthArgument, error) { - return NewBasicWorkspaceOAuthArgument("https://my-workspace.cloud.databricks.com") - }, - scopes: []string{"genie"}, - disableOffline: false, - expectedScope: "genie offline_access", - expectedPath: "/oidc/v1/authorize", - tokenResponse: `access_token=token&refresh_token=refresh`, - tokenResource: "/oidc/v1/token", + want: "dashboards offline_access", }, { - name: "account OAuth argument", - createArg: func() (OAuthArgument, error) { - return NewBasicAccountOAuthArgument("https://accounts.cloud.databricks.com", "my-account") - }, - scopes: []string{"files", "iam"}, - disableOffline: false, - expectedScope: "files iam offline_access", - expectedPath: "/oidc/accounts/my-account/v1/authorize", - tokenResponse: `access_token=token&refresh_token=refresh`, - tokenResource: "/oidc/accounts/my-account/v1/token", - }, - { - name: "unified OAuth argument", - createArg: func() (OAuthArgument, error) { - return NewBasicUnifiedOAuthArgument("https://unified.cloud.databricks.com", "my-account") - }, - scopes: []string{"pipelines", "workspaces"}, + name: "multiple scopes with offline_access", + scopes: []string{"files", "jobs", "mlflow:read"}, disableOffline: false, - expectedScope: "pipelines workspaces offline_access", - expectedPath: "/oidc/accounts/my-account/v1/authorize", - tokenResponse: `access_token=token&refresh_token=refresh`, - tokenResource: "/oidc/accounts/my-account/v1/token", + want: "files jobs mlflow:read offline_access", }, { - name: "disable offline_access", - createArg: func() (OAuthArgument, error) { - return NewBasicWorkspaceOAuthArgument("https://my-workspace.cloud.databricks.com") - }, + name: "disable offline_access", scopes: []string{"files", "jobs"}, disableOffline: true, - expectedScope: "files jobs", - tokenResponse: `access_token=token`, - tokenResource: "/oidc/v1/token", + want: "files jobs", }, } @@ -584,9 +529,6 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { if err != nil { return err } - if tt.expectedPath != "" && u.Path != tt.expectedPath { - t.Errorf("browser(): want path '%s', got %s", tt.expectedPath, u.Path) - } query := u.Query() scopeReceived <- query.Get("scope") browserOpened <- query.Get("state") @@ -599,9 +541,14 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { }, } - arg, err := tt.createArg() + arg, err := NewBasicWorkspaceOAuthArgument("https://workspace.cloud.databricks.com") if err != nil { - t.Fatalf("createArg(): want no error, got %v", err) + t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err) + } + + tokenResponse := `access_token=token&refresh_token=refresh` + if tt.disableOffline { + tokenResponse = `access_token=token` } opts := []PersistentAuthOption{ @@ -611,8 +558,8 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { Transport: fixtures.SliceTransport{ { Method: "POST", - Resource: tt.tokenResource, - Response: tt.tokenResponse, + Resource: "/oidc/v1/token", + Response: tokenResponse, ResponseHeaders: map[string][]string{ "Content-Type": {"application/x-www-form-urlencoded"}, }, @@ -641,8 +588,8 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { scope := <-scopeReceived state := <-browserOpened - if scope != tt.expectedScope { - t.Errorf("scope: want '%s', got '%s'", tt.expectedScope, scope) + if scope != tt.want { + t.Errorf("scope: want %q, got %q", tt.want, scope) } resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__CODE__&state=%s", state)) From 92e3b5ab617a4047455ca07fd2a50109d68b5609 Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Wed, 7 Jan 2026 08:47:36 +0000 Subject: [PATCH 4/7] small improvements and fixes --- credentials/u2m/persistent_auth.go | 8 +++++-- credentials/u2m/persistent_auth_test.go | 30 ++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index 1c552c98c..dc6a6030d 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -16,7 +16,6 @@ import ( "github.com/databricks/databricks-sdk-go/httpclient" "github.com/databricks/databricks-sdk-go/logger" "github.com/pkg/browser" - "golang.org/x/exp/slices" "golang.org/x/oauth2" "golang.org/x/oauth2/authhandler" ) @@ -396,9 +395,14 @@ func (a *PersistentAuth) validateArg() error { // oauth2Config returns the OAuth2 configuration for the given OAuthArgument. func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { + // Default to "all-apis" for backwards compatibility with direct users of PersistentAuth + // i.e. people implementing their own U2M authentication. scopes := a.scopes + if len(scopes) == 0 { + scopes = []string{"all-apis"} + } if !a.disableOfflineAccess { - scopes = append(slices.Clone(scopes), "offline_access") + scopes = append([]string{"offline_access"}, scopes...) } var endpoints *OAuthAuthorizationServer diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 7a8b5c0d9..d92beb7c9 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -492,12 +492,30 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) { // and sent during the authorization flow, and that the disableOfflineAccess flag // correctly controls whether offline_access is added to the scope. func TestU2M_ScopesAndOfflineAccess(t *testing.T) { + const ( + testWorkspaceHost = "https://workspace.cloud.databricks.com" + testTokenEndpoint = "/oidc/v1/token" + testCallbackURL = "http://localhost:8020" + ) + tests := []struct { name string scopes []string disableOffline bool want string }{ + { + name: "nil scopes uses default with offline_access", + scopes: nil, + disableOffline: false, + want: "all-apis offline_access", + }, + { + name: "empty scopes uses default with offline_access", + scopes: []string{}, + disableOffline: false, + want: "all-apis offline_access", + }, { name: "single scope with offline_access", scopes: []string{"dashboards"}, @@ -516,6 +534,12 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { disableOffline: true, want: "files jobs", }, + { + name: "nil scopes with disable offline_access", + scopes: nil, + disableOffline: true, + want: "all-apis", + }, } for _, tt := range tests { @@ -541,7 +565,7 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { }, } - arg, err := NewBasicWorkspaceOAuthArgument("https://workspace.cloud.databricks.com") + arg, err := NewBasicWorkspaceOAuthArgument(testWorkspaceHost) if err != nil { t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err) } @@ -558,7 +582,7 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { Transport: fixtures.SliceTransport{ { Method: "POST", - Resource: "/oidc/v1/token", + Resource: testTokenEndpoint, Response: tokenResponse, ResponseHeaders: map[string][]string{ "Content-Type": {"application/x-www-form-urlencoded"}, @@ -592,7 +616,7 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { t.Errorf("scope: want %q, got %q", tt.want, scope) } - resp, err := http.Get(fmt.Sprintf("http://localhost:8020?code=__CODE__&state=%s", state)) + resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, state)) if err != nil { t.Fatalf("http.Get(): want no error, got %v", err) } From 19acdcdc319f57bf09bcbeb1f8f7de29599e0d97 Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Wed, 7 Jan 2026 17:30:13 +0000 Subject: [PATCH 5/7] fix test --- credentials/u2m/persistent_auth_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index d92beb7c9..7943e2d0e 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -508,25 +508,25 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { name: "nil scopes uses default with offline_access", scopes: nil, disableOffline: false, - want: "all-apis offline_access", + want: "offline_access all-apis", }, { name: "empty scopes uses default with offline_access", scopes: []string{}, disableOffline: false, - want: "all-apis offline_access", + want: "offline_access all-apis", }, { name: "single scope with offline_access", scopes: []string{"dashboards"}, disableOffline: false, - want: "dashboards offline_access", + want: "offline_access dashboards", }, { name: "multiple scopes with offline_access", scopes: []string{"files", "jobs", "mlflow:read"}, disableOffline: false, - want: "files jobs mlflow:read offline_access", + want: "offline_access files jobs mlflow:read", }, { name: "disable offline_access", From 31fb45708364da3ceaeb3f80f341d91cd568e720 Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Wed, 7 Jan 2026 18:05:59 +0000 Subject: [PATCH 6/7] address comments --- credentials/u2m/persistent_auth_test.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index 7943e2d0e..4e909ec58 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -546,16 +546,17 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - scopeReceived := make(chan string, 1) - browserOpened := make(chan string, 1) + var scopeReceived, stateReceived string + browserCalled := make(chan struct{}) browser := func(redirect string) error { u, err := url.ParseRequestURI(redirect) if err != nil { return err } query := u.Query() - scopeReceived <- query.Get("scope") - browserOpened <- query.Get("state") + scopeReceived = query.Get("scope") + stateReceived = query.Get("state") + close(browserCalled) return nil } @@ -570,9 +571,11 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err) } - tokenResponse := `access_token=token&refresh_token=refresh` + var tokenResponse string if tt.disableOffline { tokenResponse = `access_token=token` + } else { + tokenResponse = `access_token=token&refresh_token=refresh` } opts := []PersistentAuthOption{ @@ -609,14 +612,13 @@ func TestU2M_ScopesAndOfflineAccess(t *testing.T) { close(errc) }() - scope := <-scopeReceived - state := <-browserOpened + <-browserCalled - if scope != tt.want { - t.Errorf("scope: want %q, got %q", tt.want, scope) + if scopeReceived != tt.want { + t.Errorf("scope: want %q, got %q", tt.want, scopeReceived) } - resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, state)) + resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, stateReceived)) if err != nil { t.Fatalf("http.Get(): want no error, got %v", err) } From b9dac8eb099f243056e711dcab6aa064cfcc4ec4 Mon Sep 17 00:00:00 2001 From: tejas-kochar Date: Thu, 8 Jan 2026 15:00:06 +0000 Subject: [PATCH 7/7] Refactor test to not require exporting a getter --- config/auth_u2m_test.go | 51 ++++++++++++++++-------------- credentials/u2m/persistent_auth.go | 5 --- 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 785d02086..4ad0cc208 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/oauth2" ) @@ -62,22 +63,6 @@ func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory { } } -// capturingPersistentAuthFactory creates a test factory for inspecting auth configuration. -// Use this when tests need to verify what options were passed to PersistentAuth while -// still controlling token behavior through ts. -func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) persistentAuthFactory { - return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { - pa, err := u2m.NewPersistentAuth(ctx, opts...) - if err != nil { - return nil, err - } - if onCapture != nil { - onCapture(pa) - } - return ts, nil - } -} - func TestU2MCredentials_Configure(t *testing.T) { testCases := []struct { desc string @@ -292,6 +277,7 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { } func TestU2MCredentials_Configure_Scopes(t *testing.T) { + const workspaceHost = "https://workspace.cloud.databricks.com" testCases := []struct { name string scopes []string @@ -317,15 +303,20 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ts := &testTokenSource{token: testValidToken} - var capturedScopes []string + var capturedPA *u2m.PersistentAuth u := u2mCredentials{ - newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) { - capturedScopes = pa.GetScopes() - }), + newPersistentAuth: func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + pa, err := u2m.NewPersistentAuth(ctx, opts...) + if err != nil { + return nil, err + } + capturedPA = pa + return ts, nil + }, } cfg := &Config{ - Host: "https://workspace.cloud.databricks.com", + Host: workspaceHost, Scopes: tc.scopes, } @@ -334,8 +325,22 @@ func TestU2MCredentials_Configure_Scopes(t *testing.T) { t.Fatalf("Configure() error = %v", err) } - if diff := cmp.Diff(tc.want, capturedScopes); diff != "" { - t.Errorf("scopes mismatch (-want +got):\n%s", diff) + arg, _ := u2m.NewBasicWorkspaceOAuthArgument(workspaceHost) + wantPA, err := u2m.NewPersistentAuth(context.Background(), + u2m.WithOAuthArgument(arg), + u2m.WithScopes(tc.want), + ) + if err != nil { + t.Fatalf("NewPersistentAuth() error = %v", err) + } + + if diff := cmp.Diff(wantPA, capturedPA, + cmp.AllowUnexported(u2m.PersistentAuth{}), + cmpopts.IgnoreFields(u2m.PersistentAuth{}, + "cache", "client", "endpointSupplier", "oAuthArgument", + "browser", "ln", "ctx", "redirectAddr", "port", "netListen", + "disableOfflineAccess")); diff != "" { + t.Errorf("PersistentAuth mismatch (-want +got):\n%s", diff) } }) } diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index dc6a6030d..50b997588 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -157,11 +157,6 @@ func WithDisableOfflineAccess(disable bool) PersistentAuthOption { } } -// GetScopes returns the OAuth scopes configured for this PersistentAuth. -func (a *PersistentAuth) GetScopes() []string { - return a.scopes -} - // NewPersistentAuth creates a new PersistentAuth with the provided options. func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { p := &PersistentAuth{}