Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions pkg/auth/oauth/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,16 +449,21 @@ func (*Flow) writeErrorPage(w http.ResponseWriter, err error) {
}

// processToken processes the received token and extracts claims
func (f *Flow) processToken(ctx context.Context, token *oauth2.Token) *TokenResult {
func (f *Flow) processToken(_ context.Context, token *oauth2.Token) *TokenResult {
result := &TokenResult{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenType: token.TokenType,
Expiry: token.Expiry,
}

// Create a base token source using the original token with the provided context
base := f.oauth2Config.TokenSource(ctx, token)
// Create a base token source using the original token with a background context.
// We use context.Background() instead of the passed ctx because the TokenSource
// is long-lived and will be used for token refresh operations long after the
// initial OAuth flow completes. Using the original ctx would cause "context canceled"
// errors when attempting to refresh tokens, as that context gets cancelled when
// the OAuth callback server shuts down.
base := f.oauth2Config.TokenSource(context.Background(), token)

// ReuseTokenSource ensures that refresh happens only when needed
f.tokenSource = oauth2.ReuseTokenSource(token, base)
Expand Down
62 changes: 62 additions & 0 deletions pkg/auth/oauth/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -932,3 +932,65 @@ func TestExtractJWTClaims_ErrorCases(t *testing.T) {
})
}
}

func TestTokenRefreshAfterContextCancellation(t *testing.T) {
t.Parallel()

// Create a mock token server that tracks refresh attempts
refreshCalled := false
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
require.NoError(t, err)

if r.Form.Get("grant_type") == "refresh_token" {
refreshCalled = true
}

response := map[string]interface{}{
"access_token": "new-access-token",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(response)
require.NoError(t, err)
}))
defer tokenServer.Close()

config := &Config{
ClientID: "test-client",
AuthURL: "https://example.com/auth",
TokenURL: tokenServer.URL,
}

flow, err := NewFlow(config)
require.NoError(t, err)

// Create a context that we will cancel (simulating OAuth flow completion)
ctx, cancel := context.WithCancel(context.Background())

// Process token with the cancellable context.
// Use an already-expired token to force refresh on next Token() call.
token := &oauth2.Token{
AccessToken: "original-access-token",
RefreshToken: "test-refresh-token",
TokenType: "Bearer",
Expiry: time.Now().Add(-time.Hour), // Already expired
}

_ = flow.processToken(ctx, token)

// Cancel the context (simulates OAuth callback server shutdown)
cancel()

// Now attempt to get a token - this should trigger refresh.
// Before the fix: fails with "context canceled" because processToken
// stored a TokenSource using the now-cancelled ctx.
// After the fix: succeeds because processToken uses context.Background().
newToken, err := flow.tokenSource.Token()

require.NoError(t, err, "token refresh should succeed even after context cancellation")
assert.True(t, refreshCalled, "refresh endpoint should have been called")
assert.Equal(t, "new-access-token", newToken.AccessToken)
}
Loading