Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions components/backend/handlers/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,25 +230,18 @@ func HandleOAuth2Callback(c *gin.Context) {
log.Printf("OAuth2 callback received - provider: %s, hasCode: %v, hasState: %v, error: %s",
provider, code != "", state != "", errorParam)

// Create callback data record
callbackData := OAuthCallbackData{
Provider: provider,
Code: code,
State: state,
Error: errorParam,
ErrorDesc: errorDesc,
ReceivedAt: time.Now(),
Consumed: false,
}

// Try to get user ID from session (may not be available for MCP flows)
if userID, exists := c.Get("userID"); exists && userID != nil {
callbackData.UserID = userID.(string)
}

// Handle OAuth errors
// Handle OAuth errors early
if errorParam != "" {
log.Printf("OAuth error received: %s - %s", errorParam, errorDesc)
callbackData := OAuthCallbackData{
Provider: provider,
Code: code,
State: state,
Error: errorParam,
ErrorDesc: errorDesc,
ReceivedAt: time.Now(),
Consumed: false,
}
// Store the error for MCP to retrieve
if err := storeOAuthCallback(c.Request.Context(), state, &callbackData); err != nil {
log.Printf("Failed to store OAuth error: %v", err)
Expand All @@ -263,6 +256,52 @@ func HandleOAuth2Callback(c *gin.Context) {
return
}

// IMPORTANT: Check for cluster-level OAuth BEFORE exchanging the code
// Authorization codes are single-use, so we must route to the correct handler first
var stateMap map[string]interface{}
stateBytes, err := base64.URLEncoding.DecodeString(strings.Split(state, ".")[0])
if err == nil {
if jsonErr := json.Unmarshal(stateBytes, &stateMap); jsonErr == nil {
// Check if this is cluster-level OAuth
if isCluster, ok := stateMap["cluster"].(bool); ok && isCluster {
log.Printf("Detected cluster-level OAuth flow")

// Handle cluster-level Google OAuth (this will exchange the code)
if err := HandleGoogleOAuthCallback(c.Request.Context(), code, stateMap); err != nil {
log.Printf("Cluster-level OAuth failed: %v", err)
// Return generic error to client, details logged server-side only
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(
"<html><body><h1>Authorization Error</h1><p>Failed to connect Google Drive. Please try again.</p><p>You can close this window.</p><script>window.close();</script></body></html>",
))
return
}

// Success
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(
"<html><body><h1>Authorization Successful!</h1><p>Google Drive is now connected!</p><p>All your sessions will have access to Google Drive.</p><p>You can close this window.</p><script>window.close();</script></body></html>",
))
return
}
}
}

// Legacy session-specific OAuth flow
// Create callback data record
callbackData := OAuthCallbackData{
Provider: provider,
Code: code,
State: state,
Error: errorParam,
ErrorDesc: errorDesc,
ReceivedAt: time.Now(),
Consumed: false,
}

// Try to get user ID from session (may not be available for MCP flows)
if userID, exists := c.Get("userID"); exists && userID != nil {
callbackData.UserID = userID.(string)
}

// Get provider configuration
providerConfig, err := getOAuthProvider(provider)
if err != nil {
Expand All @@ -278,7 +317,7 @@ func HandleOAuth2Callback(c *gin.Context) {
}
redirectURI := fmt.Sprintf("%s/oauth2callback", backendURL)

// Exchange code for token
// Exchange code for token (for legacy session-specific flow)
tokenData, err := exchangeOAuthCode(c.Request.Context(), providerConfig, code, redirectURI)
if err != nil {
log.Printf("Failed to exchange OAuth code: %v", err)
Expand All @@ -298,35 +337,6 @@ func HandleOAuth2Callback(c *gin.Context) {
callbackData.ExpiresIn = tokenData.ExpiresIn
callbackData.TokenType = tokenData.TokenType

// Try to parse state as new format (map) or legacy format (OAuthStateData struct)
// New cluster-level OAuth uses map with "cluster":true flag
var stateMap map[string]interface{}
stateBytes, err := base64.URLEncoding.DecodeString(strings.Split(state, ".")[0])
if err == nil {
if jsonErr := json.Unmarshal(stateBytes, &stateMap); jsonErr == nil {
// Check if this is cluster-level OAuth
if isCluster, ok := stateMap["cluster"].(bool); ok && isCluster {
log.Printf("Detected cluster-level OAuth flow")

// Handle cluster-level Google OAuth
if err := HandleGoogleOAuthCallback(c.Request.Context(), code, stateMap); err != nil {
log.Printf("Cluster-level OAuth failed: %v", err)
// Return generic error to client, details logged server-side only
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(
"<html><body><h1>Authorization Error</h1><p>Failed to connect Google Drive. Please try again.</p><p>You can close this window.</p><script>window.close();</script></body></html>",
))
return
}

// Success
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(
"<html><body><h1>Authorization Successful!</h1><p>Google Drive is now connected!</p><p>All your sessions will have access to Google Drive.</p><p>You can close this window.</p><script>window.close();</script></body></html>",
))
return
}
}
}

// Fallback to legacy session-specific OAuth
stateData, err := validateAndParseOAuthState(state)
if err != nil {
Expand Down
Loading