diff --git a/config/auth_u2m.go b/config/auth_u2m.go index 4f0c602a3..9283b64bd 100644 --- a/config/auth_u2m.go +++ b/config/auth_u2m.go @@ -12,12 +12,18 @@ import ( "golang.org/x/oauth2" ) +// persistentAuthFactory is a function that creates a token source for U2M +// authentication. It can be replaced in tests to spy on the options passed. +type persistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) + // u2mCredentials is a credentials strategy that uses the U2M OAuth flow to // authenticate with Databricks. It loads a token from the token cache for the // given workspace or account, refreshing it using the associated refresh token // if needed. type u2mCredentials struct { - testTokenSource oauth2.TokenSource // replace u2m token source + // newPersistentAuth is the factory function to create a PersistentAuth. + // If nil, the default u2m.NewPersistentAuth is used. + newPersistentAuth persistentAuthFactory } // Name implements CredentialsStrategy. @@ -38,14 +44,22 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials return nil, err } - var ts oauth2.TokenSource - if u.testTokenSource != nil { - ts = u.testTokenSource - } else { - ts, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort)) - if err != nil { - return nil, err + var factory persistentAuthFactory + if u.newPersistentAuth == nil { + factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + return u2m.NewPersistentAuth(ctx, opts...) } + } else { + factory = u.newPersistentAuth + } + ts, err := factory(ctx, + u2m.WithOAuthArgument(arg), + u2m.WithPort(cfg.OAuthCallbackPort), + u2m.WithScopes(cfg.GetScopes()), + u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken), + ) + if err != nil { + return nil, err } // TODO: Having to handle the CLI error here is not ideal as it couples the diff --git a/config/auth_u2m_test.go b/config/auth_u2m_test.go index 29be41bad..e29fbbcf9 100644 --- a/config/auth_u2m_test.go +++ b/config/auth_u2m_test.go @@ -11,6 +11,8 @@ import ( "github.com/databricks/databricks-sdk-go/config/credentials" "github.com/databricks/databricks-sdk-go/credentials/u2m" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/oauth2" ) @@ -53,6 +55,14 @@ var ( errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{} ) +// mockPersistentAuthFactory creates a test factory for bypassing real auth setup. +// Use this when tests only need to control token behavior without caring about auth configuration. +func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory { + return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + return ts, nil + } +} + func TestU2MCredentials_Configure(t *testing.T) { testCases := []struct { desc string @@ -197,7 +207,9 @@ func TestU2MCredentials_Configure(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { ctx := context.Background() - u := u2mCredentials{testTokenSource: tc.testTokenSource} + u := u2mCredentials{ + newPersistentAuth: mockPersistentAuthFactory(tc.testTokenSource), + } cp, gotConfigErr := u.Configure(ctx, tc.cfg) @@ -238,7 +250,9 @@ func TestU2MCredentials_Configure(t *testing.T) { func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { ts := &testTokenSource{token: testValidToken} - u := u2mCredentials{testTokenSource: ts} + u := u2mCredentials{ + newPersistentAuth: mockPersistentAuthFactory(ts), + } cfg := &Config{ Host: "https://workspace.cloud.databricks.com", } @@ -261,3 +275,76 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) { t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts) } } + +func TestU2MCredentials_Configure_Scopes(t *testing.T) { + const workspaceHost = "https://workspace.cloud.databricks.com" + testCases := []struct { + name string + scopes []string + want []string + }{ + { + name: "nil scopes uses default", + scopes: nil, + want: []string{"all-apis"}, + }, + { + name: "empty scopes uses default", + scopes: []string{}, + want: []string{"all-apis"}, + }, + { + name: "multiple scopes are sorted", + scopes: []string{"clusters", "jobs", "sql:read"}, + want: []string{"clusters", "jobs", "sql:read"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ts := &testTokenSource{token: testValidToken} + var capturedPA *u2m.PersistentAuth + + u := u2mCredentials{ + newPersistentAuth: func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) { + pa, err := u2m.NewPersistentAuth(ctx, opts...) + if err != nil { + return nil, err + } + capturedPA = pa + return ts, nil + }, + } + cfg := &Config{ + Host: workspaceHost, + Scopes: tc.scopes, + } + + _, err := u.Configure(context.Background(), cfg) + if err != nil { + t.Fatalf("Configure() error = %v", err) + } + + arg, err := u2m.NewBasicWorkspaceOAuthArgument(workspaceHost) + if err != nil { + t.Fatalf("NewBasicWorkspaceOAuthArgument() error = %v", err) + } + wantPA, err := u2m.NewPersistentAuth(context.Background(), + u2m.WithOAuthArgument(arg), + u2m.WithScopes(tc.want), + ) + if err != nil { + t.Fatalf("NewPersistentAuth() error = %v", err) + } + + if diff := cmp.Diff(wantPA, capturedPA, + cmp.AllowUnexported(u2m.PersistentAuth{}, u2m.BasicWorkspaceOAuthArgument{}), + cmpopts.IgnoreFields(u2m.PersistentAuth{}, + "cache", "client", "endpointSupplier", "browser", + "ln", "ctx", "redirectAddr", "port", "netListen", "disableOfflineAccess"), + ); diff != "" { + t.Errorf("PersistentAuth mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/credentials/u2m/persistent_auth.go b/credentials/u2m/persistent_auth.go index b91be3b6a..50b997588 100644 --- a/credentials/u2m/persistent_auth.go +++ b/credentials/u2m/persistent_auth.go @@ -88,6 +88,14 @@ type PersistentAuth struct { // netListen is an optional function to listen on a TCP address. If not set, // it will use net.Listen by default. This is useful for testing. netListen func(network, address string) (net.Listener, error) + + // scopes is the list of OAuth scopes to request. + scopes []string + + // disableOfflineAccess controls whether offline_access scope is requested. + // When true, offline_access will NOT be automatically added to scopes, + // meaning the token will not include a refresh token. + disableOfflineAccess bool } type PersistentAuthOption func(*PersistentAuth) @@ -135,6 +143,20 @@ func WithPort(port int) PersistentAuthOption { } } +// WithScopes sets the OAuth scopes for the PersistentAuth. +func WithScopes(scopes []string) PersistentAuthOption { + return func(a *PersistentAuth) { + a.scopes = scopes + } +} + +// WithDisableOfflineAccess controls whether offline_access scope is requested. +func WithDisableOfflineAccess(disable bool) PersistentAuthOption { + return func(a *PersistentAuth) { + a.disableOfflineAccess = disable + } +} + // NewPersistentAuth creates a new PersistentAuth with the provided options. func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) { p := &PersistentAuth{} @@ -368,10 +390,16 @@ func (a *PersistentAuth) validateArg() error { // oauth2Config returns the OAuth2 configuration for the given OAuthArgument. func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) { - scopes := []string{ - "offline_access", // ensures OAuth token includes refresh token - "all-apis", // ensures OAuth token has access to all control-plane APIs + // Default to "all-apis" for backwards compatibility with direct users of PersistentAuth + // i.e. people implementing their own U2M authentication. + scopes := a.scopes + if len(scopes) == 0 { + scopes = []string{"all-apis"} + } + if !a.disableOfflineAccess { + scopes = append([]string{"offline_access"}, scopes...) } + var endpoints *OAuthAuthorizationServer var err error switch argg := a.oAuthArgument.(type) { diff --git a/credentials/u2m/persistent_auth_test.go b/credentials/u2m/persistent_auth_test.go index a09cf8d03..6e8319492 100644 --- a/credentials/u2m/persistent_auth_test.go +++ b/credentials/u2m/persistent_auth_test.go @@ -487,3 +487,156 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) { t.Fatalf("pa.startListener(): want error %v, got %v", testError, gotErr) } } + +// TestU2M_ScopesAndOfflineAccess verifies that OAuth scopes are correctly configured +// and sent during the authorization flow, and that the disableOfflineAccess flag +// correctly controls whether offline_access is added to the scope. +func TestU2M_ScopesAndOfflineAccess(t *testing.T) { + const ( + testWorkspaceHost = "https://workspace.cloud.databricks.com" + testTokenEndpoint = "/oidc/v1/token" + testCallbackURL = "http://localhost:8020" + ) + + tests := []struct { + name string + scopes []string + disableOffline bool + want string + }{ + { + name: "nil scopes uses default with offline_access", + scopes: nil, + disableOffline: false, + want: "offline_access all-apis", + }, + { + name: "empty scopes uses default with offline_access", + scopes: []string{}, + disableOffline: false, + want: "offline_access all-apis", + }, + { + name: "single scope with offline_access", + scopes: []string{"dashboards"}, + disableOffline: false, + want: "offline_access dashboards", + }, + { + name: "multiple scopes with offline_access", + scopes: []string{"files", "jobs", "mlflow:read"}, + disableOffline: false, + want: "offline_access files jobs mlflow:read", + }, + { + name: "disable offline_access", + scopes: []string{"files", "jobs"}, + disableOffline: true, + want: "files jobs", + }, + { + name: "nil scopes with disable offline_access", + scopes: nil, + disableOffline: true, + want: "all-apis", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var scopeReceived, stateReceived string + browserCalled := make(chan struct{}) + defer close(browserCalled) + browser := func(redirect string) error { + u, err := url.ParseRequestURI(redirect) + if err != nil { + return err + } + query := u.Query() + scopeReceived = query.Get("scope") + stateReceived = query.Get("state") + browserCalled <- struct{}{} + return nil + } + + cache := &tokenCacheMock{ + store: func(key string, tok *oauth2.Token) error { + return nil + }, + } + + arg, err := NewBasicWorkspaceOAuthArgument(testWorkspaceHost) + if err != nil { + t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err) + } + + var tokenResponse string + if tt.disableOffline { + tokenResponse = `access_token=token` + } else { + tokenResponse = `access_token=token&refresh_token=refresh` + } + + opts := []PersistentAuthOption{ + WithTokenCache(cache), + WithBrowser(browser), + WithHttpClient(&http.Client{ + Transport: fixtures.SliceTransport{ + { + Method: "POST", + Resource: testTokenEndpoint, + Response: tokenResponse, + ResponseHeaders: map[string][]string{ + "Content-Type": {"application/x-www-form-urlencoded"}, + }, + }, + }, + }), + WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}), + WithOAuthArgument(arg), + WithDisableOfflineAccess(tt.disableOffline), + WithScopes(tt.scopes), + } + + p, err := NewPersistentAuth(ctx, opts...) + if err != nil { + t.Fatalf("NewPersistentAuth(): want no error, got %v", err) + } + defer p.Close() + + errc := make(chan error) + defer close(errc) + go func() { + err := p.Challenge() + errc <- err + }() + + select { + case <-browserCalled: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for browser to be called") + } + + if scopeReceived != tt.want { + t.Errorf("scope: want %q, got %q", tt.want, scopeReceived) + } + + resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, stateReceived)) + if err != nil { + t.Fatalf("http.Get(): want no error, got %v", err) + } + defer resp.Body.Close() + + select { + case err = <-errc: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for Challenge() to complete") + } + if err != nil { + t.Fatalf("p.Challenge(): want no error, got %v", err) + } + }) + } +}