diff --git a/components/backend/handlers/oauth.go b/components/backend/handlers/oauth.go index 2975a490..3efe0d40 100644 --- a/components/backend/handlers/oauth.go +++ b/components/backend/handlers/oauth.go @@ -230,25 +230,18 @@ func HandleOAuth2Callback(c *gin.Context) { log.Printf("OAuth2 callback received - provider: %s, hasCode: %v, hasState: %v, error: %s", provider, code != "", state != "", errorParam) - // Create callback data record - callbackData := OAuthCallbackData{ - Provider: provider, - Code: code, - State: state, - Error: errorParam, - ErrorDesc: errorDesc, - ReceivedAt: time.Now(), - Consumed: false, - } - - // Try to get user ID from session (may not be available for MCP flows) - if userID, exists := c.Get("userID"); exists && userID != nil { - callbackData.UserID = userID.(string) - } - - // Handle OAuth errors + // Handle OAuth errors early if errorParam != "" { log.Printf("OAuth error received: %s - %s", errorParam, errorDesc) + callbackData := OAuthCallbackData{ + Provider: provider, + Code: code, + State: state, + Error: errorParam, + ErrorDesc: errorDesc, + ReceivedAt: time.Now(), + Consumed: false, + } // Store the error for MCP to retrieve if err := storeOAuthCallback(c.Request.Context(), state, &callbackData); err != nil { log.Printf("Failed to store OAuth error: %v", err) @@ -263,6 +256,52 @@ func HandleOAuth2Callback(c *gin.Context) { return } + // IMPORTANT: Check for cluster-level OAuth BEFORE exchanging the code + // Authorization codes are single-use, so we must route to the correct handler first + var stateMap map[string]interface{} + stateBytes, err := base64.URLEncoding.DecodeString(strings.Split(state, ".")[0]) + if err == nil { + if jsonErr := json.Unmarshal(stateBytes, &stateMap); jsonErr == nil { + // Check if this is cluster-level OAuth + if isCluster, ok := stateMap["cluster"].(bool); ok && isCluster { + log.Printf("Detected cluster-level OAuth flow") + + // Handle cluster-level Google OAuth (this will exchange the code) + if err := HandleGoogleOAuthCallback(c.Request.Context(), code, stateMap); err != nil { + log.Printf("Cluster-level OAuth failed: %v", err) + // Return generic error to client, details logged server-side only + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte( + "

Authorization Error

Failed to connect Google Drive. Please try again.

You can close this window.

", + )) + return + } + + // Success + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte( + "

Authorization Successful!

Google Drive is now connected!

All your sessions will have access to Google Drive.

You can close this window.

", + )) + return + } + } + } + + // Legacy session-specific OAuth flow + // Create callback data record + callbackData := OAuthCallbackData{ + Provider: provider, + Code: code, + State: state, + Error: errorParam, + ErrorDesc: errorDesc, + ReceivedAt: time.Now(), + Consumed: false, + } + + // Try to get user ID from session (may not be available for MCP flows) + if userID, exists := c.Get("userID"); exists && userID != nil { + callbackData.UserID = userID.(string) + } + // Get provider configuration providerConfig, err := getOAuthProvider(provider) if err != nil { @@ -278,7 +317,7 @@ func HandleOAuth2Callback(c *gin.Context) { } redirectURI := fmt.Sprintf("%s/oauth2callback", backendURL) - // Exchange code for token + // Exchange code for token (for legacy session-specific flow) tokenData, err := exchangeOAuthCode(c.Request.Context(), providerConfig, code, redirectURI) if err != nil { log.Printf("Failed to exchange OAuth code: %v", err) @@ -298,35 +337,6 @@ func HandleOAuth2Callback(c *gin.Context) { callbackData.ExpiresIn = tokenData.ExpiresIn callbackData.TokenType = tokenData.TokenType - // Try to parse state as new format (map) or legacy format (OAuthStateData struct) - // New cluster-level OAuth uses map with "cluster":true flag - var stateMap map[string]interface{} - stateBytes, err := base64.URLEncoding.DecodeString(strings.Split(state, ".")[0]) - if err == nil { - if jsonErr := json.Unmarshal(stateBytes, &stateMap); jsonErr == nil { - // Check if this is cluster-level OAuth - if isCluster, ok := stateMap["cluster"].(bool); ok && isCluster { - log.Printf("Detected cluster-level OAuth flow") - - // Handle cluster-level Google OAuth - if err := HandleGoogleOAuthCallback(c.Request.Context(), code, stateMap); err != nil { - log.Printf("Cluster-level OAuth failed: %v", err) - // Return generic error to client, details logged server-side only - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte( - "

Authorization Error

Failed to connect Google Drive. Please try again.

You can close this window.

", - )) - return - } - - // Success - c.Data(http.StatusOK, "text/html; charset=utf-8", []byte( - "

Authorization Successful!

Google Drive is now connected!

All your sessions will have access to Google Drive.

You can close this window.

", - )) - return - } - } - } - // Fallback to legacy session-specific OAuth stateData, err := validateAndParseOAuthState(state) if err != nil {