Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/alerts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,9 @@ impl AlertConfig {
}),
})
&& !broadcast_to.is_empty()
&& let Some(handler) = SSE_HANDLER.get()
{
SSE_HANDLER.broadcast(msg, Some(&broadcast_to)).await;
handler.broadcast(msg, Some(&broadcast_to)).await;
}

Ok(())
Expand Down
12 changes: 4 additions & 8 deletions src/handlers/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,7 @@ where

// if session is expired, refresh token
if sessions().is_session_expired(&key) {
let oidc_client = if let Some(client) = OIDC_CLIENT.get()
&& let Some(client) = client
{
let guard = client.read().await;
Some(guard.client().clone())
} else {
None
};
let oidc_client = OIDC_CLIENT.get();

if let Some(client) = oidc_client
&& let Ok(userid) = userid
Expand All @@ -209,6 +202,9 @@ where

if let Some(oauth_data) = bearer_to_refresh {
let refreshed_token = match client
.read()
.await
.client()
.refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str()))
.await
{
Expand Down
6 changes: 2 additions & 4 deletions src/handlers/http/modal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub mod utils;

pub type OpenIdClient = Arc<openid::Client<Discovered, Claims>>;

pub static OIDC_CLIENT: OnceCell<Option<Arc<RwLock<GlobalClient>>>> = OnceCell::new();
pub static OIDC_CLIENT: OnceCell<Arc<RwLock<GlobalClient>>> = OnceCell::new();

#[derive(Debug)]
pub struct GlobalClient {
Expand Down Expand Up @@ -117,9 +117,7 @@ pub trait ParseableServer {
let client = config
.connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code"))
.await?;
OIDC_CLIENT.get_or_init(|| Some(Arc::new(RwLock::new(GlobalClient::new(client)))));
} else {
OIDC_CLIENT.get_or_init(|| None);
OIDC_CLIENT.get_or_init(|| Arc::new(RwLock::new(GlobalClient::new(client))));
}

// get the ssl stuff
Expand Down
27 changes: 14 additions & 13 deletions src/handlers/http/oidc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,7 @@ pub async fn login(
));
}

let oidc_client = match OIDC_CLIENT.get() {
Some(c) => c.as_ref().cloned(),
None => None,
};
let oidc_client = OIDC_CLIENT.get();

let session_key = extract_session_key_from_req(&req).ok();
let (session_key, oidc_client) = match (session_key, oidc_client) {
Expand Down Expand Up @@ -149,17 +146,23 @@ pub async fn login(
}

pub async fn logout(req: HttpRequest, query: web::Query<RedirectAfterLogin>) -> HttpResponse {
let oidc_client = match OIDC_CLIENT.get() {
Some(c) => Some(c.as_ref().unwrap().read().await.client().clone()),
None => None,
};
let oidc_client = OIDC_CLIENT.get();

let Some(session) = extract_session_key_from_req(&req).ok() else {
return redirect_to_client(query.redirect.as_str(), None);
};
let user = Users.remove_session(&session);
let logout_endpoint =
oidc_client.and_then(|client| client.config().end_session_endpoint.clone());
let logout_endpoint = if let Some(client) = oidc_client {
client
.read()
.await
.client()
.config()
.end_session_endpoint
.clone()
} else {
None
};

match (user, logout_endpoint) {
(Some(username), Some(logout_endpoint))
Expand All @@ -174,9 +177,7 @@ pub async fn logout(req: HttpRequest, query: web::Query<RedirectAfterLogin>) ->
/// Handler for code callback
/// User should be redirected to page they were trying to access with cookie
pub async fn reply_login(login_query: web::Query<Login>) -> Result<HttpResponse, OIDCError> {
let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get()
&& let Some(oidc_client) = oidc_client
{
let oidc_client = if let Some(oidc_client) = OIDC_CLIENT.get() {
oidc_client
} else {
return Err(OIDCError::Unauthorized);
Expand Down
9 changes: 6 additions & 3 deletions src/sse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use actix_web_lab::{
use futures_util::future;

use itertools::Itertools;
use once_cell::sync::Lazy;
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use tokio::sync::{RwLock, mpsc};
use tokio_stream::wrappers::ReceiverStream;
Expand All @@ -36,7 +36,7 @@ use crate::{
alerts::AlertState, rbac::map::SessionKey, utils::actix::extract_session_key_from_req,
};

pub static SSE_HANDLER: Lazy<Arc<Broadcaster>> = Lazy::new(Broadcaster::create);
pub static SSE_HANDLER: OnceCell<Arc<Broadcaster>> = OnceCell::new();

pub struct Broadcaster {
inner: RwLock<BroadcasterInner>,
Expand Down Expand Up @@ -174,7 +174,10 @@ pub async fn register_sse_client(
));
}
};
Ok(SSE_HANDLER.new_client(&sessionid).await)
Ok(SSE_HANDLER
.get_or_init(Broadcaster::create)
.new_client(&sessionid)
.await)
}

/// Struct to define the messages being sent using SSE
Expand Down
Loading