Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions config/auth_u2m.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ import (
"golang.org/x/oauth2"
)

// persistentAuthFactory is a function that creates a token source for U2M
// authentication. It can be replaced in tests to spy on the options passed.
type persistentAuthFactory func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error)

// u2mCredentials is a credentials strategy that uses the U2M OAuth flow to
// authenticate with Databricks. It loads a token from the token cache for the
// given workspace or account, refreshing it using the associated refresh token
// if needed.
type u2mCredentials struct {
testTokenSource oauth2.TokenSource // replace u2m token source
// newPersistentAuth is the factory function to create a PersistentAuth.
// If nil, the default u2m.NewPersistentAuth is used.
newPersistentAuth persistentAuthFactory
}

// Name implements CredentialsStrategy.
Expand All @@ -38,14 +44,22 @@ func (u u2mCredentials) Configure(ctx context.Context, cfg *Config) (credentials
return nil, err
}

var ts oauth2.TokenSource
if u.testTokenSource != nil {
ts = u.testTokenSource
} else {
ts, err = u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(arg), u2m.WithPort(cfg.OAuthCallbackPort))
if err != nil {
return nil, err
var factory persistentAuthFactory
if u.newPersistentAuth == nil {
factory = func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
return u2m.NewPersistentAuth(ctx, opts...)
}
} else {
factory = u.newPersistentAuth
}
ts, err := factory(ctx,
u2m.WithOAuthArgument(arg),
u2m.WithPort(cfg.OAuthCallbackPort),
u2m.WithScopes(cfg.GetScopes()),
u2m.WithDisableOfflineAccess(cfg.DisableOAuthRefreshToken),
)
if err != nil {
return nil, err
}

// TODO: Having to handle the CLI error here is not ideal as it couples the
Expand Down
83 changes: 81 additions & 2 deletions config/auth_u2m_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/databricks/databricks-sdk-go/config/credentials"
"github.com/databricks/databricks-sdk-go/credentials/u2m"
"github.com/google/go-cmp/cmp"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -53,6 +54,30 @@ var (
errInvalidRefreshToken = &u2m.InvalidRefreshTokenError{}
)

// mockPersistentAuthFactory creates a test factory for bypassing real auth setup.
// Use this when tests only need to control token behavior without caring about auth configuration.
func mockPersistentAuthFactory(ts oauth2.TokenSource) persistentAuthFactory {
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
return ts, nil
}
}

// capturingPersistentAuthFactory creates a test factory for inspecting auth configuration.
// Use this when tests need to verify what options were passed to PersistentAuth while
// still controlling token behavior through ts.
func capturingPersistentAuthFactory(ts oauth2.TokenSource, onCapture func(*u2m.PersistentAuth)) persistentAuthFactory {
return func(ctx context.Context, opts ...u2m.PersistentAuthOption) (oauth2.TokenSource, error) {
pa, err := u2m.NewPersistentAuth(ctx, opts...)
if err != nil {
return nil, err
}
if onCapture != nil {
onCapture(pa)
}
return ts, nil
}
}

func TestU2MCredentials_Configure(t *testing.T) {
testCases := []struct {
desc string
Expand Down Expand Up @@ -197,7 +222,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
ctx := context.Background()
u := u2mCredentials{testTokenSource: tc.testTokenSource}
u := u2mCredentials{
newPersistentAuth: mockPersistentAuthFactory(tc.testTokenSource),
}

cp, gotConfigErr := u.Configure(ctx, tc.cfg)

Expand Down Expand Up @@ -238,7 +265,9 @@ func TestU2MCredentials_Configure(t *testing.T) {
func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
ts := &testTokenSource{token: testValidToken}

u := u2mCredentials{testTokenSource: ts}
u := u2mCredentials{
newPersistentAuth: mockPersistentAuthFactory(ts),
}
cfg := &Config{
Host: "https://workspace.cloud.databricks.com",
}
Expand All @@ -261,3 +290,53 @@ func TestU2MCredentials_Configure_TokenCaching(t *testing.T) {
t.Errorf("token source call count = %d, want 1 (should use cache)", ts.counts)
}
}

func TestU2MCredentials_Configure_Scopes(t *testing.T) {
testCases := []struct {
name string
scopes []string
want []string
}{
{
name: "nil scopes uses default",
scopes: nil,
want: []string{"all-apis"},
},
{
name: "empty scopes uses default",
scopes: []string{},
want: []string{"all-apis"},
},
{
name: "multiple scopes are sorted",
scopes: []string{"clusters", "jobs", "sql:read"},
want: []string{"clusters", "jobs", "sql:read"},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := &testTokenSource{token: testValidToken}
var capturedScopes []string

u := u2mCredentials{
newPersistentAuth: capturingPersistentAuthFactory(ts, func(pa *u2m.PersistentAuth) {
capturedScopes = pa.GetScopes()
}),
}
cfg := &Config{
Host: "https://workspace.cloud.databricks.com",
Scopes: tc.scopes,
}

_, err := u.Configure(context.Background(), cfg)
if err != nil {
t.Fatalf("Configure() error = %v", err)
}

if diff := cmp.Diff(tc.want, capturedScopes); diff != "" {
t.Errorf("scopes mismatch (-want +got):\n%s", diff)
}
})
}
}
39 changes: 36 additions & 3 deletions credentials/u2m/persistent_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ type PersistentAuth struct {
// netListen is an optional function to listen on a TCP address. If not set,
// it will use net.Listen by default. This is useful for testing.
netListen func(network, address string) (net.Listener, error)

// scopes is the list of OAuth scopes to request.
scopes []string

// disableOfflineAccess controls whether offline_access scope is requested.
// When true, offline_access will NOT be automatically added to scopes,
// meaning the token will not include a refresh token.
disableOfflineAccess bool
}

type PersistentAuthOption func(*PersistentAuth)
Expand Down Expand Up @@ -135,6 +143,25 @@ func WithPort(port int) PersistentAuthOption {
}
}

// WithScopes sets the OAuth scopes for the PersistentAuth.
func WithScopes(scopes []string) PersistentAuthOption {
return func(a *PersistentAuth) {
a.scopes = scopes
}
}

// WithDisableOfflineAccess controls whether offline_access scope is requested.
func WithDisableOfflineAccess(disable bool) PersistentAuthOption {
return func(a *PersistentAuth) {
a.disableOfflineAccess = disable
}
}

// GetScopes returns the OAuth scopes configured for this PersistentAuth.
func (a *PersistentAuth) GetScopes() []string {
return a.scopes
}
Comment on lines +160 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the use of this method? If it's just testing, I am not sure this should be public.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just for testing but it is not used in the same package (used in auth_u2m_test.go). I think this indicates that the files are not in the right packages (auth_u2m is not in credentials/u2m while persistent_auth is), but I'm not sure if this is the right PR to fix that. I can rename it to GetScopesForTesting and add a warning comment, or perhaps move the files in /credentials/u2m up into the config package?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps the tests in auth_u2m shouldn't be making assertions on persistent_auth


// NewPersistentAuth creates a new PersistentAuth with the provided options.
func NewPersistentAuth(ctx context.Context, opts ...PersistentAuthOption) (*PersistentAuth, error) {
p := &PersistentAuth{}
Expand Down Expand Up @@ -368,10 +395,16 @@ func (a *PersistentAuth) validateArg() error {

// oauth2Config returns the OAuth2 configuration for the given OAuthArgument.
func (a *PersistentAuth) oauth2Config() (*oauth2.Config, error) {
scopes := []string{
"offline_access", // ensures OAuth token includes refresh token
"all-apis", // ensures OAuth token has access to all control-plane APIs
// Default to "all-apis" for backwards compatibility with direct users of PersistentAuth
// i.e. people implementing their own U2M authentication.
scopes := a.scopes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a behavioral change, right? Because it is possible to have the scopes as empty. What's the impact of that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it is possible to have the scopes as empty

Where exactly? It used to be hard coded to never be empty. Now, for existing users, a.scopes will first be empty but the code below will add offline_access and all-apis so that it is not empty anymore.

Could you give an example for the behavioural change? I don't see it.

if len(scopes) == 0 {
scopes = []string{"all-apis"}
}
if !a.disableOfflineAccess {
scopes = append([]string{"offline_access"}, scopes...)
}

var endpoints *OAuthAuthorizationServer
var err error
switch argg := a.oAuthArgument.(type) {
Expand Down
144 changes: 144 additions & 0 deletions credentials/u2m/persistent_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -487,3 +487,147 @@ func TestPersistentAuth_startListener_explicitPortNoFallBack(t *testing.T) {
t.Fatalf("pa.startListener(): want error %v, got %v", testError, gotErr)
}
}

// TestU2M_ScopesAndOfflineAccess verifies that OAuth scopes are correctly configured
// and sent during the authorization flow, and that the disableOfflineAccess flag
// correctly controls whether offline_access is added to the scope.
func TestU2M_ScopesAndOfflineAccess(t *testing.T) {
const (
testWorkspaceHost = "https://workspace.cloud.databricks.com"
testTokenEndpoint = "/oidc/v1/token"
testCallbackURL = "http://localhost:8020"
)

tests := []struct {
name string
scopes []string
disableOffline bool
want string
}{
{
name: "nil scopes uses default with offline_access",
scopes: nil,
disableOffline: false,
want: "offline_access all-apis",
},
{
name: "empty scopes uses default with offline_access",
scopes: []string{},
disableOffline: false,
want: "offline_access all-apis",
},
{
name: "single scope with offline_access",
scopes: []string{"dashboards"},
disableOffline: false,
want: "offline_access dashboards",
},
{
name: "multiple scopes with offline_access",
scopes: []string{"files", "jobs", "mlflow:read"},
disableOffline: false,
want: "offline_access files jobs mlflow:read",
},
{
name: "disable offline_access",
scopes: []string{"files", "jobs"},
disableOffline: true,
want: "files jobs",
},
{
name: "nil scopes with disable offline_access",
scopes: nil,
disableOffline: true,
want: "all-apis",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()

var scopeReceived, stateReceived string
browserCalled := make(chan struct{})
browser := func(redirect string) error {
u, err := url.ParseRequestURI(redirect)
if err != nil {
return err
}
query := u.Query()
scopeReceived = query.Get("scope")
stateReceived = query.Get("state")
close(browserCalled)
return nil
}

cache := &tokenCacheMock{
store: func(key string, tok *oauth2.Token) error {
return nil
},
}

arg, err := NewBasicWorkspaceOAuthArgument(testWorkspaceHost)
if err != nil {
t.Fatalf("NewBasicWorkspaceOAuthArgument(): want no error, got %v", err)
}

var tokenResponse string
if tt.disableOffline {
tokenResponse = `access_token=token`
} else {
tokenResponse = `access_token=token&refresh_token=refresh`
}

opts := []PersistentAuthOption{
WithTokenCache(cache),
WithBrowser(browser),
WithHttpClient(&http.Client{
Transport: fixtures.SliceTransport{
{
Method: "POST",
Resource: testTokenEndpoint,
Response: tokenResponse,
ResponseHeaders: map[string][]string{
"Content-Type": {"application/x-www-form-urlencoded"},
},
},
},
}),
WithOAuthEndpointSupplier(MockOAuthEndpointSupplier{}),
WithOAuthArgument(arg),
WithDisableOfflineAccess(tt.disableOffline),
WithScopes(tt.scopes),
}

p, err := NewPersistentAuth(ctx, opts...)
if err != nil {
t.Fatalf("NewPersistentAuth(): want no error, got %v", err)
}
defer p.Close()

errc := make(chan error)
go func() {
err := p.Challenge()
errc <- err
close(errc)
}()

<-browserCalled

if scopeReceived != tt.want {
t.Errorf("scope: want %q, got %q", tt.want, scopeReceived)
}

resp, err := http.Get(fmt.Sprintf("%s?code=__CODE__&state=%s", testCallbackURL, stateReceived))
if err != nil {
t.Fatalf("http.Get(): want no error, got %v", err)
}
defer resp.Body.Close()

err = <-errc
if err != nil {
t.Fatalf("p.Challenge(): want no error, got %v", err)
}
})
}
}
Loading