From 4646f92e7234ac51809f9bda7f32216d43b7b959 Mon Sep 17 00:00:00 2001 From: Anant Vindal Date: Mon, 22 Dec 2025 12:17:04 +0530 Subject: [PATCH 1/2] create new client if token decoding fails --- Cargo.toml | 2 +- src/handlers/http/middleware.rs | 36 +++++++------ src/handlers/http/modal/mod.rs | 50 +++++++++++------- src/handlers/http/oidc.rs | 92 ++++++++++++++++++++++++--------- 4 files changed, 121 insertions(+), 59 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 495802ee0..74a13001f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,7 +62,7 @@ argon2 = "0.5.0" base64 = "0.22.0" cookie = "0.18.1" hex = "0.4" -openid = { version = "0.15.0", default-features = false, features = ["rustls"] } +openid = { version = "0.18.3", default-features = false, features = ["rustls"] } rustls = "0.22.4" rustls-pemfile = "2.1.2" sha2 = "0.10.8" diff --git a/src/handlers/http/middleware.rs b/src/handlers/http/middleware.rs index 7b7d6652a..34a58df0b 100644 --- a/src/handlers/http/middleware.rs +++ b/src/handlers/http/middleware.rs @@ -24,7 +24,6 @@ use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready}, error::{ErrorBadRequest, ErrorForbidden, ErrorUnauthorized}, http::header::{self, HeaderName}, - web::Data, }; use chrono::{Duration, Utc}; use futures_util::future::LocalBoxFuture; @@ -32,9 +31,9 @@ use futures_util::future::LocalBoxFuture; use crate::{ handlers::{ AUTHORIZATION_KEY, KINESIS_COMMON_ATTRIBUTES_KEY, LOG_SOURCE_KEY, LOG_SOURCE_KINESIS, - STREAM_NAME_HEADER_KEY, http::rbac::RBACError, + STREAM_NAME_HEADER_KEY, + http::{modal::OIDC_CLIENT, rbac::RBACError}, }, - oidc::DiscoveredClient, option::Mode, parseable::PARSEABLE, rbac::{ @@ -145,7 +144,7 @@ where when request is made from Kinesis Firehose. For requests made from other clients, no change. - ## Section start */ + ## Section start */ if let Some(kinesis_common_attributes) = req.request().headers().get(KINESIS_COMMON_ATTRIBUTES_KEY) { @@ -183,12 +182,13 @@ where // if session is expired, refresh token if sessions().is_session_expired(&key) { - let oidc_client = match http_req.app_data::>>() { - Some(client) => { - let c = client.clone().into_inner(); - c.as_ref().clone() - } - None => None, + let oidc_client = if let Some(client) = OIDC_CLIENT.get() + && let Some(client) = client + { + let guard = client.read().await; + Some(guard.client().clone()) + } else { + None }; if let Some(client) = oidc_client @@ -208,13 +208,19 @@ where }; if let Some(oauth_data) = bearer_to_refresh { - let Ok(refreshed_token) = client + let refreshed_token = match client .refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str())) .await - else { - return Err(ErrorUnauthorized( - "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", - )); + { + Ok(bearer) => bearer, + Err(e) => { + tracing::error!("client refresh_token call failed- {e}"); + // remove user session + Users.remove_session(&key); + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + } }; let expires_in = diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index c8be6c89a..96faf1d18 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -18,23 +18,20 @@ use std::{fmt, path::Path, sync::Arc}; -use actix_web::{ - App, HttpServer, - middleware::from_fn, - web::{self, ServiceConfig}, -}; +use actix_web::{App, HttpServer, middleware::from_fn, web::ServiceConfig}; use actix_web_prometheus::PrometheusMetrics; use anyhow::Context; use async_trait::async_trait; use base64::{Engine, prelude::BASE64_STANDARD}; use bytes::Bytes; use futures::future; +use once_cell::sync::OnceCell; use openid::Discovered; use relative_path::RelativePathBuf; use serde::{Deserialize, Serialize}; use serde_json::{Map, Value}; use ssl_acceptor::get_ssl_acceptor; -use tokio::sync::oneshot; +use tokio::sync::{RwLock, oneshot}; use tracing::{error, info, warn}; use crate::{ @@ -43,7 +40,7 @@ use crate::{ correlation::CORRELATIONS, hottier::{HotTierManager, StreamHotTier}, metastore::metastore_traits::MetastoreObject, - oidc::Claims, + oidc::{Claims, DiscoveredClient}, option::Mode, parseable::PARSEABLE, storage::{ObjectStorageProvider, PARSEABLE_ROOT_DIRECTORY}, @@ -63,6 +60,27 @@ pub mod utils; pub type OpenIdClient = Arc>; +pub static OIDC_CLIENT: OnceCell>>> = OnceCell::new(); + +#[derive(Debug)] +pub struct GlobalClient { + client: DiscoveredClient, +} + +impl GlobalClient { + pub fn set(&mut self, client: DiscoveredClient) { + self.client = client; + } + + pub fn client(&self) -> &DiscoveredClient { + &self.client + } + + pub fn new(client: DiscoveredClient) -> Self { + Self { client } + } +} + // to be decided on what the Default version should be pub const DEFAULT_VERSION: &str = "v4"; @@ -95,16 +113,13 @@ pub trait ParseableServer { where Self: Sized, { - let oidc_client = match oidc_client { - Some(config) => { - let client = config - .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) - .await?; - Some(client) - } - - None => None, - }; + if let Some(config) = oidc_client { + let client = config + .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) + .await?; + OIDC_CLIENT + .get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client.clone()))))); + } // get the ssl stuff let ssl = get_ssl_acceptor( @@ -120,7 +135,6 @@ pub trait ParseableServer { // fn that creates the app let create_app_fn = move || { App::new() - .app_data(web::Data::new(oidc_client.clone())) .wrap(prometheus.clone()) .configure(|config| Self::configure_routes(config)) .wrap(from_fn(health_check::check_shutdown_middleware)) diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index d0853a65e..cfc9d2c2e 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -16,7 +16,7 @@ * */ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use actix_web::{ HttpRequest, HttpResponse, @@ -29,11 +29,18 @@ use http::StatusCode; use openid::{Bearer, Options, Token, Userinfo}; use regex::Regex; use serde::Deserialize; +use tokio::sync::RwLock; use ulid::Ulid; use url::Url; use crate::{ - handlers::{COOKIE_AGE_DAYS, SESSION_COOKIE_NAME, USER_COOKIE_NAME, USER_ID_COOKIE_NAME}, + handlers::{ + COOKIE_AGE_DAYS, SESSION_COOKIE_NAME, USER_COOKIE_NAME, USER_ID_COOKIE_NAME, + http::{ + API_BASE_PATH, API_VERSION, + modal::{GlobalClient, OIDC_CLIENT}, + }, + }, oidc::{Claims, DiscoveredClient}, parseable::PARSEABLE, rbac::{ @@ -73,20 +80,18 @@ pub async fn login( )); } - let oidc_client = match req.app_data::>>() { - Some(client) => { - let c = client.clone().into_inner(); - c.as_ref().clone() - } + let oidc_client = match OIDC_CLIENT.get() { + Some(c) => c.as_ref().cloned(), None => None, }; + let session_key = extract_session_key_from_req(&req).ok(); let (session_key, oidc_client) = match (session_key, oidc_client) { (None, None) => return Ok(redirect_no_oauth_setup(query.redirect.clone())), (None, Some(client)) => { return Ok(redirect_to_oidc( query, - &client, + client.read().await.client(), PARSEABLE.options.scope.to_string().as_str(), )); } @@ -131,7 +136,7 @@ pub async fn login( if let Some(oidc_client) = oidc_client { redirect_to_oidc( query, - &oidc_client, + oidc_client.read().await.client(), PARSEABLE.options.scope.to_string().as_str(), ) } else { @@ -170,16 +175,21 @@ pub async fn logout(req: HttpRequest, query: web::Query) -> /// Handler for code callback /// User should be redirected to page they were trying to access with cookie -pub async fn reply_login( - req: HttpRequest, - login_query: web::Query, -) -> Result { - let oidc_client = req.app_data::>>().unwrap(); - let oidc_client = oidc_client.clone().into_inner().as_ref().clone().unwrap(); - let Ok((mut claims, user_info, bearer)): Result<(Claims, Userinfo, Bearer), anyhow::Error> = - request_token(oidc_client, &login_query).await - else { - return Ok(HttpResponse::Unauthorized().finish()); +pub async fn reply_login(login_query: web::Query) -> Result { + let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get() + && let Some(oidc_client) = oidc_client + { + oidc_client + } else { + return Err(OIDCError::Unauthorized); + }; + + let (mut claims, user_info, bearer) = match request_token(oidc_client, &login_query).await { + Ok(v) => v, + Err(e) => { + tracing::error!("reply_login call failed- {e}"); + return Ok(HttpResponse::Unauthorized().finish()); + } }; let username = user_info .name @@ -351,6 +361,7 @@ pub fn redirect_to_client( response.cookie(cookie); } response.insert_header((header::CACHE_CONTROL, "no-store")); + response.finish() } @@ -387,19 +398,50 @@ pub fn cookie_userid(user_id: &str) -> Cookie<'static> { } pub async fn request_token( - oidc_client: DiscoveredClient, + oidc_client: &Arc>, login_query: &Login, ) -> anyhow::Result<(Claims, Userinfo, Bearer)> { - let mut token: Token = oidc_client.request_token(&login_query.code).await?.into(); - let Some(id_token) = token.id_token.as_mut() else { + let old_client = oidc_client.read().await.client().clone(); + let mut token: Token = old_client.request_token(&login_query.code).await?.into(); + + let id_token = if let Some(token) = token.id_token.as_mut() { + token + } else { return Err(anyhow::anyhow!("No id_token provided")); }; - oidc_client.decode_token(id_token)?; - oidc_client.validate_token(id_token, None, None)?; + if let Err(e) = old_client.decode_token(id_token) { + tracing::error!("error while decoding the id_token- {e}"); + let new_client = PARSEABLE + .options + .openid() + .unwrap() + .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) + .await?; + let mut token: Token = new_client.request_token(&login_query.code).await?.into(); + let id_token = if let Some(token) = token.id_token.as_mut() { + token + } else { + return Err(anyhow::anyhow!("No id_token provided")); + }; + new_client.decode_token(id_token)?; + new_client.validate_token(id_token, None, None)?; + let claims = id_token.payload().expect("payload is decoded").clone(); + + let userinfo = new_client.request_userinfo(&token).await?; + let bearer = token.bearer; + + // replace old client with new one + drop(old_client); + + oidc_client.write().await.set(new_client); + return Ok((claims, userinfo, bearer)); + } + old_client.decode_token(id_token)?; + old_client.validate_token(id_token, None, None)?; let claims = id_token.payload().expect("payload is decoded").clone(); - let userinfo = oidc_client.request_userinfo(&token).await?; + let userinfo = old_client.request_userinfo(&token).await?; let bearer = token.bearer; Ok((claims, userinfo, bearer)) } From 6e44dcef2eb5169c3f58393b38e61eecda7b5721 Mon Sep 17 00:00:00 2001 From: Anant Vindal Date: Mon, 22 Dec 2025 13:00:20 +0530 Subject: [PATCH 2/2] coderabbit suggestions --- src/handlers/http/modal/mod.rs | 5 +++-- src/handlers/http/oidc.rs | 20 +++++++------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index 96faf1d18..061159d57 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -117,8 +117,9 @@ pub trait ParseableServer { let client = config .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) .await?; - OIDC_CLIENT - .get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client.clone()))))); + OIDC_CLIENT.get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client))))); + } else { + OIDC_CLIENT.get_or_init(|| None); } // get the ssl stuff diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index cfc9d2c2e..e877f033d 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -22,7 +22,7 @@ use actix_web::{ HttpRequest, HttpResponse, cookie::{Cookie, SameSite, time}, http::header::{self, ContentType}, - web::{self, Data}, + web, }; use chrono::{Duration, TimeDelta}; use http::StatusCode; @@ -149,13 +149,11 @@ pub async fn login( } pub async fn logout(req: HttpRequest, query: web::Query) -> HttpResponse { - let oidc_client = match req.app_data::>>() { - Some(client) => { - let c = client.clone().into_inner(); - c.as_ref().clone() - } + let oidc_client = match OIDC_CLIENT.get() { + Some(c) => Some(c.as_ref().unwrap().read().await.client().clone()), None => None, }; + let Some(session) = extract_session_key_from_req(&req).ok() else { return redirect_to_client(query.redirect.as_str(), None); }; @@ -418,12 +416,8 @@ pub async fn request_token( .unwrap() .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) .await?; - let mut token: Token = new_client.request_token(&login_query.code).await?.into(); - let id_token = if let Some(token) = token.id_token.as_mut() { - token - } else { - return Err(anyhow::anyhow!("No id_token provided")); - }; + + // Reuse the already-obtained token, just decode with new client's JWKS new_client.decode_token(id_token)?; new_client.validate_token(id_token, None, None)?; let claims = id_token.payload().expect("payload is decoded").clone(); @@ -437,7 +431,7 @@ pub async fn request_token( oidc_client.write().await.set(new_client); return Ok((claims, userinfo, bearer)); } - old_client.decode_token(id_token)?; + old_client.validate_token(id_token, None, None)?; let claims = id_token.payload().expect("payload is decoded").clone();