From 8aedf5ec3de843d04e3241a89701e7f9c596eb81 Mon Sep 17 00:00:00 2001 From: Anant Vindal Date: Tue, 30 Dec 2025 11:58:54 +0530 Subject: [PATCH] fix: bugfix for logout, oncecell for sse logout flow incorrectly assumed the oidc client to always be present shifted sse handler from lazy to oncecell --- src/alerts/mod.rs | 3 ++- src/handlers/http/middleware.rs | 12 ++++-------- src/handlers/http/modal/mod.rs | 6 ++---- src/handlers/http/oidc.rs | 27 ++++++++++++++------------- src/sse/mod.rs | 9 ++++++--- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/src/alerts/mod.rs b/src/alerts/mod.rs index 0a971cee3..31b53f71d 100644 --- a/src/alerts/mod.rs +++ b/src/alerts/mod.rs @@ -634,8 +634,9 @@ impl AlertConfig { }), }) && !broadcast_to.is_empty() + && let Some(handler) = SSE_HANDLER.get() { - SSE_HANDLER.broadcast(msg, Some(&broadcast_to)).await; + handler.broadcast(msg, Some(&broadcast_to)).await; } Ok(()) diff --git a/src/handlers/http/middleware.rs b/src/handlers/http/middleware.rs index e07217f2d..280f8894c 100644 --- a/src/handlers/http/middleware.rs +++ b/src/handlers/http/middleware.rs @@ -182,14 +182,7 @@ where // if session is expired, refresh token if sessions().is_session_expired(&key) { - let oidc_client = if let Some(client) = OIDC_CLIENT.get() - && let Some(client) = client - { - let guard = client.read().await; - Some(guard.client().clone()) - } else { - None - }; + let oidc_client = OIDC_CLIENT.get(); if let Some(client) = oidc_client && let Ok(userid) = userid @@ -209,6 +202,9 @@ where if let Some(oauth_data) = bearer_to_refresh { let refreshed_token = match client + .read() + .await + .client() .refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str())) .await { diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index f9626d624..ea6fffb76 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -60,7 +60,7 @@ pub mod utils; pub type OpenIdClient = Arc>; -pub static OIDC_CLIENT: OnceCell>>> = OnceCell::new(); +pub static OIDC_CLIENT: OnceCell>> = OnceCell::new(); #[derive(Debug)] pub struct GlobalClient { @@ -117,9 +117,7 @@ pub trait ParseableServer { let client = config .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) .await?; - OIDC_CLIENT.get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client))))); - } else { - OIDC_CLIENT.get_or_init(|| None); + OIDC_CLIENT.get_or_init(|| Arc::new(RwLock::new(GlobalClient::new(client)))); } // get the ssl stuff diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index b5b479711..6785a4e12 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -80,10 +80,7 @@ pub async fn login( )); } - let oidc_client = match OIDC_CLIENT.get() { - Some(c) => c.as_ref().cloned(), - None => None, - }; + let oidc_client = OIDC_CLIENT.get(); let session_key = extract_session_key_from_req(&req).ok(); let (session_key, oidc_client) = match (session_key, oidc_client) { @@ -149,17 +146,23 @@ pub async fn login( } pub async fn logout(req: HttpRequest, query: web::Query) -> HttpResponse { - let oidc_client = match OIDC_CLIENT.get() { - Some(c) => Some(c.as_ref().unwrap().read().await.client().clone()), - None => None, - }; + let oidc_client = OIDC_CLIENT.get(); let Some(session) = extract_session_key_from_req(&req).ok() else { return redirect_to_client(query.redirect.as_str(), None); }; let user = Users.remove_session(&session); - let logout_endpoint = - oidc_client.and_then(|client| client.config().end_session_endpoint.clone()); + let logout_endpoint = if let Some(client) = oidc_client { + client + .read() + .await + .client() + .config() + .end_session_endpoint + .clone() + } else { + None + }; match (user, logout_endpoint) { (Some(username), Some(logout_endpoint)) @@ -174,9 +177,7 @@ pub async fn logout(req: HttpRequest, query: web::Query) -> /// Handler for code callback /// User should be redirected to page they were trying to access with cookie pub async fn reply_login(login_query: web::Query) -> Result { - let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get() - && let Some(oidc_client) = oidc_client - { + let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get() { oidc_client } else { return Err(OIDCError::Unauthorized); diff --git a/src/sse/mod.rs b/src/sse/mod.rs index 02f3470c8..5e009e504 100644 --- a/src/sse/mod.rs +++ b/src/sse/mod.rs @@ -26,7 +26,7 @@ use actix_web_lab::{ use futures_util::future; use itertools::Itertools; -use once_cell::sync::Lazy; +use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use tokio::sync::{RwLock, mpsc}; use tokio_stream::wrappers::ReceiverStream; @@ -36,7 +36,7 @@ use crate::{ alerts::AlertState, rbac::map::SessionKey, utils::actix::extract_session_key_from_req, }; -pub static SSE_HANDLER: Lazy> = Lazy::new(Broadcaster::create); +pub static SSE_HANDLER: OnceCell> = OnceCell::new(); pub struct Broadcaster { inner: RwLock, @@ -174,7 +174,10 @@ pub async fn register_sse_client( )); } }; - Ok(SSE_HANDLER.new_client(&sessionid).await) + Ok(SSE_HANDLER + .get_or_init(Broadcaster::create) + .new_client(&sessionid) + .await) } /// Struct to define the messages being sent using SSE